#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import sys
import os

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
output_file = os.path.join(project_root, "outputs", "smoothness", "vianney.csv")

# parse command line arguments
i = 1
verbose = 0
while i < len(sys.argv):
    opt = sys.argv[i]
    i += 1 
    if opt == "-verbose":
        verbose = 2
    else:
        print("Wrong argument", opt)
        i += 1

def vianney_CR_exact(a, u, pred, rho, T=30.0):
    """
    Exact truncated integral on [0,T] for any rho >= 0.
    """
    r = u / pred
    coeff = 2.0 * a**2 / (a - 1.0) * (pred / u)

    if rho == 0:
        k = np.floor(np.log(r) / np.log(a))
        weight = 1.0 - 1.0 / (1.0 + T)**2
        return (a**2 / (a - 1.0)) * (a**k / r) * weight

    def F(t):
        return -rho / (1.0 + t) - (1.0 - rho) / (2.0 * (1.0 + t)**2)

    m_min = int(np.floor(np.log(r / (1.0 + rho * T)) / np.log(a)))
    m_max = int(np.floor(np.log(r) / np.log(a)))

    total = 0.0
    for m in range(m_min, m_max + 1):
        L = max(0.0, (r / (a**(m + 1)) - 1.0) / rho)
        R = min(T,   (r / (a**m)       - 1.0) / rho)
        if L < R:
            total += (a**m) * (F(R) - F(L))

    return coeff * total

def vianney_CR_many_u_exact(a, u_values, pred, rho, T=30.0):
    u_values = np.asarray(u_values, dtype=float)
    return np.array([vianney_CR_exact(a, u, pred, rho, T=T) for u in u_values])


def plot_vianney(a, rho, n_plot=300, T=30.0, plot = True, verbose = verbose):
    X = np.linspace(0, 2 * np.log(a), n_plot)
    U = np.exp(X)
    Y = vianney_CR_many_u_exact(a, U, 1, rho, T=T)

    if rho == 0:
        min_val = a/(a-1)
        max_val = a**2/(a-1)
        worst_left_slope = a/np.log(a)
        worst_right_slope = float("inf")
        if verbose == 2:
            print(f"✅ a = {a:.5f}, rho = {rho:.5f}")
            print(f"Consistency = {min_val:.5f}")
            print(f"Robustness = {max_val:.5f}")
            print(f"worst left slope = {worst_left_slope:.5f}")
            print(f"worst right slope = {worst_right_slope:.5f}")

        elif verbose == 1:
            print(f"✅ a = {a:.5f}, rho = {rho:.5f}")
        return a, rho, a/(a-1), a**2/(a - 1), a/np.log(a), float("inf")

    min_val = vianney_CR_exact(a, a, 1, rho, T=T)
    min_u = a

    max_val = max(Y)
    worst_right_slope, worst_left_slope = 0,0

    for i in range(len(X)):
        x = X[i]
        y = Y[i]
        if x != np.log(a):
            slope = abs((y - min_val)/(x - np.log(a)))
            if x < np.log(a):
                worst_left_slope = max(slope, worst_left_slope)
            else:
                worst_right_slope = max(slope, worst_right_slope)

    
    if verbose == 2:
        print(f"✅ a = {a:.5f}, rho = {rho:.5f}")
        print(f"Consistency = {min_val:.5f}")
        print(f"Robustness = {max_val:.5f}")
        print(f"worst left slope = {worst_left_slope:.5f}")
        print(f"worst right slope = {worst_right_slope:.5f}")

    elif verbose == 1:
        print(f"✅ a = {a:.5f}, rho = {rho:.5f}")

    if plot:
        plt.plot(X, Y, label=f"a={a}")
        plt.scatter([np.log(a)], [min_val])

    return a, rho, min_val, max_val, worst_left_slope, worst_right_slope

def plot_data_fixed_a(a, n_plot = 100, n = 2000, T = 50):
    X = [np.exp(x) for x in np.linspace(-10,3, n_plot)]
    C,R,LS,RS = [],[],[],[]
    for rho in X:
        _,_,c,r,ls,rs = plot_vianney(a, rho, 0, -1, n, T, False, False)
        C.append(c)
        R.append(r)
        LS.append(ls)
        RS.append(rs)
    plt.plot(X,C)
    plt.plot(X,R)
    plt.plot(X,LS)
    plt.plot(X,RS)

def plot_data_fixed_rho(rho, n_plot = 200, n = 20000, T = 50):
    X = np.linspace(2, 100, n_plot)
    C,R,LS,RS = [],[],[],[]
    for a in X:
        _,_,c,r,ls,rs = plot_vianney(a, rho, 0, -1, n, T, False, False)
        C.append(c)
        R.append(r)
        LS.append(ls)
        RS.append(rs)
    plt.plot(X,C)
    plt.plot(X,R)
    plt.plot(X,LS)
    plt.plot(X,RS)

def add_data(
    start_a = 2,
    end_a = 100,
    n_a = 10, 
    start_rho = 0.001, 
    end_rho = 1, 
    n_rho = 10,
    n_plot = 2000,
    T = 30
):
    with open(output_file, 'a') as file:
        for a in np.linspace(2, 100, n_a):
            for rho in np.linspace(start_rho, end_rho, n_rho):
                _,_,c,r,ls,rs = plot_vianney(a, rho, n_plot, T, False, 1)
                print(*(f"{float(x):.5f}" for x in (a, rho, c, r, ls, rs)), file=file)

add_data(2,10,300,0,1,300,1000,100000)