#!/usr/bin/env python3
import sys, math, os
from scipy.special import lambertw
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
import time
import numpy as np
import csv
import gurobipy as gp
from gurobipy import GRB

if (len(sys.argv) == 1):
    prog = sys.argv[0]
    print(f"Usage: {prog} [R]             -> dichotomy to find best consistency" )
    print(f"or:    {prog} [low:high:step] -> find consistency for a range of robustness ratios" )
    print("option: -n    [n]   nb of intervals")
    print("option: -a    [a]   nb of subdiv of intervals")
    print("option: -d          display the ratio at the end")
    print("option: -w    [delta]       use affine combination between wbot (delta=0) and wtop (delta=1)")
    print("option: -p    [p1,p2,...,pk] set predictions values")
    print("option: -prec [eps] set a precision for dichotomy")
    print("option: -verbose    display various informations")
    sys.exit(1)

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ratio_folder = os.path.join(project_root, "outputs", "ratio")
seq_folder = os.path.join(project_root, "outputs", "seq")

def wbot(R):
    return -1/lambertw(-1/R,-1).real

def wtop(R):
    return -1/lambertw(-1/R).real       

class Instance:
    def __init__(self, n = 10, a = 200, p = 0, w_delta = 1):
        self.n = n
        self.a = a
        self.N = n * a
        if p == 0:
            self.p = self.N - 8 * self.a
        else:
            self.p = p
        self.w_delta = w_delta

    def set_R(self, R):
        self.R = R

    def init_model(self):
        wb = wbot(R) + 1e-2
        wt = wtop(R) - 1e-2
        self.w = wb + self.w_delta * (wt - wb)
        if self.w_delta < 0:
            self.w = 1

        N = self.N
        a = self.a
        w = self.w
        p = self.p

        env = gp.Env(empty=True)
        env.setParam('OutputFlag', 0)
        env.start()

        self.model = gp.Model(env = env)
        model = self.model
        model.setParam("Method", 1)
        model.setParam("NumericFocus", 3)
        model.setParam("FeasibilityTol", 1e-8)
        model.setParam("OptimalityTol", 1e-8)

        # Variables 
        self.X = model.addMVar(N+a, name="X")
        X = self.X

        # Objective
        obj = np.zeros(N+a)
        model.setObjective(obj @ self.X, GRB.MINIMIZE)
        

        # Equality constraints
        rows = []
        cols = []
        data = []
        b_eq = np.zeros(2 * a)

        for i in range(a):
            rows.append(i)
            cols.append(i)
            data.append(1)

            b_eq[i] = a * (w * math.exp((i + 1) / (a * w)) - w)
        
        # X_i - X_(i-1) = (X_(N-1) - X(N-2)) * exp((i-N)/(a * wtop))  
        row_id = a
        for i in range(N+1, N+a):
            rows += [row_id, row_id, row_id, row_id]
            cols += [i, i-1, N-1, N-2]
            data += [1, -1, -math.exp((i-N)/(a * 1)), math.exp((i-N)/(a * 1))]
            row_id += 1

        rows += [row_id, row_id, row_id]
        cols += [N, N-1, N-2]
        data += [1, -2, 1]
        row_id += 1

        A_eq = csr_matrix((data, (rows, cols)), shape=(2*a, N+a))

        model.addMConstr(A_eq, self.X, "=", b_eq)

        # Inequality constraints
        nb_constraint_ub = 2*N + a - 2

        rows = []
        cols = []
        data = []
        b_ub = np.zeros(nb_constraint_ub)

        row_id = 0

        # X_(i+2) - X_(i+1) >= X_(i+1) - X_i
        # -> -X_(i+2) + 2X_(i+1) - X_i <= 0
        for i in range(N+a-2):
            rows += [row_id, row_id, row_id]
            cols += [i, i+1, i+2]
            data += [-1, 2, -1]

            row_id += 1

        # w + X_(i+a)/a <= ratio * (X_i - X_(i-1))
        for i in range(1, N):
            if i != p:
                rows += [row_id, row_id, row_id]
                cols += [i+a, i, i-1]
                data += [1/a, -self.R, self.R]
                b_ub[row_id] = -w
                row_id += 1


        # w + X_a/a <= R X_0
        rows += [row_id, row_id]
        cols += [a, 0]
        data += [1/a, -R]
        b_ub[row_id] = -w
        row_id += 1

        A_ub = csr_matrix((data, (rows, cols)), shape=(nb_constraint_ub, N+a))

        model.addMConstr(A_ub, self.X, "<=", b_ub)

        self.C_constraint = model.addConstr(X[p+a] * 1./a - math.exp(1) * (X[p] - X[p-1]) <= -w)

    def change_consistency(self, newC):
        self.model.chgCoeff(self.C_constraint, self.X[self.p], -newC)
        self.model.chgCoeff(self.C_constraint, self.X[self.p-1], newC)
        self.model.update()

    def feasible(self):
        self.model.optimize()
        success = self.model.Status == GRB.OPTIMAL
        if success:
            self.X_sol = np.array(self.X.X)
        return success

    def display(self):
        X = self.X_sol
        w = self.w
        a = self.a 
        N = self.N
        x = [X[0]] + [X[i+1] - X[i] for i in range(N-1)]
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
        ABS = [i/a for i in range(N-a)]
        Y = [math.log(x[i]) for i in range(N-a)]
        Y2 = [(w + X[i] * 1./a)/x[i] for i in range(N-a)]
        Y3 = [(w + X[i+a] * 1./a) / x[i] for i in range(N-a)]
        Y4 = [x[i] for i in range(N-a)]
        U = np.exp(np.linspace(math.log(min(x)) + 0.1, math.log(max(x)) - 0.1, 1000))
        ratio_U = []
        i = 0
        for u in U:
            while x[i] < u:
                i += 1
            i -= 1
            # x[i] < u <= x[i+1] 
            ratio_U.append((w + X[i+a] * 1./a)/u)
        ax1.plot(ABS,Y)
        ax1.set_title("log scale")
        ax2.plot(ABS,Y2)
        ax2.set_title("cost w")
        ax3.plot(ABS,x[:N-a])
        ax3.set_title("lin scale")
        ax4.plot(np.log(U),ratio_U)
        ax4.set_title("ratio_u")
        plt.show()

    def write_ratio(self):
        filename = os.path.join(ratio_folder, f"R{self.R}-n{self.n}-a{self.a}.csv")
        X = self.X_sol
        w = self.w
        a = self.a 
        N = self.N
        with open(filename, "w") as file:
            x = [X[0]] + [X[i+1] - X[i] for i in range(N-1)]
            U = np.exp(np.linspace(math.log(min(x)) + 0.1, math.log(max(x)) - 0.1, 1000))
            ratio_U = []
            i = 0
            for u in U:
                while x[i] < u:
                    i += 1
                i -= 1
                # x[i] < u <= x[i+1] 
                ratio_U.append((w + X[i+a] * 1./a)/u)
            for i in range(len(U)):
                print(f"{np.log(U[i])} {ratio_U[i]}", file = file)
    
    def write_seq(self):
        filename = os.path.join(seq_folder, f"R{self.R}-n{self.n}-a{self.a}.csv")
        X = self.X_sol
        N = self.N
        with open(filename, "w") as file:
            x = [X[0]] + [X[i+1] - X[i] for i in range(N-1)]
            for xx in x:
                print(f"{xx}", file = file)

eps = 1e-3 
display_bool = False
verbose = False 
n = 10
a = 200
w_delta = 1
set_C = -1

# parse command line arguments
i = 2
while i < len(sys.argv):
    opt = sys.argv[i]
    i += 1 
    if opt == "-n":
        n = int(sys.argv[i])
        i += 1
    elif opt == "-a":
        a = int(sys.argv[i])
        i += 1
    elif opt == "-d":
        display_bool = True 
    elif opt == "-w":
        w_delta = float(sys.argv[i])
        i += 1
    elif opt == "-prec":
        eps = float(sys.argv[i])
        i += 1
    elif opt == "-verbose":
        verbose = True
    elif opt == "-c":
        set_C = float(sys.argv[i])
        i += 1
    else:
        print("Wrong argument", opt)
        i += 1

N = a * n
p = N - 8 * a

def consistency(R, eps = eps):
    instance = Instance(n, a, p, w_delta)
    instance.set_R(R)
    instance.init_model()
    Cmin,Cmax = 1,math.exp(1)
    if verbose:
        print(f"{wbot(R)=}    {wtop(R)=}")
    while Cmax - Cmin > eps:
        C = (Cmin + Cmax)/2
        if verbose:
            print(f"checking feasibility of {R=}, {C=}...")
        instance.change_consistency(C)
        if instance.feasible():
            Cmax = C
        else:
            Cmin = C
    if display_bool:
        instance.display()
    instance.write_ratio()
    instance.write_seq()
    return Cmax

def handle_R(R):
    C = consistency(R)
    print(f"For Robustness {R} got consistency {C}")

# two modes, range on R or single given R 
arg = sys.argv[1]
if ":" in arg:
    low, high, step = map(float, arg.split(":"))
    R = low
    while R <= high + 1e-6:
        handle_R(R)
        R += step
else:
    R = float(arg)
    handle_R(R)