#include "main.h"
#include <chrono>
#include <algorithm>


State::State(std::vector<int> _heights, int _height, int _size = -1) {
    heights = _heights;
    max_width = _heights.size();
    max_height = _height;
    size = 0;
    if (_size == -1) {
        for (int i = 0; i < max_width; i++) {
            size += _heights[i];
        }
    } else {
        size = _size;
    }   
    for (int i = 0; i < max_width; i++) {
        stacks.push_back(std::vector<int>(max_height, -1));
    }
}

void State::apply_perm(std::vector<int> perm) {
    for (int i = 0; i < max_width; i++) {
        for (int j = 0; j < max_height; j++) {
            stacks[i][j] = -1;
        }
    }
    int k = 0;
    for (int i = 0; i < max_width; i++) {
        for (int j = 0; j < heights[i]; j++) {
            stacks[i][j] = perm[k];
            k += 1;
        }
    }
}


State State::copy() {
    State s = State(heights, max_height, size);
    for (int i = 0; i < max_width; i++) {
        for (int j = 0; j < max_height; j++) {
            s.stacks[i][j] = stacks[i][j];
        }
    }
    return s;
}

void State::print() {
    for (int j = max_height - 1; j >= 0; j--) {
        for (int i = 0; i < max_width; i++) {
            if (stacks[i][j] != -1) {
                printf("%d ", stacks[i][j]);
            } else {
                printf("□ ");
            }
        }
        printf("\n");
    }
    for (int i = 0; i < max_width; i++) {
        printf("%d ",heights[i]);
    }
    printf("\n");
}



bool State::valid_move(int i, int j) {
    return ((max_width - 1) * max_height + j >= size) && (j < heights[i]);
}

int State::min_stack(int forbidden) {
    int min_ind = 0;
    int min = max_height + 1;
    for (int i = 0; i < max_width; i++) {
        if (heights[i] < min && i != forbidden) {
            min = heights[i];
            min_ind = i;
        }
    }
    return min_ind;
}

void State::relocate(int i1, int i2) {
    heights[i1] -= 1;
    heights[i2] += 1;
    stacks[i2][heights[i2] - 1] = stacks[i1][heights[i1]];
    stacks[i1][heights[i1]] = -1; 
}

void State::retrieve(int i){
    heights[i] -= 1;
    size -= 1;
    stacks[i][heights[i]] = -1;
}

int State::step_L(int i, int j) {
        int to_move = heights[i] - j - 1;
        for (int k = 0; k < to_move; k++) {
            int i_low = min_stack(i);
            relocate(i, i_low);
        }
        retrieve(i);
        return to_move;
    }


std::pair<int,int> State::next_req() {
    int best_req = 1000;
    std::pair<int,int> best_ind_req = {0,0};
    for (int i = 0; i < max_width; i++) {
        for (int j = 0; j < max_height; j++) {
            if (stacks[i][j] != -1) {
                int req = stacks[i][j];
                if (req < best_req) {
                    best_req = req;
                    best_ind_req = {i,j};
                }
            }
        }
    }
    return best_ind_req;
}

int cost_L(State& s) { 
    State s2 = s.copy();
    int cost = 0;
    while (s2.size > 0) {
        std::pair<int,int> ind_req = s2.next_req();
        int i = ind_req.first;
        int j = ind_req.second;
        cost += s2.step_L(i,j);
    }
    return cost;
}



int main()
{
    using std::chrono::high_resolution_clock;
    using std::chrono::duration_cast;
    using std::chrono::milliseconds;

    std::cout << "give your size :" << std::endl;
    int w;
    std::cin >> w;
    std::vector<int> perm;
    std::vector<int> heights(w, 0);
    
    std::cout << "give your heights :" << std::endl;
    for (int i = 0; i < w; i++) {
        std::cin >> heights[i];
    }
    State s = State(heights, 6);

    std::cout << "give i j :" << std::endl;
    int i1;
    int j1;
    std::cin >> i1;
    std::cin >> j1;
    for (int i = 1; i <= s.size; i++) {
        perm.push_back(i);
    }
    std::vector<int> result_s;
    std::vector<int> result_s2;
    do {
        s.apply_perm(perm);
        State s2 = s.copy();
        s2.relocate(i1,j1);
        result_s.push_back(cost_L(s));
        result_s2.push_back(cost_L(s2));
    } while (std::next_permutation(perm.begin(), perm.end()));
    std::sort(result_s.begin(), result_s.end());
    std::sort(result_s2.begin(), result_s2.end());
    std::cout << "result 1 : ";
    for (int i = 0; i < result_s.size(); i++) {
        std::cout << result_s[i] - result_s2[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}