from math import floor,ceil
import random as rd
import math
import matplotlib.pyplot as plt

memo_f = dict()
memo_f2 = dict()

fact_mem = [1]

def fact(n):
    if len(fact_mem) > n:
        return fact_mem[n]
    if len(fact_mem) == n:
        a = n * fact_mem[n-1]
        fact_mem.append(a)
        return a
    return n * fact(n-1)

def to_str(a,b,c):
    return str(a) + "," + str(b) + "," + str(c)

def L(a,b,c,i,j):
    s = [a,b,c]
    k = s[i] - 1 - j
    ibig = (i == 0)
    ismall = 2 - (i==2)
    if k <= s[ibig]-s[ismall]:
        return (j,s[ibig],s[ismall] + k)
    r = k - (s[ibig]-s[ismall])
    return (j, s[ibig] + r - r//2, s[ibig] + r//2)

def compute1(a,b,c, depth = 0):
    if a + b + c <= 0:
        return 1
    if not(a >= b >= c):
        a,b,c = sorted([a,b,c], reverse=True)
        return compute1(a,b,c, depth)
    s = [a,b,c]
    state_str = to_str(a,b,c)
    if state_str in memo_f2:
        return memo_f2[state_str]
    n = a + b + c
    result = fact(n-1) * (a**2 + b**2 + c**2)
    for i in range(3):
        for j in range(s[i]):
            d,e,f = L(a,b,c,i,j)
            result += compute1(d,e,f, depth + 1)
    memo_f2[state_str] = result
    return result

def compute(a,b,c, depth = 0):
    if a + b + c == 1:
        return 2
    if not(a >= b >= c):
        a,b,c = sorted([a,b,c], reverse=True)
        return compute(a,b,c, depth)
    s = [a,b,c]
    state_str = to_str(a,b,c)
    if state_str in memo_f:
        return memo_f[state_str]
    n = a+b+c
    result = fact(n-1) * (a**2 + b**2 + c**2)
    for i in range(3):
        for j in range(s[i]):
            d,e,f = L(a,b,c,i,j)
            result += compute(d,e,f, depth + 1)
    memo_f[state_str] = result
    return result

def gen_random_state(h):
    return [rd.randint(0,h-1) for _ in range(3)]

def gen_all_states(n):
    for a in range(n+1):
        for b in range(min(a+1,n+1-a)):
            c = n-a-b
            if c <= b:
                yield [a,b,c]

def check_ratio(a,b,c):
    d,e,f = a-1,b+1,c
    v1 = compute(a,b,c)
    v2 = compute(d,e,f)
    delta = a-b-1
    fn = fact(a+b+c-1)
    if (v1 - v2) < (fn * delta):
        print(a,b,c)
        raise(NameError("should not happen rip"))
    return v1,v2,delta,fn

def successors(a,b,c):
    l = []
    s = [a,b,c]
    for i in range(3):
        for j in range(s[i]):
            d,e,f = L(a,b,c,i,j)
            if not(d >= e >= f):
                d,e,f = sorted([d,e,f],reverse=True)
            l.append([d,e,f])
    return l

def potential(a,b,c):
    return a**2 + b**2 + c**2

def pot_next_2(a,b,c):
    n = a+b+c
    s = n * potential(a,b,c)
    for d,e,f in successors(a,b,c):
        s += potential(d,e,f)
    return s

def new_pot(a,b,c):
    if not(a >= b >= c):
        a,b,c = sorted([a,b,c],reverse=True)
    n = a+b+c
    d,e,f = ceil(n/3),ceil((n-1)/3),ceil((n-2)/3)
    return max(0,a-d)**2 + max(0,b-e)**2 + max(0,c-f)**2

def new_pot2(a,b,c):
    if not(a >= b >= c):
        a,b,c = sorted([a,b,c],reverse=True)
    n = a+b+c
    d,e,f = ceil(n/3),ceil((n-1)/3),ceil((n-2)/3)
    return (a-d)**2 + (b-e)**2 + (c-f)**2

def compute_successors_new_pot(a,b,c):
    l = successors(a,b,c)
    pot = 0
    for d,e,f in l:
        pot += new_pot(d,e,f)
    return pot 

def compute_successors_new_pot2(a,b,c):
    l = successors(a,b,c)
    pot = 0
    for d,e,f in l:
        pot += new_pot2(d,e,f)
    return pot 

def compute_successors_potential(a,b,c):
    l = successors(a,b,c)
    pot = 0
    for d,e,f in l:
        pot += potential(d,e,f)
    return pot 

def counter_ex_pot():
    all = []
    for a,b,c in gen_all_states(11):
        all.append([(a,b,c),potential(a,b,c),compute_successors_potential(a,b,c)])
        all.sort(key= lambda x : x[1])
    for i in range(len(all)):
        all[i][1] = round(all[i][1],3)
    for x in all:
        print(" - ".join(map(str,x)))

def counter_ex_new_pot():
    all = []
    for a,b,c in gen_all_states(11):
        all.append([(a,b,c),new_pot(a,b,c),compute_successors_new_pot(a,b,c)])
        all.sort(key= lambda x : x[1])
    for x in all:
        print(" - ".join(map(str,x)))

def flat(a,b,c):
    return a <= c+1

def distance_flat(a,b,c):
    count = 0
    d,e,f = a,b,c
    while not(flat(d,e,f)):
        count += 1
        if d > e:
            d -= 1
            if e > f:
                f += 1
            else:
                e += 1
        else:
            e -= 1
            f += 1
    return count

def distance_flat_succ(a,b,c):
    l = successors(a,b,c)
    pot = 0
    for d,e,f in l:
        pot += distance_flat(d,e,f)
    return pot 

def check_distance_flat():
    all = []
    for a,b,c in gen_all_states(300):
        all.append(((a,b,c),distance_flat(a,b,c),distance_flat_succ(a,b,c)))
        all.sort(key= lambda x : (x[1],x[2]))
    for x in all:
        print(" - ".join(map(str,x)))
    for i in range(len(all) - 1):
        if all[i][2] > all[i+1][2]:
            print("rip : ", all[i])

def plot_potentiels_etats(n):
    L = []
    for a,b,c in gen_all_states(n):
        L.append(potential(a,b,c))
    plt.hist(L, 50)
    plt.show()

def plot_pot_succ(n):
    X = []
    Y = []
    for a,b,c in gen_all_states(n):
        X.append(potential(a,b,c))
        Y.append(compute_successors_potential(a,b,c)/n)
    plt.figure(1)
    plt.scatter(X,Y, s=1)

def plot_pot_cost(n):
    X = []
    Y = []
    XX,YY = [],[]
    for a,b,c in gen_all_states(n):
        X.append(potential(a,b,c))
        XX.append((potential(a,b,c),compute(a,b,c)/fact(n),a,b,c))
        Y.append(compute(a,b,c)/fact(n))
    print((max(Y)-min(Y))/(max(X)-min(X)))
    XX.sort()
    for x in XX:
        print(x)
    plt.figure(2)
    plt.scatter(X,Y, s=1)

def plot_next2_pot_cost(n):
    X = []
    Y = []
    for a,b,c in gen_all_states(n):
        X.append(pot_next_2(a,b,c))
        Y.append(compute(a,b,c)/fact(n))
    print((max(Y)-min(Y))/(max(X)-min(X)))
    plt.figure(2)
    plt.scatter(X,Y, s=1)

def plot_cost_flat_state(n):
    a,b,c = 0,0,0
    X = []
    Y = []
    for i in range(n):
        c += 1
        a,b,c = sorted([a,b,c], reverse=True)
        Y.append(compute(a,b,c)/fact(a+b+c))
        X.append(i)
    plt.figure(3)
    plt.scatter(X,Y,s=1)

def plot(n):
    plot_pot_succ(n)
    plot_pot_cost(n)
    plot_cost_flat_state(n)
    plt.show()

def check_new_pot(n):
    all = []
    for a,b,c in gen_all_states(n):
        all.append([(a,b,c),new_pot(a,b,c),compute_successors_new_pot(a,b,c)])
        all.sort(key= lambda x : x[1])
    for i in range(len(all)-1):
        if all[i][2] > all[i+1][2]:
            print("rippppp",n,all[i],all[i+1])
            break    

def check_new_pot2(n):
    all = []
    for a,b,c in gen_all_states(n):
        all.append([(a,b,c),new_pot2(a,b,c),compute_successors_new_pot2(a,b,c)])
        all.sort(key= lambda x : x[1])
    for i in range(len(all)-1):
        if all[i][2] > all[i+1][2]:
            print("rippppp",n,all[i],all[i+1])
            break    


def plot_new_pot_cost(n):
    X = []
    Y = []
    for a,b,c in gen_all_states(n):
        X.append(new_pot(a,b,c))
        Y.append(compute(a,b,c)/fact(n))
    print((max(Y)-min(Y))/(max(X)-min(X)))
    plt.figure(2)
    plt.scatter(X,Y, s=1)

def plot_new_pot2_cost(n):
    X = []
    Y = []
    for a,b,c in gen_all_states(n):
        X.append(new_pot2(a,b,c))
        Y.append(compute(a,b,c)/fact(n))
    print((max(Y)-min(Y))/(max(X)-min(X)))
    plt.figure(2)
    plt.scatter(X,Y, s=1)

def plot_delta_diff(n):
    n = 100
    X,Y = [],[]
    for a,b,c in gen_all_states(n):
        v = compute(a,b,c)/fact(n-1)
        if a > b + 1:
            v1 = compute(a-1,b+1,c)/fact(n-1)
            X.append(a-b-1)
            Y.append(v-v1)
        if a > c + 1:
            v2 = compute(a-1,b,c+1)/fact(n-1)
            X.append(a-c-1)
            Y.append(v-v2)
        if b > c + 1:
            v3 = compute(a,b-1,c+1)/fact(n-1)
            X.append(b-c-1)
            Y.append(v-v3)

    plt.scatter(X,Y,s=1)
    plt.plot([0,max(X)],[0,max(X)])
    maxi = 1 
    for i in range(len(X)):
        print(X[i],Y[i])
        if X[i] > Y[i]:
            print("AHHH")
            break
        if Y[i]/X[i] > maxi:
            maxi = Y[i]/X[i]
    print(maxi)

def check_L_opt():
    n = 10
    while True:
        for a,b,c in gen_all_states(n):
            if a <= b + 1:
                continue
            v1 = compute(a,b,c)
            d,e,f = a-1,b+1,c
            v2 = compute(d,e,f)
            if v1 < v2:
                print("L n'est pas optimal pour",a,b,c)
                return False
        n += 1
        print("L est optimal jusque",n)

def check_convexity(a,b):
    # a is the first stack, b is the sum of two others
    states = [[a,b-i,i] for i in range(b//2 + 1)]
    diff = float("inf")
    for i in range(b//2):
        s1 = sorted(states[i])
        s2 = sorted(states[i + 1])
        c1,c2 = compute(*s1),compute(*s2)
        new_diff = c1 - c2
        if new_diff >= diff:
            print("AIE AIE AIE", s1, s2, c1, c2)
        else:
            diff = new_diff

def check_all_convexity(limit_n):
    for a in range(limit_n):
        print("Starting a = ",a)
        for b in range(limit_n - a):
            check_convexity(a,b)

def descend(s1,s2):
    # knowing they have the same number of containers
    # is s2 obtained by lowering a container from s1
    a,b,c = s1
    d,e,f = s2 
    if d == a-1 and e == b+1 and a > b+1:
        return True
    elif d == a-1 and f == c+1 and a > b and b > c:
        return True
    elif e == b-1 and b > c+1 and f == c+1:
        return True
    return False

def check_convexity_implies_theorem(s1,i,j):
    k = (0 if i == 1 else (1 if j == 2 else 2))
    s2 = s1.copy()
    s2[i] -= 1
    s2[j] += 1
    hi = s1[i] - 1
    hj = s1[j]
    h = sum(s1) - hi - hj
    a,b,c = s1
    n = a + b + c
    d,e,f = s2
    l1i = [L(a,b,c,i,p) for p in range(hj)]
    l1j = [L(a,b,c,j,p) for p in range(hj)]
    l2i = [L(d,e,f,i,p) for p in range(hj)] 
    l2j = [L(d,e,f,j,p) for p in range(hj)]
    for p in range(hj):
        cost1 = compute(*l1i[p]) + compute(*l1j[p])
        cost2 = compute(*l2i[p]) + compute(*l2j[p])
        if cost1 < cost2:
            print(s1,s2,p)
    for p in range(hj+1,hi):
        cost1 = compute(*L(a,b,c,i,p))
        cost2 = compute(*L(d,e,f,i,p))
        if cost1 + 2 * fact(n-1) < cost2:
            print(s1,s2,p)


def check_all_convexity_implies_theorem(limit_n):
    for n in range(1,limit_n):
        print("checked until n =",n)
        for s in gen_all_states(n):
            for i,j in [(0,1),(0,2),(1,2)]:
                if s[i] > s[j] + 1:
                    if (i != 0 or j != 2 or (s[0] > s[1] and s[1] > s[2])):
                        check_convexity_implies_theorem(s,i,j)

if __name__ == "__main__":
    s1 = [4,2,0]
    s2 = [3,3,0]
    h = 1
    print(L(*s1,0,h))
    print(L(*s1,1,h))
    print(L(*s2,0,h))
    print(L(*s2,1,h))
    """ i,j = 0,2
    s1 = [6,3,0]
    s2 = s1.copy()
    s2[i] -= 1
    s2[j] += 1 
    s3 = s1.copy()
    s3[i] -= 2
    s3[j] += 2

    l1 = successors(*s1) + successors(*s3)
    l2 = successors(*s2) * 2
    l1.sort()
    l2.sort()
    print(l1)
    print(l2)
    to_rem = []
    for x in l1:
        if x in l2:
            to_rem.append(x)
            l2.remove(x)
    for x in to_rem:
        l1.remove(x)
    print(l1)
    print(l2)
    for i in range(len(l1)-1,-1,-1):
        print(descend(l1[i],l2[i])) """