import random as rd
from lower_bound_online import gen_random_state,bound_L_F
import itertools

W = 3
H = 5

def heights_to_mat(s):
    m = [[0 for _ in range(H)] for _ in range(W)]
    for i in range(W):
        for j in range(s[i]):
            m[i][j] = -1
    return m,s

def min(s,forbid = -1):
    mini = float("inf")
    min_ind = -1
    for i in range(len(s)):
        if i != forbid and s[i] < mini:
            mini = s[i]
            min_ind = i
    return min_ind

def L(m,s,i,j):
    for r in range(s[i]-1,j,-1):
        v = m[i][r]
        i2 = min(s,i)
        m[i][r] = 0
        m[i2][s[i2]] = v
        s[i] -= 1
        s[i2] += 1
    s[i] -= 1
    m[i][s[i]] = 0

def apply_perm_on_mat(p,m,s):
    k = 0
    for i in range(W):
        for j in range(s[i]):
            m[i][j] = p[k]
            k += 1

def random_perm(n):
    l = [i for i in range(1,n+1)]
    for i in range(n):
        j = rd.randint(i,n-1)
        l[i],l[j] = l[j],l[i]
    return l

def next_container(m):
    ci,cj = 0,0
    best = float("inf")
    for i in range(W):
        for j in range(H):
            if 0 < m[i][j] and m[i][j] < best:
                best = m[i][j]
                ci,cj = i,j
    return ci,cj

def rem(m,s,i1):
    m[i1][s[i1] - 1] = 0
    s[i1] -= 1

def move_container(m,s,i1,i2):
    a = m[i1][s[i1] - 1]
    rem(m,s,i1)
    m[i2][s[i2]] = a
    s[i2] += 1

def gen_list_stack(n,forbid):
    if n == 0:
        return [[]]
    if n == 1:
        return [[i] for i in range(W) if i != forbid]
    L = gen_list_stack(n-1,forbid)
    l = []
    for c in L: 
        for i in range(W):
            if i != forbid:
                cc = c.copy()
                cc.append(i)
                l.append(cc)
    return l

def next_states(m,s):
    ci,cj = next_container(m)
    to_rem = s[ci] - cj - 1
    l_states = []
    for l in gen_list_stack(to_rem,ci):
        mm = [m[i].copy() for i in range(W)]
        ss = s.copy()
        correct = True
        for i in l:
            if ss[i] == H:
                correct = False
            if correct:
                move_container(mm,ss,ci,i)
        if correct:
            rem(mm,ss,ci)
            l_states.append((mm,ss))
    return l_states,to_rem

dp = {}

def state_to_str(m,s):
    c1 = ",".join(map(str,s))
    l = []
    for i in range(W):
        for j in range(s[i]):
            l.append(m[i][j])
    return c1 + "#" + ",".join(map(str,l))

def tuple_to_str(s):
    return ",".join(map(str,s))

def projection(m,s):
    m2 = [(i,m[i].copy()) for i in range(W)]
    m2.sort(key = lambda couple : s[couple[0]])
    for i in range(W):
        m[i] = m2[i][1].copy()
    s.sort()
    for i in range(W):
        for j in range(s[i]):
            m[i][j] -= 1
    

def opt(m,s):
    projection(m,s)
    c = state_to_str(m,s)
    if c in dp:
        return dp[c]
    if sum(s) == 1:
        return 0
    best = float("inf")
    nexts,cost = next_states(m,s)
    for mm,ss in nexts:
        next_cost = opt(mm,ss)
        if next_cost < best:
            best = next_cost
    score = best + cost
    dp[c] = score
    return score

def opt_against_all(s):
    n = sum(s)
    worst = 0
    for p in itertools.permutations(range(1,n+1),n):
        m,_ = heights_to_mat(s)
        apply_perm_on_mat(p,m,s)
        score = opt(m,s)
        if score > worst:
            worst = score
    return worst

def len_non_z(x):
    i = 0
    while i < len(x) and x[i] > 0:
        i += 1
    return i

fact_mem = [1]

def fact(n):
    if len(fact_mem) > n:
        return fact_mem[n]
    if len(fact_mem) == n:
        a = n * fact_mem[n-1]
        fact_mem.append(a)
        return a
    return n * fact(n-1)

memo = {}

def compute_single_cont_cost(s,depth=0):
    if tuple_to_str(s) in memo:
        return memo[tuple_to_str(s)]
    m,_ = heights_to_mat(s)
    w = len(s)
    n = sum(s)
    c = [[0] * s[i] for i in range(w)]
    if max(s) == 1:
        memo[tuple_to_str(s)] = c
        return c
    k = 1
    location = {}
    for i in range(w):
        for j in range(s[i]):
            m[i][j] = k
            location[k] = (i,j)
            k += 1
            c[i][j] = j * fact(n-1)
    
    for ri in range(w):
        for rj in range(s[ri]):
            mm = [m[i].copy() for i in range(w)]
            ss = s.copy()
            L(mm,ss,ri,rj)
            mm.sort(key = lambda x : len_non_z(x), reverse=True)
            ss.sort(reverse = True)
            m2 = compute_single_cont_cost(ss,depth+1)
            if depth <= 0:
                print(m)
                print(ss)
                print(mm)
                print()
            for i in range(w):
                for j in range(ss[i]):
                    v = m2[i][j]
                    k = mm[i][j]
                    i1,j1 = location[k]
                    c[i1][j1] += v
    return c


print(compute_single_cont_cost([2,1,2]))