#include "sequence.h"
#include "utils/bidding_functions.h"
#include <functional>

Sequence smooth_sequence(double R, double x, double y){
    Sequence seq;
    double yb = ybot(R);
    double yt = ytop(R);
    double by = 1/yb;
    double ty = 1/yt;
    seq.add_interval(Interval(-by,0));
    seq.add_interval(Interval(0,x));
    seq.add_interval(Interval(x,x+y));
    seq.add_interval(Interval(x+y, x+y+ty));
    seq.extend_right();
    return seq;
}

Sequence best_smooth_sequence(double R, double precision) {
    Sequence seq;
    double yb = ybot(R);
    double yt = ytop(R);
    double by = 1/yb;
    double ty = 1/yt;
    if (yt - yb >= 2) {
        seq = smooth_sequence(R, 0, ty);
        seq.consistency = yb + 1;
        seq.robustness = R;
    } else if (yt - yb >= 1) {
        auto f = [yb](double z) { return next_cost(yb+1,z); };
        seq = smooth_sequence(R, 0, ty);
        seq.consistency = yb + 1;
        seq.robustness = R;
    } else {
        auto f = [yb](double z) { return next_cost(yb,z); };
        double x = dichotomy(f, yt, precision, 0, 100, false);
        seq = smooth_sequence(R, x, ty);
        seq.robustness = R;
        seq.consistency = yt * std::exp(x);
    }
    return seq;
}

Sequence separated_sequence(double R, double z, double h, std::vector<double>& lengths, std::vector<double>& holes) {
    Sequence seq;
    double ty = 1/ytop(R);
    seq.add_interval(Interval(-2*z - h,-z - h));
    seq.add_interval(Interval(-z,0));
    seq.robustness_to_check.push_back({Interval(-2*z - h,-z - h),true});
    double tmp = 0;
    int n = lengths.size();
    for (int i = 0; i < n; i++) {
        if (i == n-1) {
            seq.consistency_to_check.push_back({Interval(tmp, tmp + holes.at(i)),true});
        }
        
        tmp = tmp + holes.at(i);
        seq.add_interval(Interval(tmp, tmp + lengths.at(i)));
        seq.robustness_to_check.push_back({Interval(tmp, tmp + lengths.at(i)),true});
        tmp = tmp + lengths.at(i);
        
    }
    seq.extend_left();
    seq.add_interval(Interval(tmp, tmp + ty));
    seq.extend_right();
    seq.robustness_to_check.push_back({Interval(tmp, tmp + ty),true});
    return seq;
}

Sequence separated_sequence(double R, std::vector<double>& params, std::vector<double>& holes) {
    Sequence seq;
    double z = params[0];
    double h = params[1];
    double ty = 1/ytop(R);
    seq.add_interval(Interval(-2*z - h,-z - h));
    seq.add_interval(Interval(-z,0));
    seq.robustness_to_check.push_back({Interval(-2*z - h,-z - h),true});
    double tmp = 0;
    int n = params.size() - 2;
    for (int i = 0; i < n; i++) {
        if (i == n-1) {
            seq.consistency_to_check.push_back({Interval(tmp, tmp + holes.at(i)),true});
        }
        tmp = tmp + holes.at(i);
        seq.add_interval(Interval(tmp, tmp + params.at(i+2)));
        seq.robustness_to_check.push_back({Interval(tmp, tmp + params.at(i+2)),true});
        tmp = tmp + params.at(i+2);
        
    }
    seq.extend_left();
    seq.add_interval(Interval(tmp, tmp + ty));
    seq.extend_right();
    seq.robustness_to_check.push_back({Interval(tmp, tmp + ty),true});
    return seq;
}

Sequence tight_separated_sequence(double R, double const z, double h, std::vector<double>& lengths, double precision) {
    int n = lengths.size();
    std::vector<double> holes = {};
    for (int i = 0; i < n; i++){
        holes.push_back(0);
    }
    Sequence seq = separated_sequence(R, z, h, lengths, holes);
    if (seq.get_robustness() > R) {
        throw std::invalid_argument("Initial sequence has out of bounds robustness");
    }
    double check_start = -z;
    double check_end = 0;
    for (int i = 0; i < n; i++) {
        while (seq.find_maximum(check_start, check_end, precision, false) <= R) {
            if (holes.at(i) == 0) {
                holes.at(i) = 1;
            } else {
                holes.at(i) = holes.at(i) * 2;
            }
            seq = separated_sequence(R, z, h, lengths, holes);
        }
        double h_end = holes.at(i);
        double h_start = (holes.at(i) == 1)?0:holes.at(i)/2;
        while (h_end - h_start > precision) {
            holes.at(i) = (h_end + h_start)/2;
            seq = separated_sequence(R, z, h, lengths, holes);
            double r = seq.find_maximum(check_start, check_end, precision, false);
            if (r >= R) {
                h_end = holes.at(i);
            } else {
                h_start = holes.at(i);
            }
        }
        holes.at(i) = h_start;
        check_start = check_end + holes.at(i);
        check_end = check_start + lengths.at(i);
    }
    seq = separated_sequence(R, z, h, lengths, holes);
    return seq;
}

Sequence tight_separated_sequence(double R, std::vector<double>& params, double precision) {
    int n = params.size() - 2;
    double z = params[0];
    double h = params[1];
    std::vector<double> holes = {};
    for (int i = 0; i < n; i++){
        holes.push_back(0);
    }
    Sequence seq = separated_sequence(R, params, holes);
    if (seq.get_robustness() > R) {
        throw std::invalid_argument("Initial sequence has out of bounds robustness");
    }
    double check_start = -z;
    double check_end = 0;
    for (int i = 0; i < n; i++) {
        while (seq.find_maximum(check_start, check_end, precision, false) <= R) {
            if (holes.at(i) == 0) {
                holes.at(i) = 1;
            } else {
                holes.at(i) = holes.at(i) * 2;
            }
            seq = separated_sequence(R, params, holes);
        }
        double h_end = holes.at(i);
        double h_start = (holes.at(i) == 1)?0:holes.at(i)/2;
        while (h_end - h_start > precision) {
            holes.at(i) = (h_end + h_start)/2;
            seq = separated_sequence(R, params, holes);
            double r = seq.find_maximum(check_start, check_end, precision, false);
            if (r >= R) {
                h_end = holes.at(i);
            } else {
                h_start = holes.at(i);
            }
        }
        holes.at(i) = h_start;
        check_start = check_end + holes.at(i);
        check_end = check_start + params.at(i+2);
    }
    seq = separated_sequence(R, params, holes);
    return seq;
}

std::function<double(std::vector<double>)> tight_separated_sequence_functional(double R, double precision) {
    return [R, precision](std::vector<double> params) {
        try {
            for (double x : params) {
                if (x < 0) {
                    return numeric_limits<double>::infinity();
                }
            }
            Sequence seq = tight_separated_sequence(R, params, precision);
            return seq.get_consistency();
        } catch (...) {
            return numeric_limits<double>::infinity();
        }
    };
}