#include "sequence.h"
#include "plot.h"
#include <iostream>
#include <cmath>

Sequence::Sequence() {
    size = 0;
    intervals = std::vector<Interval>();
}

Sequence::Sequence(std::vector<Interval> _intervals){
    intervals = _intervals;
    size = _intervals.size();
}

Sequence Sequence::copy() {
    Sequence seq;
    for (Interval interval : intervals) {
        seq.add_interval(interval);
    }
    return seq;
}

void Sequence::add_interval(Interval interval){
    updated_cumulative = false;
    Interval last(0,0);
    if (!intervals.empty()) {
        last = intervals.back();
    } else {
        start = interval.start;
    }
    if (intervals.empty() || last.end <= interval.start) {
        end = interval.start;
        intervals.push_back(interval);
        size++;
    } else {
        throw std::invalid_argument("Could not add interval because of bounds\n");
    }
}

void Sequence::extend_left() {
    updated_cumulative = false;
    updated_initial = false;
    if (size <= 1) {
        Interval interval = intervals.at(0);
        double left = interval.start;
        double length = interval.end - interval.start;
        Interval new_interval = Interval(left - length, left);
        intervals.insert(intervals.begin(), new_interval);
        size++;
        start = left - length;
    } else {
        Interval interval = intervals.at(0);
        double left = interval.start;
        double length = interval.end - interval.start;
        double hole = intervals.at(1).start - interval.end;
        Interval new_interval = Interval(left - length - hole, left - hole);
        intervals.insert(intervals.begin(), new_interval);
        size++;
        start = left - length - hole;
    }
    
}

void Sequence::extend_right() {
    updated_cumulative = false;
    Interval interval = intervals.back();
    double right = interval.end;
    double length = interval.end - interval.start;
    Interval new_interval = Interval(right, right + length);
    add_interval(new_interval);
}


void Sequence::compute_cumulative(){
    cumulative_cost = std::vector<double>();
    cumulative_cost.push_back(intervals[0].total_cost());
    for (int i = 1; i < size; i++) {
        Interval interval = intervals[i];
        cumulative_cost.push_back(cumulative_cost[i-1] + interval.total_cost());
    }
    updated_cumulative = true;
}

void Sequence::compute_initial_cost(){
    // Compute left-infinite cost until end of first interval. 
    // Suppose the next hole is also repeating on the left side. 
    if (intervals.empty()) {
        throw std::invalid_argument("Interval list is empty, cannot compute inital cost");
    } else {
        Interval first = intervals.front();
        double step = first.end - first.start;
        if (size > 1) {
            step =  intervals[1].start - intervals[0].start;
        }
        initial_cost = intervals[0].total_cost() * (1/(std::exp(step) - 1));
        updated_initial = true;
    }
}

double Sequence::compute_ratio(double t) {
    if (!updated_cumulative) {
        compute_cumulative();
    }
    if (!updated_initial) {
        compute_initial_cost();
    }
    if (t < start || t > end) {
        throw std::invalid_argument("argument of Sequence::compute_ratio not in the sequence bound");
    }
    int i = 0;
    while (t > intervals.at(i+1).start) {
        i++;
    } 
    double lambda = intervals.at(i).stochastic_limit(t);
    double cost = initial_cost + cumulative_cost.at(i) + intervals.at(i+1).integral(0,lambda);
    return cost/std::exp(t);
}


Plot Sequence::generate_ratio(double step, double from, double to) {
    plot_t res;
    if (from < start || to > end) {
        throw std::invalid_argument("generate_ratio argument out of sequence bounds");
    }
    double t = from;
    double cost;
    int i = 0;
    while (t < to) {
        std::pair<double,double> point{t,compute_ratio(t)};
        res.push_back(point);
        t = t + step;
    }
    std::pair<double,double> point{to,compute_ratio(to)};
    res.push_back(point);
    return Plot(res);
}

// Find minimum ratio in an interval [from,to], supposing it is unique 
double Sequence::find_minimum(double from, double to, double precision, bool check_end){
    double c = (from + to)/2;
    if (to - from < 2 * precision) {
        return compute_ratio((from + to)/2);
    }
    if (check_end) {
        if (compute_ratio(to) - compute_ratio(to - precision) < 0) {
            return compute_ratio(to);
        }
    }
    double y = compute_ratio(c);
    double z = compute_ratio(c+precision);
    if (z > y) {
        return find_minimum(from,c,precision, false);
    } else {
        return find_minimum(c,to,precision, false);
    }
}

// Find minimum ratio in an interval [from,to], supposing it is unique 
double Sequence::find_maximum(double from, double to, double precision, bool check_end){
    double c1 = (2 * from + to)/3;
    double c2 = (from + 2 * to)/3;
    if (to - from < 2 * precision) {
        return compute_ratio((from + to)/2);
    }
    double r1 = compute_ratio(c1);
    double r2 = compute_ratio(c2);

    if (r1 > r2) {
        return find_maximum(from, c2,precision,false);
    } else {
        return find_maximum(c1, to,precision,false);
    }
}
void Sequence::check_consistency(double from, double to, double precision, bool check_end) {
    double c = find_minimum(from,to,precision, check_end);
    if (c < consistency) {
        consistency = c;
    }
}

double Sequence::get_consistency(double precision) {
    if (consistency_to_check.empty()) {
        return consistency;
    } else {
        for (auto [interval,check_end] : consistency_to_check) {
            check_consistency(interval.start, interval.end, precision, check_end);
        }
        consistency_to_check.clear();
        return consistency;
    }
}

double Sequence::get_robustness(double precision) {
    if (robustness_to_check.empty()) {
        return robustness;
    } else {
        for (auto [interval,check_end] : robustness_to_check) {
            check_robustness(interval.start, interval.end, precision, false);
        }
        return robustness;
    }
}

void Sequence::check_robustness(double from, double to, double precision, bool check_end) {
    double r = find_maximum(from,to,precision, check_end);
    if (r  > robustness) {
        robustness = r;
    }
}

void Sequence::display(double step) {
    Plot plot = generate_ratio(step, start, end);
    plot.display();
}

std::ostream& operator<<(std::ostream &out, Sequence const& seq) {
    for (Interval interval : seq.intervals) {
        out << "[ " << interval.start << " - " << interval.end << " ]\n";
    }
    return out;
}