#include "lambert_w.h"
#include "bidding_functions.h"
#include "../plot.h"

using namespace std;

double ybot(double R) {
    return -1./lambert_w(-1./R, 1, 0);
}

double ytop(double R) {
    return -1./lambert_w(-1./R, 0, 0);
}

double local_ratio(double y, double z) {
    if (z == 0) {
        return y + 1;
    } else {
        return y + (std::exp(z)-1)/z;
    }
}

double next_cost(double y, double z) {
    return local_ratio(y, z)/std::exp(z);
}

double revert_exp_integral(double y) {
    return (- 1 - y * lambert_w(-(1/y) * exp(-1/y),1,0))/y;
}

double compute_R0() {
    auto g = [](auto vec) {
        double R = vec[0];
        double x = ytop(R) - ybot(R);
        if (x > 1) {
            return 1000.;
        }
        return -R;
    };
    vector<double> v1 = {3};
    vector<double> v2 = {3};
    return ternary_search(g, v1, v2, 0, 2.5, 3.5);
}

double contiguous_ratio(double R) {
    if (R == exp(1)) {
        return R;
    }
    double yt = ytop(R);
    double yb = ybot(R);
    if (yt - yb >= 1) {
        return yb + 1;
    } else {
        double z = dichotomy(
            [yb](double x){return next_cost(yb, x);},
            yt,
            1e-12,
            0,
            1.1,
            false
        );
        return yt * exp(z);
    }
    
}

// Fonction de dichotomie pour trouver x tel que |f(x) - c| <= epsilon
double dichotomy(
    const std::function<double(double)>& f,  // Fonction monotone (croissante ou décroissante)
    double c,                                // Valeur cible
    double epsilon,                          // Précision souhaitée
    double a,                                // Borne inférieure de l'intervalle
    double b,                                // Borne supérieure de l'intervalle
    bool is_increasing            // true si f est croissante, false si décroissante
) {
    if (epsilon <= 0) {
        throw std::invalid_argument("epsilon doit être strictement positif.");
    }

    // Vérification que c est bien entre f(a) et f(b)
    double fa = f(a);
    double fb = f(b);
    if (is_increasing) {
        if ((fa > c && fb > c) || (fa < c && fb < c)) {
            throw std::invalid_argument("c n'est pas dans l'intervalle [f(a), f(b)] ou [f(b), f(a)].");
        }
    } else {
        if ((fa < c && fb < c) || (fa > c && fb > c)) {
            throw std::invalid_argument("c n'est pas dans l'intervalle [f(b), f(a)] ou [f(a), f(b)].");
        }
    }

    // Boucle de dichotomie
    while (b - a > epsilon) {
        double mid = (a + b) / 2.0;
        double fmid = f(mid);

        if (std::abs(fmid - c) <= epsilon) {
            return mid;  // Solution trouvée
        }

        if (is_increasing) {
            if (fmid < c) {
                a = mid;  // Chercher dans [mid, b]
            } else {
                b = mid;  // Chercher dans [a, mid]
            }
        } else {
            if (fmid > c) {
                a = mid;  // Chercher dans [mid, b] (f décroissante)
            } else {
                b = mid;  // Chercher dans [a, mid] (f décroissante)
            }
        }
    }

    // Retourne le milieu de l'intervalle final
    return (a + b) / 2.0;
}

// Ternary search for a single parameter
double ternary_search(
    const function<double(const vector<double>&)>& func,
    vector<double>& params,
    vector<double>& valid_params, 
    int param_index,
    double left,
    double right,
    int max_iter,
    double ternary_precision
) {
    double m1, m2;
    for (int i = 0; i < max_iter; i++) {
        m1 = left + (right - left) / 3;
        m2 = right - (right - left) / 3;
        params[param_index] = m1;
        double val_m1 = func(params);
        params[param_index] = m2;
        double val_m2 = func(params);
        if (val_m1 < val_m2) {
            right = m2;
        } else if (val_m1 > val_m2) {
            left = m1;
        } else {
            double v = valid_params.at(param_index);
            if (v < m1) {
                right = m1;
            } else if (v > m2) {
                left = m2;
            } else {
                right = m2;
                left = m1;
            }
        }
        if (abs(right - left) < ternary_precision) {
            break;
        }
    }
    return (left + right) / 2;
}

// Multidimensional ternary search
vector<double> multidimensional_ternary_search(
    const function<double(const vector<double>&)>& func,
    vector<pair<double, double>> ranges,
    int max_iter,
    double iter_precision,
    double ternary_precision
) {
    vector<double> valid_params(ranges.size());
    valid_params[0] = 1;
    valid_params[1] = ranges.at(1).first;
    for (size_t i = 2; i < ranges.size(); ++i) {
        valid_params[i] = 1;
    }
    vector<double> params(ranges.size());
    params[0] = 1;
    params[1] = 0.001;
    for (size_t i = 2; i < ranges.size(); ++i) {
        params[i] = 1;
    }
    for (int iter = 0; iter < max_iter; ++iter) {
        double start_val = func(params);
        
        for (size_t i = 0; i < ranges.size(); ++i) {
            double prev = func(params);
            double prev_param = params[i];
            cout << "   from " << func(params);

            params[i] = ternary_search(func, params, valid_params, i, ranges[i].first, ranges[i].second, 100000, ternary_precision);
            valid_params[i] = params[i];
            double next = func(params);
            cout << " to " << func(params);
            double next_param = params[i];
            cout << " by changing " << prev_param << " by " << next_param << "\n";
            if (next > prev) {
                Plot plot;
                for (double x = ranges[i].first; x <= ranges[i].second; x = x + 0.001) {
                    params[i] = x;
                    double val = func(params);
                    plot.add({x, (val<100)?val:10});
                }
                params[i] = prev_param;
                valid_params[i] = prev_param;
                /* plot.display(); */
            }
        }
        double end_val = func(params);
        std::cout << "went from " << start_val << " to " << end_val << "\n";
        if (end_val < numeric_limits<double>::infinity() && abs(end_val - start_val) < iter_precision) {
            break;
        }
    }
    return params;
}