#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <tuple>
#include "main.h"
#include <climits>
#include <time.h>
using namespace std;

State::State(const vector<vector<int>>& _stacks, pair<int, int> _req = {-1, -1})
    : stacks(_stacks), req(_req) {
    w = stacks.size();
    h = stacks[0].size();
    heights.resize(w, 0);
    size = 0;
    for (int i = 0; i < w; ++i) {
        int j = 0;
        while (j < h && stacks[i][j] > 0) {
            if (stacks[i][j] == 1) {
                size++;
            }
            j++;
        }
        heights[i] = j;
    }
}

State::State() {
}

State State::copy() const {
    return State(stacks, req);
}

bool State::operator==(const State& other) const {
    if (w != other.w || h != other.h || req != other.req) return false;
    return stacks == other.stacks;
}

struct State::HashFunction {
    size_t operator()(const State& s) const {
        size_t hash = 0;
        for (int i = 0; i < s.w; ++i) {
            int v = 0, p = 1, j = 0;
            while (j < s.h && s.stacks[i][j] > 0) {
                v += (s.stacks[i][j] - 1) * p;
                p *= 2;
                j++;
            }
            v += p;
            hash ^= v + 0x9e3779b9 + (hash << 6) + (hash >> 2);
        }
        if (s.req.first != -1) {
            hash += 157 * s.req.first + s.req.second;
        } else {
            hash -= 1;
        }
        return hash;
    }
};

bool State::add(int i, int v) {
    if (heights[i] < h) {
        stacks[i][heights[i]++] = v;
        size += 1;
        return true;
    }
    return false;
}

void State::pop(int i) {
    if (heights[i] > 0) {
        stacks[i][--heights[i]] = 0;
        size -= 1;
    }
}

void State::sort_stacks() {
    if (req.first == -1) {
        sort(stacks.begin(), stacks.end());
        for (int i = 0; i < w; ++i) {
            int j = 0;
            while (j < h && stacks[i][j] > 0) j++;
            heights[i] = j;
        }
    }
}

void State::print() const {
    cout << string(w + 2, '-') << "\n";
    for (int j = h - 1; j >= 0; --j) {
        cout << "|";
        for (int i = 0; i < w; ++i) {
            if (req.first == i && req.second == j) cout << "\033[6;30;42m";
            if (stacks[i][j] == 0)
                cout << " ";
            else if (stacks[i][j] == 1)
                cout << "o";
            else
                cout << "x";
            if (req.first == i && req.second == j) cout << "\033[0m";
        }
        cout << "|\n";
    }
    cout << string(w + 2, '-') << "\n";
}

vector<vector<int>> gen_all_seq(int n, int k, int forbidden = -1) {
    vector<vector<int>> result;
    if (k == 0) {
        result.push_back({});
        return result;
    }
    vector<int> seq(k, 0);
    while (true) {
        // Check for forbidden value
        bool valid = true;
        for (int i = 0; i < k; ++i) {
            if (seq[i] == forbidden) {
                valid = false;
                break;
            }
        }
        if (valid) result.push_back(seq);

        // Increment sequence
        int pos = k - 1;
        while (pos >= 0) {
            seq[pos]++;
            if (seq[pos] < n) break;
            seq[pos] = 0;
            pos--;
        }
        if (pos < 0) break;
    }
    return result;
}

unordered_map<State, int, State::HashFunction> memo;
unordered_map<State, State, State::HashFunction> memo_strat;

int compute(const State& state, int depth = 0) {
    if (memo.count(state)) return memo[state];

    if (state.req.first == -1) {
        int best = 0;
        State best_req = state.copy();
        for (int i = 0; i < state.w; ++i) {
            for (int j = 0; j < state.heights[i]; ++j) {
                if (state.stacks[i][j] == 1) {
                    State state2 = state.copy();
                    state2.req = {i, j};
                    int val = compute(state2, depth + 1);
                    if (val >= best) {
                        best = val;
                        best_req = state2;
                    }
                }
            }
        }
        memo[state] = best;
        memo_strat[state] = best_req.copy();
        return best;
    }

    int best = INT_MAX;
    State best_state = state.copy();
    vector<int> to_move;
    int i = state.req.first;
    int j = state.req.second;
    for (int k = state.heights[i] - 1; k > j; --k) {
        to_move.push_back(state.stacks[i][k]);
    }

    for (auto seq : gen_all_seq(state.w, to_move.size(), i)) {
        bool ok = true;
        State state2 = state.copy();
        state2.req = {-1, -1};
        for (int x = 0; x < to_move.size() + 1; ++x) state2.pop(i);
        for (int k = 0; k < to_move.size(); ++k) {
            if (!state2.add(seq[k], to_move[k])) {
                ok = false;
                break;
            }
        }
        if (ok) {
            state2.sort_stacks();
            int val = compute(state2, depth + 1) + to_move.size();
            if (val < best) {
                best = val;
                best_state = state2;
            }
        }
    }

    memo[state] = best;
    memo_strat[state] = best_state;
    return best;
}

int main() {
    vector<vector<int>> initial = {
        {2, 1, 2, 2,2,0,0},
        {2, 2, 1, 2,1,0,0},
        {2, 1, 1, 0,0,0,0},
        {2, 1, 2, 1,0,0,0},
        {2, 1, 2, 1,0,0,0}
    };


    State initial_state(initial);
    
    initial_state.sort_stacks();
    clock_t start = clock();
    int result = compute(initial_state);
    clock_t end = clock();
    double elapsed = double(end - start) / CLOCKS_PER_SEC;
    cout << "Time taken: " << elapsed << " seconds\n";
    cout << "Minimum moves: " << result << "\n";
    initial_state.print();
    
    while (initial_state.size > 0) {
        initial_state = memo_strat[initial_state];
        cout << "We ask :\n";
        initial_state.print();
        cout << "You should try:\n";
        memo_strat[initial_state].print();
        vector<int> to_move;
        int i = initial_state.req.first;
        int j = initial_state.req.second;
        for (int k = initial_state.heights[i] - 1; k > j; --k) {
            to_move.push_back(initial_state.stacks[i][k]);
        }
        int a;
        cout << "Enter the stacks where to relocate (size " << to_move.size() << "): ";
        for (int k = 0; k < to_move.size(); ++k) {
            cin >> a;
            initial_state.add(a, to_move[k]);
        }
        for (int k = 0; k < to_move.size(); ++k){
            initial_state.pop(initial_state.req.first);
        }
        initial_state.pop(i);  
        
        initial_state.req = {-1, -1};
        initial_state.sort_stacks();
        cout << "Current state:\n";
        initial_state.print();
        cout << compute(initial_state) << "\n";
    }
    return 0;
}