
s2 = [
    [1,2,2,2,0,0,0],
    [1,1,1,0,0,0,0],
    [1,0,0,0,0,0,0],
]

s3 = [
    [2,1,1,2,0],
    [1,2,0,0,0],
    [1,1,0,0,0],
    [2,1,1,0,0],
]

class State:
    def __init__(self, _stacks, _req = None):
        self.stacks = _stacks
        self.req = _req
        self.w = len(_stacks)
        self.h = len(_stacks[0])
        self.heights = [0] * self.w
        for i in range(self.w):
            j = 0
            while j < self.h and self.stacks[i][j] > 0:
                j += 1
            self.heights[i] = j

    def copy(self):
        _stacks = []
        for i in range(self.w):
            _stacks.append(self.stacks[i].copy())
        return State(_stacks, self.req)
    
    def __hash__(self):
        for i in range(self.w):
            v = 0
            p = 1
            j = 0
            while j < self.h and self.stacks[i][j] > 0:
                v += (self.stacks[i][j] - 1) * p
                p *= 2
                j += 1
            v += p
        if self.req != None:
            i,j = self.req
            v += 157 * i + j
        else:
            v -= 1
        return v
    
    def __eq__(self,other):
        if self.w != other.w or self.h != self.w:
            return False
        for i in range(self.w):
            for j in range(self.h):
                if self.stacks[i][j] != other.stacks[i][j]:
                    return False
        if self.req != other.req:
            return False
        return True
    
    def add(self,i,v):
        if self.heights[i] < self.h:
            self.stacks[i][self.heights[i]] = v
            self.heights[i] += 1
            return True
        else:
            return False
    
    def pop(self,i):
        self.stacks[i][self.heights[i]-1] = 0
        self.heights[i] -= 1
        
    def print(self):
        print("-" * (self.w+2))
        a,b = -1,-1
        if self.req != None:
            a,b = self.req
        for j in range(self.h-1,-1,-1):
            print("|",end="")
            for i in range(self.w):
                if a == i and b == j:
                    print('\x1b[6;30;42m',end="")
                if self.stacks[i][j] == 0:
                    print(" ",end="")
                elif self.stacks[i][j] == 1:
                    print("o",end="")
                else:
                    print("x",end="")
                if a == i and b == j:
                    print('\x1b[0m',end="")
            print("|")
        print("-" * (self.w+2))

    def sort(self):
        if self.req is None:
            self.stacks.sort()
            for i in range(self.w):
                j = 0
                while j < self.h and self.stacks[i][j] > 0:
                    j += 1
                self.heights[i] = j

memo = dict()
memo_strat = dict()
state2 = State(s2)
state3 = State(s3)

def gen_all_seq(n,k,forbidden = -1):
    if k == 0:
        return [[]]
    L = []
    l = gen_all_seq(n,k-1,forbidden)
    for i in range(n):
        if i != forbidden:
            for x in l:
                c = x.copy()
                c.append(i)
                L.append(c)
    return L

def compute(state, depth = 0):
    if state in memo:
        return memo[state]
    if state.req is None:
        state.sort()
        best = -1
        best_req = None
        for i in range(state.w):
            for j in range(state.heights[i]):
                if state.stacks[i][j] == 1:
                    state2 = state.copy()
                    state2.req = (i,j)
                    val = compute(state2,depth+1)
                    if val > best:
                        best = val
                        best_req = state2
        memo[state] = best
        memo_strat[state] = best_req
        return best
    best = float("inf")
    best_state = None 
    to_move = []
    i,j = state.req
    for k in range(state.heights[i]-1,j,-1):
        to_move.append(state.stacks[i][k])
    for seq in gen_all_seq(state.w, len(to_move), i):
        ok = True
        state2 = state.copy()
        state2.req = None
        for _ in range(len(to_move) + 1):
            state2.pop(i)
        for k in range(len(to_move)):
            if not(state2.add(seq[k],to_move[k])):
                ok = False
        if ok:
            state2.sort()
            val = compute(state2,depth + 1) + len(to_move)
            if val < best:
                best = val
                best_state = state2
    memo[state] = best
    memo_strat[state] = best_state
    return best

print(compute(state2))
state = state2
state.print()
while not(state is None) and sum(state.heights) != 0:
    state = memo_strat[state]
    if not(state is None):
        state.print()

    
