#!/usr/bin/env python3
import matplotlib.pyplot as plt
import math
import numpy as np
from sys import argv 
import os
import re

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ratio_folder = os.path.join(project_root, "outputs", "ratio")


def plot_pareto(n, R_max):
    filename = "outputs/separated_pareto_" + str(n) + ".txt"
    X,Y = [],[]
    with open(filename, "r") as f:
        for line in f:
            l = list(map(float, line.rstrip().split(" ")))
            R,C = l[0], l[1]
            if R <= R_max:
                X.append(R)
                Y.append(C)
    plt.plot(X,Y, ".")

def plot_file(filename, R_max):
    X,Y = [],[]
    with open(filename, "r") as f:
        for line in f:
            R,C = map(float, line.rstrip().split(" "))
            if R <= R_max:
                X.append(R)
                Y.append(C)
    plt.plot(X,Y, ".")

def ln0(x):
    if x <= 0:
        return 0
    else:
        return math.log(x)

def plot_lb(filename):
    X,Y1,Y2 = [],[],[]
    with open(filename, "r") as f:
        for line in f:
            if "#" in line:
                continue
            x,y,z = map(float, line.rstrip().split(" "))
            X.append(x)
            Y1.append(y)
            Y2.append(z)

    XX = [0] * len(X)
    for i in range(len(X)):
        XX[i] = math.exp(X[i])
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
    Y1_plot = Y1
    Y2_plot = Y2
    ax1.plot(X,Y1_plot, '-', label="f")
    ax1.plot(X,Y2_plot, '--', label="g")
    ax1.legend()
    ax2.plot(XX,Y1_plot, '-', label="f")
    ax2.plot(XX,Y2_plot, '--', label="g")
    ax2.legend()

    
def plot_deterministic(R_max):
    plot_file("outputs/deterministic_pareto.txt", R_max)

def plot_uniform_separated(R_max):
    plot_file("outputs/uniform_separated_pareto.txt", R_max)

def plot_lower_bound(R_max):
    plot_file("outputs/stochastic_lower_bound.txt", R_max)

def plot_contiguous_pareto(R_max):
    plot_file("outputs/contiguous_pareto.txt", R_max)

def plot_lower_bound(R,b):
    plot_lb(f"outputs/lb-R{R}-b{b}.csv")

def plot_vianney(R):
    a_ar, rho_ar, c_ar, r_ar, ls_ar, rs_ar = [],[],[],[],[],[]
    with open("outputs/smoothness/vianney.csv", "r") as f:
        for line in f:
            a, rho, c, r, ls, rs = map(float, line.split())
            a_ar.append(a)
            rho_ar.append(rho)
            c_ar.append(c)
            r_ar.append(r)
            ls_ar.append(ls)
            rs_ar.append(rs)
    min_c = min(c_ar)
    max_c = max(c_ar)
    C_abs = np.linspace(min_c,max_c,200)
    Y = []
    for C in C_abs:
        best_smoothness = float("inf")
        for i in range(len(a_ar)):
            if c_ar[i] <= C and r_ar[i] <= R:
                smth = max(ls_ar[i],rs_ar[i])
                if smth < best_smoothness:
                    best_smoothness = smth
        Y.append(best_smoothness)
    plt.plot(C_abs,Y)

def extract_R(filename):
    if not filename or filename[0] != "R":
        return None

    s = ""
    i = 1
    while i < len(filename) and (filename[i].isdigit() or filename[i] == "."):
        s += filename[i]
        i += 1

    if s == "" or s == ".":
        return None

    try:
        return float(s)
    except ValueError:
        return None
def closest_file_below(target_R):
    best_name = None
    best_R = None

    for name in os.listdir(ratio_folder):
        full_path = os.path.join(ratio_folder, name)
        if not os.path.isfile(full_path):
            continue

        R = extract_R(name)
        if R is None:
            continue

        if R <= target_R and (best_R is None or R > best_R):
            best_R = R
            best_name = name

    return best_name, best_R

def plot_smoothness(R):
    filename, best_R = closest_file_below(R)
    filename = os.path.join(ratio_folder, filename)
    x,y = [],[]
    with open(filename, 'r') as file:
        for line in file:
            a, b = map(float, line.split())
            x.append(a)
            y.append(b)
    i_min = np.argmin(y)
    i = i_min
    C = []
    rs = []
    ls = []
    while y[i] < np.exp(1) - 0.01:
        max_l_slope = 0
        max_r_slope = 0
        for j in range(i):
            slope = abs(y[j] - y[i])/(x[i] - x[j])
            if slope > max_l_slope:
                max_l_slope = slope
        for j in range(i_min+1, len(x)):
            slope = (y[j] - y[i])/(x[j] - x[i])
            if slope > max_r_slope:
                max_r_slope = slope
        C.append(y[i])
        ls.append(max(max_l_slope,max_r_slope))
        i -= 1
    plt.plot(C,ls)

        

if __name__ == "__main__":
    filename = argv[1]
    plot_lb(filename)
    plt.show()
