import random as rd

W = 4
H = 10

def print_state(state):
    for h in range(H,0,-1):
        for i in range(W):
            if state[i] >= h:
                print("■",end=" ")
            else:
                print("□",end=" ")
        print()
    print(" ".join(map(str,state)))

def generate_subset_sum(n,k):
    if k == 1 and n <= H:
        return [[n]]
    if k == 1:
        return []
    L = []
    for i in range(min(n+1,H+1)):
        for x in generate_subset_sum(n-i,k-1):
            x.append(i)
            L.append(x)
    return L

def next(state,col,row,next_col):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    states_list = []
    for add in generate_subset_sum(to_move,W-1):
        new_state_cur = new_state.copy()
        for i in range(W-1):
            column_to_add = i + (i >= col)
            new_state_cur[column_to_add] += add[i]
        correct = True
        for i in range(W):
            if new_state_cur[i] > H:
                correct = False
        if correct:
            states_list.append(new_state_cur)
    return states_list

def next_L(state,col,row,next_col):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    for _ in range(to_move):
        i = min_col(new_state, [col, next_col])
        new_state[i] += 1
    correct = True
    for x in new_state:
        if x > H:
            correct = False
    if correct:
        return new_state
    return None

dp = dict()
dpL = dict()
dpF = dict()

def state_to_str(state, col, row):
    s = sorted(state) + [col, row]
    return ",".join(map(str,s))

def bound(state, col = -1, row = -1):
    car = state_to_str(state, col, row)
    if car in dp:
        return dp[car]
    if sum(state) == 1:
        sol = 0
        dp[car] = 0
        return 0
    worst = 0
    if col == -1:
        for col in range(W):
            for row in range(state[col]):
                worst = max(worst, bound(state,col,row))
        return worst
    for next_col in range(W):
        if next_col != col:
            for next_row in range(state[next_col]):
                best = float("inf")
                nexts = next(state,col,row,next_col)
                if nexts != []:
                    for state2 in nexts:
                        best = min(best,bound(state2,next_col,next_row) + state[col] - row - 1)
                    worst = max(worst,best)
    dp[car] = worst
    print(state, col, row, " : ",worst)
    return worst

def bound_L(state, col = -1, row = -1):
    car = state_to_str(state, col, row)
    if car in dpL:
        return dpL[car]
    if sum(state) == 1:
        sol = 0
        dpL[car] = 0
        return 0
    worst = 0
    if col == -1:
        for col in range(W):
            for row in range(state[col]):
                worst = max(worst, bound_L(state,col,row))
        return worst
    for next_col in range(W):
        if next_col != col:
            for next_row in range(state[next_col]):
                state2 = next_L(state,col,row,next_col)
                worst = max(worst,bound_L(state2,next_col,next_row) + state[col] - row - 1)
    dpL[car] = worst
    return worst

def bound_F(state, col = -1, row = -1):
    car = state_to_str(state, col, row)
    if car in dpF:
        return dpF[car]
    if sum(state) == 1:
        sol = 0
        dpF[car] = 0
        return 0
    if col == -1:
        col = max_col(state)
        return bound_F(state,col,0)
    next_col = max_col(state, [col])
    best = float("inf")
    nexts = next(state,col,row,next_col)
    for state2 in nexts:
        best = min(best,bound_F(state2,next_col,0) + state[col] - row - 1)
    dpF[car] = best
    return best

def bound_L_F(state, col = -1, row = -1):
    if sum(state) == 1:
        return 0
    if col == -1:
        col = max_col(state)
        return bound_L_F(state,col,0)
    next_col = max_col(state, [col])
    state2 = next_L(state,col,row,next_col)
    return bound_L_F(state2,next_col,0) + state[col] - row - 1


def gen_random_state(w = W, h = H):
    return [rd.randint(0,h-1) for _ in range(w)]


def max_col(state, forbid = []):
    c = 0
    h = -1
    for i in range(W):
        if state[i] > h and not(i in forbid):
            h = state[i]
            c = i
    return c

def min_col(state, forbid = []):
    c = 0
    h = float("inf")
    for i in range(W):
        if state[i] < h and not(i in forbid):
            h = state[i]
            c = i
    return c

s = [3,0,0,0]
print(bound(s,0,0))