#!/usr/bin/env python3
import sys, math, pulp, scipy, os
from scipy.optimize import linprog
from scipy.sparse import lil_matrix
import matplotlib.pyplot as plt
import time
import numpy as np
import csv

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
helper_folder = os.path.join(project_root, "data_helper")
R_lb_up_filename = os.path.join(helper_folder, "R_lb_up.csv")

def get_known_lb_up(R):
    best_lb = 1
    best_up = math.exp(1)
    with open(R_lb_up_filename, mode='r', encoding='utf-8') as file:
        reader = csv.reader((row for row in file if not row.startswith('#')), delimiter = " ")
        for row in reader:
            R_csv, lb, up = map(float, row)
            if R_csv >= R and lb > best_lb:
                best_lb = lb
            if R_csv <= R and up < best_up:
                best_up = up
    return best_lb, best_up            

def wbot(R):
    return -1/scipy.special.lambertw(-1/R,-1).real

def wtop(R):
    return -1/scipy.special.lambertw(-1/R).real


t0 = time.time()
solving_time = 0

default_n = True
default_a = True
default_p = True
factor = False

a = 200 # each interval is divided in a equal parts
n = 20 # number of intervals
eps = 1e-3 # dichotomy precision
N = a * n
p = []
display_bool = False
verbose = False
w_delta = 1

if (len(sys.argv) == 1):
    prog = sys.argv[0]
    print(f"Usage: {prog} [R]             -> dichotomy to find best consistency" )
    print(f"or:    {prog} [low:high:step] -> find consistency for a range of robustness ratios" )
    print(f"or:    {prog} [R,C]           -> solve LP for robustness R and consistency C" )
    print("option: -n    [n]   nb of intervals")
    print("option: -a    [a]   nb of subdiv of intervals")
    print("option: -d          display the ratio at the end")
    print("option: -w    [delta]       use affine combination between wbot (delta=0) and wtop (delta=1)")
    print("option: -p    [p1,p2,...,pk] set predictions values")
    print("option: -prec [eps] set a precision for dichotomy")
    print("option: -verbose    display various informations")
    sys.exit(1)

# parse command line arguments
i = 2
while i < len(sys.argv):
    opt = sys.argv[i]
    i += 1 
    if opt == "-n":
        n = int(sys.argv[i])
        i += 1
        default_n = False
    elif opt == "-a":
        a = int(sys.argv[i])
        i += 1
        default_a = False
    elif opt == "-d":
        display_bool = True 
    elif opt == "-w":
        w_delta = float(sys.argv[i])
        i += 1
    elif opt == "-p":
        p = list(map(float, sys.argv[i].split(",")))
        i += 1
        default_p = False
    elif opt == "-prec":
        eps = float(sys.argv[i])
        i += 1
    elif opt == "-verbose":
        verbose = True
    else:
        print("Wrong argument", opt)
        i += 1

if default_p:
    p = [((n+1)//2) * a]
else:
    for i in range(len(p)):
        pred = p[i]
        if pred <= 0:
            p[i] = (n + int(pred)) * a
        else:
            p[i] = int(pred * n * a) 
N = a * n

if verbose:
    print(f"predictions are at indices : ", end="")
    for pred in p:
        print(pred, end = "")
    print()



def feasibility(R, C, display_bool = False):
    global solving_time

    wb = wbot(R) + 1e-2
    wt = wtop(R) - 1e-2
    w = w_delta * wt + (1 - w_delta) * wb

    n_var = N
    def XX(i): 
        return i
    obj = np.zeros(n_var)
    nb_pred = len(p)
    nb_constraint_ub = 2*N - a - 2
    nb_constraint_eq = a 
    A_eq = lil_matrix((nb_constraint_eq, n_var))
    b_eq = np.zeros(nb_constraint_eq)
    A_ub = lil_matrix((nb_constraint_ub, n_var))
    b_ub = np.zeros(nb_constraint_ub)
    bounds = [(None,None)] * n_var

    row_id = 0

    for i in range(a):
        A_eq[row_id, XX(i)] = 1
        b_eq[row_id] = a * (w * math.exp((i + 1)/(a * w)) - w)
        row_id += 1

    assert row_id == nb_constraint_eq
    row_id = 0


    # X_(i+2) - X_(i+1) >= X_(i+1) - X_i
    for i in range(N-2):
        A_ub[row_id, XX(i+1)] = 2
        A_ub[row_id, XX(i)] = -1
        A_ub[row_id, XX(i+2)] = -1
        row_id += 1


    # w + X_(i+a) / a <= R * (X_i - X_(i-1))
    for i in range(1,N-a):
        ratio = R
        if i in p:
            ratio = C
        A_ub[row_id, XX(i+a)] = 1/a 
        A_ub[row_id, XX(i)] = - ratio
        A_ub[row_id, XX(i-1)] = + ratio
        b_ub[row_id] = - w
        row_id += 1
    
    # w + X_a / a <= R * X_0
    A_ub[row_id, XX(a)] = 1/a 
    A_ub[row_id, XX(0)] = - R
    b_ub[row_id] = - w
    row_id += 1 

    assert row_id == nb_constraint_ub

    t0_solving_time = time.time()

    res = linprog(
        obj,
        A_ub=A_ub.tocsr(),
        b_ub=b_ub,
        A_eq=A_eq.tocsr(),
        b_eq = b_eq,
        method="highs-ds"
    )

    solving_time += time.time() - t0_solving_time

    if display_bool:
        display([res.x[XX(0)]] + [res.x[XX(i)] - res.x[XX(i-1)] for i in range(1,N)], [res.x[XX(i)] for i in range(N)],w)

    return res.success

def display(x, X, w):
    if verbose:
        print(f"{w=}, {1/w=}")
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
    ABS = [i/a for i in range(N-a)]
    Y = [math.log(x[i]) for i in range(N-a)]
    Y2 = [(w + X[i] * 1./a)/x[i] for i in range(N-a)]
    Y3 = [(w + X[i+a] * 1./a) / x[i] for i in range(N-a)]
    Y4 = [x[i] for i in range(N-a)]
    ax1.plot(ABS,Y)
    ax1.set_title("log scale")
    ax2.plot(ABS,Y2)
    ax2.set_title("cost w")
    ax3.plot(Y,Y3)
    ax3.set_title("ratio")
    ax4.set_title("linear scale")
    ax4.plot(ABS,Y4)

    pred = p[0]
    # analysis of the log part 
    x1_ind = pred - 5
    x0_ind = pred - a + 5
    x0 = ABS[x0_ind]
    x1 = ABS[x1_ind]
    y0 = x[x0_ind]
    y1 = x[x1_ind]
    slope = (y1 - y0)/(x1 - x0)
    constant = y1 - x1 * slope
    log_const = math.log(slope)
    log_offset = constant/slope
    if verbose :
        print(f"slope is {slope} and constant is {constant}")
        print(f"log_offset is {log_offset} and log const is {log_const}")
        print(f"After vert/horiz shifting, log_offset is {log_offset + ABS[pred-a]} and log_const is {log_const - math.log(y0)}")
    plt.show()

def consistency(R, eps = eps):
    wb = wbot(R)
    wt = wtop(R)
    if verbose:
        print(f"wbot is {wb} | wtop is {wt}")
    Cmin,Cmax = get_known_lb_up(R)
    if verbose:
        print(f"Got lower bound {Cmin} and up bound {Cmax} from helper file")
        print(f"Check feasibility of {Cmax}")
    if Cmax < math.exp(1):
        while not(feasibility(R, Cmax, False)):
            Cmax = Cmin + (Cmax - Cmin) * 2
            Cmin = (Cmin + Cmax)/2
            if verbose :
                print(f"Rolled back initial up bound from {Cmin} to {Cmax}")
    while Cmax - Cmin > eps:
        C = (Cmin + Cmax)/2
        if verbose:
            print(f"checking feasibility of {R=}, {C=}...")
        if feasibility(R, C, False):
            Cmax = C
        else:
            Cmin = C
    if display_bool:
        feasibility(R, Cmax, True)
    return Cmax

# two modes, range on R or single given R 
arg = sys.argv[1]
if ":" in arg:
    low, high, step = map(float, arg.split(":"))
    R = low
    while R <= high + 1e-6:
        print(R, consistency(R))
        R += step
elif "," in arg:
    R, C = map(float, arg.split(","))
    if feasibility(R, C, display_bool):
        print("Feasible")
    else:
        print("Unfeasible")
else:
    R = float(arg)
    print(R, consistency(R))

t1 = time.time()
if verbose:   
    print(f"Time elapsed : {t1 - t0} s")
    print(f"Solving time : {solving_time} s")