#include "main.h"
#include <chrono>
#include <algorithm>
#include <gmpxx.h>


State::State(std::vector<int> stacks_arg, int _size = -1) {
    stacks = stacks_arg;
    max_width = 3;
    size = 0;
    if (_size == -1) {
        for (int i = 0; i < 3; i++) {
            size += stacks[i];
        }
    } else {
        size = _size;
    }   
}

int State::sum_depth() {
    int x = 0;
    for (int i = 0; i < 3; i++) {
        x += stacks[i] * stacks[i];
    }
    return x;
}

State State::copy() {
    return State(stacks, size);
}


void State::print_simple() {
    std::cout << "[ ";
    for (int i = 0; i < 3; i++) {
        std::cout << stacks[i] << " ";
    }
    std::cout << "]";
}

void State::swap(int i, int j) {
    int tmp = stacks[i];
    stacks[i] = stacks[j];
    stacks[j] = tmp;
}

void State::sort_insertion() {
    if (stacks[0] < stacks[1]) {
        swap(0, 1);
    }
    if (stacks[1] < stacks[2]) {
        swap(1, 2);
    }
    if (stacks[0] < stacks[1]) {
        swap(0, 1);
    }
}

int State::min_stack(int forbidden) {
    int min_ind = 0;
    int min = 1 << 20;
    for (int i = 0; i < 3; i++) {
        if (stacks[i] < min && i != forbidden) {
            min = stacks[i];
            min_ind = i;
        }
    }
    return min_ind;
}


void State::relocate_k(int k, int i1, int i2) {
    stacks[i1] -= k;
    stacks[i2] += k;
}

void State::retrieve(int i, int k = 1){
    stacks[i] -= k;
    size -= k;
}

void State::add(int i, int k = 1){
    stacks[i] += k;
    size += k;
}

void State::step_L(int i, int j) {
    int to_move = stacks[i] - j - 1;
    stacks[i] -= (to_move + 1);
    int big,small;
    small = 0 + (i==0);
    big = 2 - (i==2);
    if (stacks[small] > stacks[big]) {
        int tmp = small;
        small = big;
        big = tmp;
    }
    if (to_move <= stacks[big] - stacks[small]) {
        stacks[small] += to_move;
    } else {
        to_move -= stacks[big] - stacks[small];
        stacks[small] = stacks[big];
        int mid = to_move/2;
        stacks[small] += mid;
        stacks[big] += to_move - mid;
    }
    sort_insertion();
    size -= 1;
}

big_int Average_cost::average_cost_L(State& s) {
    int n = s.size;
    if (memo_L[n]->find(s) != memo_L[n]->end()) {
        return memo_L[n]->at(s);
    } else if (s.size > 0) {
        big_int total_cost = s.sum_depth();
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < s.stacks[i]; j++) {
                State s2 = s.copy();
                s2.step_L(i,j);
                total_cost += average_cost_L(s2);
            }
        }
        memo_L[n]->insert({s,total_cost});
        return total_cost;
    } else {
        return 0;
    }
}

bool State::operator==(const State& other) const {
    if (other.size != size) {
        return false;
    }
    for (int i = 0; i < 3; i++) {
        if (stacks[i] != other.stacks[i]) {
            return false;
        }
    }
    return true;
}
Average_cost::Average_cost() {
    std::vector<memo_t*> memo_L;
    memo_opt = new memo_t();
    memo_fact = std::vector<big_int>();
    memo_fact.push_back(1);
}

big_int Average_cost::fact(int n) {
    for (int i = memo_fact.size(); i <= n; i++) {
        memo_fact.push_back(i * memo_fact[i-1]);
    }
    return memo_fact[n];
}

big_int Average_cost::average_cost_opt(State& s, int depth = 0) {
    s.sort_insertion();
    auto elt = memo_opt->find(s);
    if (elt != memo_opt->end()) {
        return elt->second;
    } else if (s.size > 1) {
        big_int total_cost = s.sum_depth() * fact(s.size - 1);
        int to_move;
        big_int best;
        big_int cost;
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < s.stacks[i]; j++) {
                to_move = s.stacks[i] - j - 1;
                s.retrieve(i,to_move+1);
                best = -1;
                for (int a = 0; a <= to_move; a++) {
                    State s2 = s.copy();
                    int i1 = (i == 0);
                    int i2 = 2 - (i == 2);
                    s2.add(i1,a);
                    s2.add(i2,to_move-a);
                    cost = average_cost_opt(s2, depth + 1);
                    if (best == -1 || cost < best) {
                        best = cost;
                    }
                }
                s.add(i,to_move+1);
                total_cost += best;
            }
        }
        memo_opt->insert({s,total_cost});
        return total_cost;
    } else {
        return 0;
    }
}

std::vector<std::vector<int>> gen_all_state(int n){
    std::vector<std::vector<int>> result;
    for (int a = n/3; a <= n; a++) {
        for (int b = (n-a+1)/2; b <= n - a && b <= a; b++) {
            int c = n - a - b;
            result.push_back({a, b, c});   
        }
    }
    return result;
}


int main()
{
    using std::chrono::high_resolution_clock;
    using std::chrono::duration_cast;
    using std::chrono::milliseconds;

    Average_cost helper;
    int n = 1;
    helper.memo_L.push_back(new memo_t());
    helper.memo_L[0]->insert({State({0,0,0}), 0});
    helper.memo_L.push_back(new memo_t());
    helper.memo_L[1]->insert({State({1,0,0}), 0});
    n++;
    while (1) {
        helper.memo_L.push_back(new memo_t());
        delete helper.memo_L[n-2];
        for (auto tab : gen_all_state(n)) {
            State s(tab);
            if (s.stacks[0] <= s.stacks[1] + 1) {
                big_int a = helper.average_cost_L(s);
            } else {
                std::vector<int> tab2 = {tab[0]-1, tab[1]+1, tab[2]};
                State s2(tab2);
                big_int a = helper.average_cost_L(s);
                big_int b = helper.average_cost_L(s2);
                if (a < b) {
                    std::cout << "Error: " << a << " < " << b << " for state ";
                    s.print_simple();
                }
            }
            
        }
        n++;
        std::cout << "n = " << n << std::endl;
    }
    
}
