import random as rd

W = 3
H = 5

def print_state(state):
    for h in range(H,0,-1):
        for i in range(W):
            if state[i] >= h:
                print("■",end=" ")
            else:
                print("□",end=" ")
        print()
    print(" ".join(map(str,state)))

def generate_subset_sum(n,k):
    if k == 1 and n <= H:
        return [[n]]
    if k == 1:
        return []
    L = []
    for i in range(min(n+1,H+1)):
        for x in generate_subset_sum(n-i,k-1):
            x.append(i)
            L.append(x)
    return L

def next(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    states_list = []
    for add in generate_subset_sum(to_move,W-1):
        new_state_cur = new_state.copy()
        for i in range(W-1):
            column_to_add = i + (i >= col)
            new_state_cur[column_to_add] += add[i]
        correct = True
        for i in range(W):
            if new_state_cur[i] > H:
                correct = False
        if correct:
            states_list.append(new_state_cur)
    return states_list

def next_L(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    for _ in range(to_move):
        columns = [i for i in range(W) if i != col]
        columns.sort(key= lambda i : new_state[i])
        new_state[columns[0]] += 1
    correct = True
    for x in new_state:
        if x > H:
            correct = False
    if correct:
        return new_state
    return None

dp = dict()
dpL = dict()

def state_to_str(state):
    s = sorted(state)
    return ",".join(map(str,s))

def bound(state):
    car = state_to_str(state)
    if car in dp:
        return dp[car]
    if sum(state) == 1:
        sol = 0
        dp[car] = 0
        return 0
    worst = 0
    for c in range(W):
        for h in range(state[c]):
            best = float("inf")
            nexts = next(state,c,h)
            if nexts != []:
                for state2 in next(state,c,h):
                    best = min(best,bound(state2) + state[c] - h - 1)
                worst = max(worst,best)
    dp[car] = worst
    return worst

def bound_L(state):
    car = state_to_str(state)
    if car in dpL:
        return dpL[car]
    if sum(state) == 1:
        sol = 0
        dpL[car] = 0
        return 0
    worst = 0
    for c in range(W):
        for h in range(state[c]):
            state2 = next_L(state,c,h)
            if state2 != None:
                cost = bound_L(state2) + state[c] - h - 1
                worst = max(worst,cost)
    dpL[car] = worst
    return worst

def bound_L_F(state):
    cost = 0
    while sum(state) != 0:    
        c = 0
        h = state[0]
        for i in range(W):
            if state[i] > h:
                c = i
                h = state[i]
        free_space = 0
        for i in range(W):
            if i != c:
                free_space += H - state[i]
        h = max(0,h - free_space - 1)
        cost += state[c] - h - 1
        state = next_L(state,c,h)
    return cost

def gen_random_state(w = W, h = H):
    return [rd.randint(0,h-1) for _ in range(w)]

def check_L_F_opt(n):
    best = float("inf")
    best_s = []
    states = generate_subset_sum(n,W)
    centieme = 1 + len(states)//100
    for i,state in enumerate(states):
        if i%centieme == 0:
            print("Avancement à",i//centieme,"%")
        k = sum(state)
        if bound_L_F(state) != bound_L(state):
            if k < best:
                best = k
                best_s = state.copy()
    if (best_s == []):
        return True
    else:
        print("contre-exemple :",best_s)
        return False
    
def check_L_opt(n):
    best = float("inf")
    best_s = []
    states = generate_subset_sum(n,W)
    centieme = 1 + len(states)//100
    for i,state in enumerate(states):
        if i%centieme == 0:
            print("Avancement à",i//centieme,"%")
        k = sum(state)
        if bound(state) != bound_L(state):
            if k < best:
                best = k
                best_s = state.copy()
    if (best_s == []):
        return True
    else:
        print("contre-exemple :",best_s)
        return False

def custom_cost_L(state):
    worst = -float("inf")
    for c in range(W):
        for h in range(state[c]):
            state2 = next_L(state,c,h)
            score = state[c] - h - 1 + max(state2) - max(state)
            empty1 = 0
            empty2 = 0
            for i in range(W):
                empty1 += state[i] == 0
                empty2 += state2[i] == 0
            score += empty2 - empty1
            if score > worst:
                worst = score
    return worst

def check_cost_L(n):
    states = generate_subset_sum(n,W)
    for state in states:
        print(state,custom_cost_L(state))

def check_L_opt_until(N):
    for n in range(3,N):
        if not(check_L_opt(n)):
            print("aie")
            break
        print("L est optimal jusque",n)

def check_L_F_opt_until(N):
    for n in range(3,N):
        if not(check_L_F_opt(n)):
            print("aie")
            break
        print("F est optimal contre L jusque",n)

def flat(N):
    state = [0] * W
    for _ in range(N):
        columns = [i for i in range(W)]
        columns.sort(key= lambda i : state[i])
        state[columns[0]] += 1
    return state

def max_col(state):
    c = 0
    h = state[0]
    for i in range(W):
        if state[i] > h:
            h = state[i]
            c = i
    return c

def min_col(state):
    c = 0
    h = state[0]
    for i in range(W):
        if state[i] < h:
            h = state[i]
            c = i
    return c

def move_one_L(state):
    c1 = max_col(state)
    c2 = min_col(state)
    state[c1] -= 1
    state[c2] += 1



