#!/usr/bin/env python
"""
Attempt to compute a lower bound using fixed cost wbot for the bidding sequence
see Typst.app

Uses SciPy 
"""

import sys, math, os
import numpy as np
import scipy.sparse as sp
import time
import gurobipy as gp
from gurobipy import GRB
import random

solving_time = 0
building_time = 0

if (len(sys.argv) == 1):
    prog = sys.argv[0]
    print(f"Usage: {prog} [R]             -> will print dual variables" )
    print(f"or:    {prog} [low:high:step] -> will print pairs of robustness and consistency")
    print("option: -n [n] nb of bids")
    print("option: -b [b] range of bids")
    print("option: -s [f] writes the dual solution in the file f")
    print("option: --s    writes the dual in the default file")
    print("option: -w     Compute the minimal value of w instead of C")
    sys.exit(1)

w0 = 0.
lw0 = 0.
N = 400
BID_RANGE = 1.
SHOW_DUAL = False
DEFAULT_FILENAME = False 
FILENAME = ""
COMPUTE_COST = False
LINEAR = False


i = 2
while i < len(sys.argv):
    a = sys.argv[i]
    i += 1 
    if a == "-n":
        N = int(sys.argv[i])
        i += 1
    elif a == "-b":
        BID_RANGE = float(sys.argv[i])
        i += 1
    elif a == "-s":
        SHOW_DUAL = True
        FILENAME = sys.argv[i]
        i += 1 
    elif a == "--s":
        SHOW_DUAL = True
        DEFAULT_FILENAME = True
    elif a == "-w":
        w0 = float(sys.argv[i])
        i += 1
        COMPUTE_COST = True
    elif a == "-lw":
        lw0 = float(sys.argv[i])
        i += 1
    elif a == "-lin":
        LINEAR = True
    else:
        print("# Wrong argument", a)

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
collect_lb_file = os.path.join(project_root, "collect_lb.txt")
collect_cost_file = os.path.join(project_root, "collect_cost.txt")
collect_file = collect_cost_file if COMPUTE_COST else collect_lb_file
output_folder = os.path.join(project_root, "outputs", "lb_dual")
helper_folder = os.path.join(project_root, "data_helper")
cost_file = os.path.join(helper_folder, "minimal_cost.csv")
os.makedirs(output_folder, exist_ok=True)

def filename(R,b,n, filename = ""):
    if DEFAULT_FILENAME:
        if not(LINEAR):
            return os.path.join(output_folder, f"lb-R{R}-b{b}-n{n}.csv")
        else:
            return os.path.join(output_folder, f"lb-R{R}-b{b}-n{n}-lin.csv")
    else:
        return os.path.join(project_root, filename)

def uniform_bids(n, bid_range):
    return np.array([math.exp(bid_range * i / n) for i in range(1,n+1)])

def linear_bids(n, bid_range):
    return np.array([i/n * math.exp(bid_range) for i in range(1,n+1)])

def left_concentrated_bids(n, bid_range, l = 1):
    return np.array([math.exp(x) for x in np.linspace(0,bid_range - l, n-3)] + [math.exp(bid_range * i / n) for i in range(n-3,n)])

def concentrated_bids(n, bid_range, t, l = 0.8, width = 1):
    a = int(l * n)
    b = n - a 
    b1 = b//2
    b2 = b - b1
    return np.array(
        [math.exp(x) for x in np.linspace(0, t - width, b1, endpoint = False)]
        + [math.exp(t + width * x) for x in np.linspace(-1, 1, int(l * n), endpoint = False)] 
        + [math.exp(x) for x in np.linspace(t + width, bid_range, b2)]
    )

def add_bids(bids, i1, i2):
    new_bids = []
    for i in range(i1):
        new_bids.append(bids[i])
    for i in range(i1,i2):
        new_bids.append(bids[i])
        new_bids.append(math.sqrt(bids[i] * bids[i+1]))
    for i in range(i2,len(bids)):
        new_bids.append(bids[i])
    return new_bids 

def solve(R, n, bids):
    global solving_time
    global building_time

    t0 = time.time()
    p = int(0.7*n)

    # Create model
    model = gp.Model()
    model.setParam("OutputFlag", 0)  # silent mode (optional)

    # ----------------------------
    # Variables
    # ----------------------------
    nvar = 3 * n
    Y = model.addVars(n, lb=0.0, name="Y")
    Z = model.addVars(n, lb=0.0, name="Z")
    ZB = model.addVars(n, lb=0.0, name="ZB")

    model.update()

    vars_flat = model.getVars()

    # ----------------------------
    # Objective (minimize)
    # Same as your linprog version
    # ---------------------------------------------
    #  Y_(n - 1) - R Z_(n - 1) + R Z_p - R Z_(p - 1)
    # (implemented as minimization with flipped signs)
    # ----------------------------
    obj = np.zeros(nvar)
    obj[n-1] = -1
    obj[n + (n-1)] = R 
    obj[n + p] += -R
    obj[n + p - 1] = R
    obj[2*n + n-1] = -lw0 * bids[0]

    model.setObjective(obj @ vars_flat, GRB.MINIMIZE)

    # ----------------------------
    # Equality constraints
    # -----------------------------------------
    # b_j ZB_(j - 1) - b_j ZB_j - Z_(j - 1) + Z_j = 0
    # ----------------------------
    rows = []
    cols = []
    data = []
    for j in range(n):
        row = j

        rows.append(row)
        cols.append(2*n + j)
        data.append(-bids[j])

        rows.append(row)
        cols.append(n+j)
        data.append(1)

        if j > 0:
            rows.append(row)
            cols.append(2 * n + j-1)
            data.append(bids[j])

            rows.append(row)
            cols.append(n + j - 1)
            data.append(-1)

    Aeq = sp.csr_matrix((data, (rows, cols)), shape = (n, nvar))
    beq = np.zeros(n)
     
    model.addMConstr(Aeq, vars_flat, "=", beq)

    # ----------------------------
    # Inequality constraints
    # -------------------------------------
    # Y_k - Y_(i - 1) - b_k ZB_(n - 1) + b_k ZB_(i - 1) <= 0
    # ----------------------------
    nb_tri = n * (n+1) // 2
    nb_mono = 2*(n-1)
    nb_special = 1

    total_ineq = nb_tri + nb_mono + nb_special

    rows = []
    cols = []
    data = []
    b = np.zeros(total_ineq)

    row_id = 0
    for k in range(n):
        for i in range(k + 1):
            rows.append(row_id)
            cols.append(k)
            data.append(1)

            rows.append(row_id)
            cols.append(2*n + n - 1)
            data.append(-bids[k])

            if i > 0:
                rows.append(row_id)
                cols.append(i-1)
                data.append(-1)

                rows.append(row_id)
                cols.append(2*n + i - 1)
                data.append(bids[k])

            row_id += 1

    # ----------------------------
    # Monotonicity constraints
    # ----------------------------
    for j in range(1, n):
        rows.append(row_id)
        cols.append(j-1)
        data.append(1)
        
        rows.append(row_id)
        cols.append(j)
        data.append(-1)
        row_id += 1 
        
        rows.append(row_id)
        cols.append(n+j-1)
        data.append(1)
        
        rows.append(row_id)
        cols.append(n+j)
        data.append(-1)
        row_id += 1
    # ---------------------------------------------
    # Z_p - Z_(p - 1) <= 1
    # ---------------------------------------------
    rows.append(row_id)
    cols.append(n+p)
    data.append(1)

    rows.append(row_id)
    cols.append(n+p-1)
    data.append(-1)

    b[row_id] = 1

    Aub = sp.csr_matrix((data, (rows, cols)), shape = (total_ineq, nvar))

    model.addMConstr(Aub, vars_flat, "<", b)

    building_t = time.time()

    model.setParam("Method", 2)
    model.setParam("Crossover", 1)
    model.setParam("NumericFocus", 1) 
    model.optimize()

    solving_t = time.time()
    solving_time += solving_t - building_t
    building_time += building_t - t0

    return model

def solve_cost(R, n, bids, w0):
    global solving_time
    global building_time

    t0 = time.time()
    p = n//2

    # Create model
    model = gp.Model()
    model.setParam("OutputFlag", 0)  # silent mode (optional)

    # ----------------------------
    # Variables
    # ----------------------------
    nvar = 3 * n + 1
    model.addVars(n, lb=0.0, name="Y")
    model.addVars(n, lb=0.0, name="Z")
    model.addVars(n, lb=0.0, name="ZB")
    model.addVars(1, lb=0.0, name="Gamma")

    def Y(i): return i
    def Z(i): return i + n
    def ZB(i) : return i + 2*n
    gamma = 3*n

    model.update()

    vars_flat = model.getVars()

    # ----------------------------
    # Objective (minimize)
    # Same as your linprog version
    # ---------------------------------------------
    #  Y_(n - 1) - R Z_(n - 1) + R Z_p - R Z_(p - 1)
    # (implemented as minimization with flipped signs)
    # ----------------------------
    obj = np.zeros(nvar)
    obj[Y(n-1)] = -1
    obj[Z(n-1)] = R
    obj[ZB(n-1)] = -w0
    obj[gamma] = - w0 / bids[p]

    model.setObjective(obj @ vars_flat, GRB.MINIMIZE)

    # ----------------------------
    # Equality constraints
    # -----------------------------------------
    # b_j ZB_(j - 1) - b_j ZB_j - Z_(j - 1) + Z_j = 0
    # ----------------------------
    rows = []
    cols = []
    data = []
    for j in range(n):
        row = j

        rows.append(row)
        cols.append(2*n + j)
        data.append(-bids[j])

        rows.append(row)
        cols.append(n+j)
        data.append(1)

        if j > 0:
            rows.append(row)
            cols.append(2 * n + j-1)
            data.append(bids[j])

            rows.append(row)
            cols.append(n + j - 1)
            data.append(-1)

    Aeq = sp.csr_matrix((data, (rows, cols)), shape = (n, nvar))
    beq = np.zeros(n)
     
    model.addMConstr(Aeq, vars_flat, "=", beq)

    # ----------------------------
    # Inequality constraints
    # -------------------------------------
    # Y_k - Y_(i - 1) - b_k ZB_(n - 1) + b_k ZB_(i - 1) <= 0
    # ----------------------------
    nb_tri = n * (n+1) // 2
    nb_mono = 2*(n-1)
    nb_special = 1

    total_ineq = nb_tri + nb_mono + nb_special

    rows = []
    cols = []
    data = []
    b = np.zeros(total_ineq)

    row_id = 0
    for k in range(n):
        for i in range(k + 1):
            rows.append(row_id)
            cols.append(k)
            data.append(1)

            rows.append(row_id)
            cols.append(2*n + n - 1)
            data.append(-bids[k])

            if i > 0:
                rows.append(row_id)
                cols.append(i-1)
                data.append(-1)

                rows.append(row_id)
                cols.append(2*n + i - 1)
                data.append(bids[k])

            if k < p:
                rows.append(row_id)
                cols.append(gamma)
                data.append(-bids[k]/bids[p])

            row_id += 1

    # ----------------------------
    # Monotonicity constraints
    # ----------------------------
    for j in range(1, n):
        rows.append(row_id)
        cols.append(j-1)
        data.append(1)
        
        rows.append(row_id)
        cols.append(j)
        data.append(-1)
        row_id += 1 
        
        rows.append(row_id)
        cols.append(n+j-1)
        data.append(1)
        
        rows.append(row_id)
        cols.append(n+j)
        data.append(-1)
        row_id += 1

    rows.append(row_id)
    cols.append(gamma)
    data.append(1)
    b[row_id] = 1

    Aub = sp.csr_matrix((data, (rows, cols)), shape = (total_ineq, nvar))

    model.addMConstr(Aub, vars_flat, "<", b)

    building_t = time.time()

    model.setParam("Method", 2)
    model.setParam("Crossover", 1)
    model.setParam("NumericFocus", 1) 
    model.optimize()

    solving_t = time.time()
    solving_time += solving_t - building_t
    building_time += building_t - t0

    return model

def ln0(x):
    return math.log(x) if x > 0 else 0 

def consistency(R, n = N, bid_range = BID_RANGE):
    bids = []
    if not(LINEAR):
        bids = uniform_bids(n, bid_range)
    else:
        bids = linear_bids(n, bid_range)
    if SHOW_DUAL and (filename != "" or DEFAULT_FILENAME):
        output = open(filename(R,bid_range,n, FILENAME), "w")
    elif SHOW_DUAL:
        output = sys.stdout
    else:
        output = open(os.devnull, 'w')

    model = solve(R, n, bids)

    if SHOW_DUAL:
        write_dual(model, R, n, bid_range, bids)

    return -model.ObjVal

def write_dual(model, R, n, bid_range, bids):
    with open(filename(R, bid_range, n), "w") as file:
        print("# bid y z", file = file)
        i = 0
        print(math.log(bids[i]), model.x[i], model.x[n + i], file = file)
        for i in range(1, n):
            print(math.log(bids[i]), model.x[i] - model.x[i-1], model.x[n + i] - model.x[n + i -1], file = file)


def handle(R):
    if COMPUTE_COST:
        model = solve_cost(R, N, uniform_bids(N, BID_RANGE), w0)
        print(f"For Robustness {R}, from {w0} to {- model.ObjVal}")
    else:
        res = consistency(R, N, BID_RANGE)
        print(f"For Robustness {R}, the consistency is {res}")
        with open(collect_lb_file, "a") as file:
            print(R, res, file=file)


arg = sys.argv[1]
if ":" in arg:
    low, high, step = map(float, arg.split(":"))
    for R in np.arange(low, high, step):
        handle(R)
else:
    R = float(arg)
    handle(R)
        
print(f"# Building time : {building_time}s")
print(f"# Solving time : {solving_time}s")