import random as rd
import itertools
import time
import matplotlib.pyplot as plt

W = 3
H = 200

def print_state(state):
    for h in range(H,0,-1):
        for i in range(W):
            if state[i] >= h:
                print("■",end=" ")
            else:
                print("□",end=" ")
        print()
    print(" ".join(map(str,state)))

def generate_subset_sum(n,k):
    if k == 1 and n <= H:
        return [[n]]
    if k == 1:
        return []
    L = []
    for i in range(min(n+1,H+1)):
        for x in generate_subset_sum(n-i,k-1):
            x.append(i)
            L.append(x)
    return L

def next(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    states_list = []
    for add in generate_subset_sum(to_move,W-1):
        new_state_cur = new_state.copy()
        for i in range(W-1):
            column_to_add = i + (i >= col)
            new_state_cur[column_to_add] += add[i]
        correct = True
        for i in range(W):
            if new_state_cur[i] > H:
                correct = False
        if correct:
            states_list.append(new_state_cur)
    return states_list

def next_smart(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    states_list = []
    dict_of_height = {}
    for i in range(W):
        if i != col:
            if state[i] in dict_of_height:
                dict_of_height[state[i]] += 1
            else:
                dict_of_height[state[i]] = 1
    different_height = len(dict_of_height)
    for a in all_partition(to_move,different_height):
        
        print("partition : ",a)
        repartition = []
        for grp,occ in enumerate(dict_of_height.values()):
            print("parameters :",a[grp],occ)
            inside_rep = decreasing_partition(a[grp],occ)
            repartition.append(inside_rep)
        print(repartition)
        possibilities = repartition[0]
        for grp in range(1,different_height):
            possibilities = itertools.product(possibilities,repartition[grp])
        for x in possibilities:
            print(x)

def next_L(state,col,row):
    new_state = state.copy()
    to_move = state[col] - row - 1
    new_state[col] = row
    for _ in range(to_move):
        columns = [i for i in range(W) if i != col]
        columns.sort(key= lambda i : new_state[i])
        new_state[columns[0]] += 1
    correct = True
    for x in new_state:
        if x > H:
            correct = False
    if correct:
        return new_state
    return None

dp = dict()
dpL = dict()

def state_to_str(state):
    s = sorted(state)
    return ",".join(map(str,s))

def bound_L_F(state):
    cost = 0
    while sum(state) != 0:    
        c = 0
        h = state[0]
        for i in range(W):
            if state[i] > h:
                c = i
                h = state[i]
        free_space = 0
        for i in range(W):
            if i != c:
                free_space += H - state[i]
        h = max(0,h - free_space - 1)
        cost += state[c] - h - 1
        state = next_L(state,c,h)
    return cost

def gen_random_state(w = W, h = H):
    return [rd.randint(0,h-1) for _ in range(w)]

def max_col(state, forbid = -1):
    c = 0
    h = state[0]
    for i in range(W):
        if i != forbid and state[i] > h:
            h = state[i]
            c = i
    return c

def min_col(state, forbid = -1):
    c = 0
    h = state[0]
    for i in range(W):
        if i != forbid and state[i] < h:
            h = state[i]
            c = i
    return c

def move_one_L(state, forbid = -1):
    c1 = max_col(state, forbid)
    c2 = min_col(state, forbid)
    state[c1] -= 1
    state[c2] += 1

dp_sum = {}

fact = [1]
for i in range(1,W*H):
    fact.append(i * fact[-1])

def potential(state, exp = 2):
    p = 0
    for x in state:
        p += x**exp
    return p

def sum_cost_L(state):
    if state == None:
        return 0
    if sum(state) == 0:
        return 1
    state.sort(reverse = True)
    c = state_to_str(state)
    if c in dp_sum:
        return dp_sum[c]
    total = 0
    for i in range(W):
        for j in range(state[i]):
            new_state = next_L(state,i,j)
            if new_state != None:
                total += sum_cost_L(new_state)
    
    dp_sum[c] = total + fact[sum(state)-1] * potential(state)
    return dp_sum[c]

dp_opt = {}

def sum_opt(state):
    if state == None:
        return 0
    if sum(state) == 1:
        return 0
    state.sort()
    c = state_to_str(state)
    if c in dp_opt:
        return dp_opt[c]
    total = 0
    for i in range(W):
        for j in range(state[i]):
            tomove = state[i] - j - 1
            nextstates = next(state,i,j)
            if nextstates == []:
                pass
            else:
                best = float("inf")
                for state2 in nextstates:
                    r = sum_opt(state2)
                    if r >= 0 and r + tomove * fact[sum(state2)] < best:
                        best = r + tomove * fact[sum(state2)]
                total += best
    dp_opt[c] = total
    return total

def check():
    worst = 1
    mems = []
    for i in range(1000):
        if i%10 == 0:
            print(i//10,"%")
        s = gen_random_state()
        i = rd.randint(0,W-1)
        t = s.copy()
        move_one_L(t,i)
        a,b = sum_cost_L(s),sum_cost_L(t)
        if a/b < 1:
            worst = a/b
            mems = s
    print(mems)
    print(worst)

def check_L_opt():
    worst = 1
    mems = []
    n = 1000000
    p = 100
    for i in range(n):
        if i%(n//p) == 0:
            print((100 * i//(n//p))/p,"%")
        s = gen_random_state()
        a,b = sum_cost_L(s),sum_opt(s)
        if b > 0 and b < a:
            worst = 0
            mems = s
            print(mems)
            print(worst)

def all_partition(n,k):
    if n == 0:
        return [[0] * k]
    if k == 1:
        return [[n]]
    l = []
    for i in range(n+1):
        L = all_partition(n-i,k-1)
        for c in L:
            c.append(i)
            l.append(c)
    return l

def decreasing_partition(n,k,end = None):
    if end == None:
        L = []
        for i in range(n+1):
            L += decreasing_partition(n,k,i)
        return L
    if n == 0 and end == 0:
        return [[0] * k]
    elif k == 1 and end == n:
        return [[n]]
    elif k == 1 or n == 0:
        return []
    l = []
    for i in range(end,n+1):
        L = decreasing_partition(n-end,k-1,i)
        for c in L:
            c.append(end)
            l.append(c)
    return l

def plot_(N):
    X = []
    Y = []
    min_diff = float("inf")
    for s in decreasing_partition(N,W):
        for i in range(1,W):
            if s[0] > s[i] + 1:
                ss = s.copy()
                ss[0] -= 1
                ss[i] += 1
                a,b = sum_cost_L(s),sum_cost_L(ss)
                min_diff = min(min_diff,a - b)
                pa,pb = potential(s),potential(ss)
                X.append(pa - pb)
                Y.append((a - b)/fact[N])
    print(min_diff/fact[N-1])
    plt.scatter(X,Y)
    plt.show()

if __name__ == "__main__":
    plot_(50)

