import random as rd

W = 3


def print_state(state):
    H = max(state)
    for h in range(H,0,-1):
        for i in range(W):
            if state[i] >= h:
                print("■",end=" ")
            else:
                print("□",end=" ")
        print()
    print(" ".join(map(str,state)))

def generate_subset_sum(n,k):
    if k == 1:
        return [[n]]
    if k == 1:
        return []
    L = []
    for i in range(n+1):
        for x in generate_subset_sum(n-i,k-1):
            x.append(i)
            L.append(x)
    return L

def next(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    states_list = []
    for add in generate_subset_sum(to_move,W-1):
        new_state_cur = new_state.copy()
        for i in range(W-1):
            column_to_add = i + (i >= col)
            new_state_cur[column_to_add] += add[i]
        
        states_list.append(new_state_cur)
    return states_list

def exist_free(state,col):
    for i in range(W):
        if i != col:
            if is_free(state[i]):
                return True
    return False

def is_free(v):
    return v <= 1 or v % 2 == 1

def next_L(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    for _ in range(to_move):
        if exist_free(new_state,col):
            min_ind = -1
            min = float("inf")
            for i in range(W):
                if i != col and is_free(new_state[i]) and new_state[i] <= min:
                    min = new_state[i]
                    min_ind = i
            new_state[min_ind] += 1
        else:
            i = min_col(new_state,col)
            new_state[i] += 1
    return new_state

dp = dict()
dpL = dict()
dpF = dict()

def state_to_str(state):
    s = sorted(state)
    return ",".join(map(str,s))

def bound(state):
    car = state_to_str(state)
    if car in dp:
        return dp[car]
    if sum(state) == 1:
        sol = 0
        dp[car] = 0
        return 0
    worst = 0
    for c in range(W):
        for h in range(state[c]//2,state[c]):
            best = float("inf")
            nexts = next(state,c,h)
            if nexts != []:
                for state2 in next(state,c,h):
                    best = min(best,bound(state2) + state[c] - h - 1)
                worst = max(worst,best)
    dp[car] = worst
    return worst

def bound_F(state):
    car = state_to_str(state)
    if car in dpF:
        return dpF[car]
    if sum(state) == 1:
        sol = 0
        dpF[car] = 0
        return 0
    c = max_col(state)
    h = state[c]//2
    best = float("inf")
    nexts = next(state,c,h)
    if nexts != []:
        for state2 in next(state,c,h):
            best = min(best,bound_F(state2) + state[c] - h - 1)
    dpF[car] = best
    return best

def bound_L(state):
    car = state_to_str(state)
    if car in dpL:
        return dpL[car]
    if sum(state) == 1:
        sol = 0
        dpL[car] = 0
        return 0
    worst = 0
    for c in range(W):
        for h in range(state[c]//2,state[c]):
            state2 = next_L(state,c,h)
            if state2 != None:
                cost = bound_L(state2) + state[c] - h - 1
                worst = max(worst,cost)
    dpL[car] = worst
    return worst

def bound_L_F(state):
    cost = 0
    while sum(state) != 0:
        c = 0
        h = 0
        for i in range(W):
            if state[i] > h:
                c = i
                h = state[i]
        h = h//2
        cost += state[c] - h - 1
        state = next_L(state,c,h)
    return cost

def gen_random_state(h, w = W):
    return [rd.randint(0,h-1) for _ in range(w)]

def check_L_F_opt(n):
    best = float("inf")
    best_s = []
    states = generate_subset_sum(n,W)
    centieme = 1 + len(states)//100
    for i,state in enumerate(states):
        if i%centieme == 0:
            print("Avancement à",i//centieme,"%")
        k = sum(state)
        if bound_L_F(state) != bound_L(state):
            if k < best:
                best = k
                best_s = state.copy()
    if (best_s == []):
        return True
    else:
        print("contre-exemple :",best_s)
        return False
    
def check_L_opt(n):
    best = float("inf")
    best_s = []
    states = generate_subset_sum(n,W)
    centieme = 1 + len(states)//100
    for i,state in enumerate(states):
        if i%centieme == 0:
            print("Avancement à",i//centieme,"%")
        k = sum(state)
        if bound(state) != bound_L(state):
            if k < best:
                best = k
                best_s = state.copy()
    if (best_s == []):
        return True
    else:
        print("contre-exemple :",best_s)
        return False

def custom_cost_L(state):
    worst = -float("inf")
    for c in range(W):
        for h in range(state[c]):
            state2 = next_L(state,c,h)
            score = state[c] - h - 1 + max(state2) - max(state)
            empty1 = 0
            empty2 = 0
            for i in range(W):
                empty1 += state[i] == 0
                empty2 += state2[i] == 0
            score += empty2 - empty1
            if score > worst:
                worst = score
    return worst

def check_cost_L(n):
    states = generate_subset_sum(n,W)
    for state in states:
        print(state,custom_cost_L(state))

def check_L_opt_until(N):
    for n in range(3,N):
        if not(check_L_opt(n)):
            print("aie")
            break
        print("L est optimal jusque",n)

def check_L_F_opt_until(N):
    for n in range(3,N):
        if not(check_L_F_opt(n)):
            print("aie")
            break
        print("F est optimal contre L jusque",n)

def flat(N):
    state = [0] * W
    for _ in range(N):
        columns = [i for i in range(W)]
        columns.sort(key= lambda i : state[i])
        state[columns[0]] += 1
    return state

def max_col(state):
    c = 0
    h = state[0]
    for i in range(W):
        if state[i] > h:
            h = state[i]
            c = i
    return c

def min_col(state,forbid):
    c = 0
    h = float("inf")
    for i in range(W):
        if state[i] < h and i != forbid:
            h = state[i]
            c = i
    return c

def move_one_L(state):
    c1 = max_col(state)
    c2 = min_col(state)
    state[c1] -= 1
    state[c2] += 1

def pot(s):
    v = 0
    for i in range(W):
        for j in range(s[i]):
            v += min(s[i]-j-1,j)
    return v

for _ in range(10):
    s = gen_random_state(30)
    a,b,c = s
    s.sort(reverse=True)
    
    a,b,c = s
    if a >= b+2:
        t = [a-1,b+1,c]
        d,e = bound_L(s),bound_L(t)
        if abs(d-e) > 1:
            print("Erreur dans bound_L pour",s,"et",t)