#!/usr/bin/env python3
"""
Attempt to compute a lower bound using fixed cost wbot for the bidding sequence
see Typst.app

Uses SciPy 
"""

import sys, math, os
import numpy as np
import scipy.special
import time
from scipy.optimize import linprog
from scipy.sparse import lil_matrix
import random

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
output_folder = os.path.join(project_root, "outputs", "lb_dual")

solving_time = 0
building_time = 0

if (len(sys.argv) == 1):
    prog = sys.argv[0]
    print(f"Usage: {prog} [R]             -> will print dual variables" )
    print(f"or:    {prog} [low:high:step] -> will print pairs of robustness and consistency")
    print("option: -n [n] nb of bids")
    print("option: -b [b] range of bids")
    print("option: -g     improved precision around the final gap")
    print("option: -s [f] writes the dual solution in the file f")
    print("option: --s    writes the dual in the default file")
    sys.exit(1)

N = 400
BID_RANGE = 1.
SHOW_DUAL = False
GAPS = False
DEFAULT_FILENAME = False 
FILENAME = ""

i = 2
while i < len(sys.argv):
    a = sys.argv[i]
    i += 1 
    if a == "-n":
        N = int(sys.argv[i])
        i += 1
    elif a == "-b":
        BID_RANGE = float(sys.argv[i])
        i += 1
    elif a == "-s":
        SHOW_DUAL = True
        FILENAME = sys.argv[i]
        i += 1 
    elif a == "--s":
        SHOW_DUAL = True
        DEFAULT_FILENAME = True
    elif a == "-g":
        GAPS = True
    else:
        print("# Wrong argument", a)

def filename(R,b,n, filename = ""):
    if DEFAULT_FILENAME:
        return os.path.join(output_folder, f"lb-R{R}-b{b}-n{n}.csv")
    else:
        return os.path.join(project_root, filename)

def uniform_bids(n, bid_range):
    return np.array([math.exp(bid_range * i / n) for i in range(n)])

def left_concentrated_bids(n, bid_range, l = 1):
    return np.array([math.exp(x) for x in np.linspace(0,bid_range - l, n-3)] + [math.exp(bid_range * i / n) for i in range(n-3,n)])

def concentrated_bids(n, bid_range, t, l = 0.8, width = 1):
    a = int(l * n)
    b = n - a 
    b1 = b//2
    b2 = b - b1
    return np.array(
        [math.exp(x) for x in np.linspace(0, t - width, b1, endpoint = False)]
        + [math.exp(t + width * x) for x in np.linspace(-1, 1, int(l * n), endpoint = False)] 
        + [math.exp(x) for x in np.linspace(t + width, bid_range, b2)]
    )

def add_bids(bids, i1, i2):
    new_bids = []
    for i in range(i1):
        new_bids.append(bids[i])
    for i in range(i1,i2):
        new_bids.append(bids[i])
        new_bids.append(math.sqrt(bids[i] * bids[i+1]))
    for i in range(i2,len(bids)):
        new_bids.append(bids[i])
    return new_bids 

def solve(R, n, bids):
    global solving_time
    global building_time

    t0 = time.time()
    p = n-2
    # Variable layout:
    N = 3 * n
    def YY(i): return 0*n + i
    def ZZ(i): return 1*n + i
    def ZB(i): return 2*n + i

    # Objective (maximize → minimize negative)
    # --------------------------------------------- Y_(n - 1) - R Z_(n - 1) + R Z_p - R Z_(p - 1)
    obj = np.zeros(N)
    obj[YY(n - 1)] = -1     # sum of y_j
    obj[ZZ(n - 1)] = R      # -R times sum of z_j
    obj[ZZ(p)] += -R         # except for j == p
    obj[ZZ(p - 1)] = R      # except for j == p

    # Equality constraints
    nb_eq = n
    A_eq = lil_matrix((nb_eq, N))
    b_eq = np.zeros(nb_eq)

    # Relate ZB to ZZ
    for j in range(n):
        # ----------------------------------------- b_j Z^b_(j - 1) - b_j Z^b_j - Z_(j - 1) + Z_j = 0
        A_eq[j, ZB(j)] = -bids[j]
        A_eq[j, ZZ(j)] = 1
        if j > 0:
            A_eq[j, ZB(j - 1)] = bids[j]            # not bids[j - 1]
            A_eq[j, ZZ(j - 1)] = -1
        
    # Inequality constraints
    #       forall i <= k       range(1,n)   z_p
    nb_ub = n * (n + 1) // 2 + 2 * (n - 1) + 1
    A_ub = lil_matrix((nb_ub, N))
    b_ub = np.zeros(nb_ub)

    row_id = 0
    for k in range(n):
        for i in range(k + 1):
            # ------------------------------------- Y_k - Y_(i - 1) - b_k Z^b_(n - 1) + b_k Z^b_(i - 1) <= 0
            A_ub[row_id, YY(k)] = 1
            A_ub[row_id, ZB(n - 1)] = -bids[k]
            if i > 0:
                A_ub[row_id, YY(i - 1)] = -1
                A_ub[row_id, ZB(i - 1)] += bids[k]
            row_id += 1

    for j in range(1, n):
        # ----------------------------------------- Y_(j - 1) - Y_j <= 0
        A_ub[row_id, YY(j - 1)] = 1
        A_ub[row_id, YY(j)] = -1
        row_id += 1
        # ----------------------------------------- Z_(j - 1) - Z_j <= 0
        A_ub[row_id, ZZ(j - 1)] = 1
        A_ub[row_id, ZZ(j)] = -1
        row_id += 1
    # --------------------------------------------- Z_(p) - Z_(p - 1) <= 1
    A_ub[row_id, ZZ(p)] = 1
    A_ub[row_id, ZZ(p - 1)] = -1
    b_ub[row_id] = 1
    row_id += 1 

    assert row_id == nb_ub

    # variables are non-negative
    bounds = [(0, None)] * N
    A_ub=A_ub.tocsr()
    A_eq=A_eq.tocsr()

    building_t = time.time()

    res = linprog(
        obj,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        bounds=bounds,
        # method="highs-ds",
        method="highs-ipm"  ,
    )
    solving_t = time.time()
    solving_time += solving_t - building_t
    building_time += building_t - t0
    return res

def compute_gaps(R, n, bid_range, eps = 1e-2):
    def YY(i): return 0*n + i
    bids = uniform_bids(n, bid_range)
    start = 0
    end = 1
    while end - start > eps:
        print(start, end)
        res = solve(R, n, bids)
        j = 0
        for i in range(1,n):
            if res.x[YY(i)] - res.x[YY(i - 1)] < 1e-10:
                j = i
                break
        start = math.log(bids[j-1])
        end = math.log(bids[j])
        bids = add_bids(bids, j-20, j+20)
        n = len(bids)
    print(start, end)
    return (bid_range - end) / math.log(R)

def ln0(x):
    return math.log(x) if x > 0 else 0 

def consistency(R, n = N, bid_range = BID_RANGE):
    # bids = left_concentrated_bids(n, bid_range, 1.15)
    bids = uniform_bids(n, bid_range)
    if SHOW_DUAL and (filename != "" or DEFAULT_FILENAME):
        output = open(filename(R,bid_range,n, FILENAME), "w")
    elif SHOW_DUAL:
        output = sys.stdout
    else:
        output = open(os.devnull, 'w')

    def YY(i): return 0*n + i
    def ZZ(i): return 1*n + i
    def ZB(i): return 2*n + i

    res = solve(R, n, bids)

    if res.status == 3:
        print("⚠️ LP is unbounded (as expected).", file=sys.stderr)
        return float("inf"), float('inf')

    if not res.success:
        raise RuntimeError(res.message)

    if SHOW_DUAL:
        print("# bid","y","z", file=output)
        y_0 = res.x[YY(0)]
        z_0 = res.x[ZZ(0)]
        print(math.log(bids[0]), y_0, z_0, file=output)
        for i in range(1, n):
            y_i = res.x[YY(i)] - res.x[YY(i - 1)]
            z_i = res.x[ZZ(i)] - res.x[ZZ(i - 1)]
            print(math.log(bids[i]), y_i, z_i, file=output)
    T = 0
    for i in range(1,n):
        if res.x[YY(i)] - res.x[YY(i - 1)] < 1e-10:
            T = bids[i]
            break
    TZ = []
    for i in range(1, n-1):
        z_1, z0, z1 = (res.x[ZZ(j)] - res.x[ZZ(j-1)] for j in (i-1, i, i+1))
        if z0 > z_1 and z0 > z1:
            TZ.append((bid_range - math.log(bids[i])) / math.log(R))
    print(f"# local maxima for z at bids {R=} {TZ[-4:]}", file=output)
    print("# ratio between T and prediction :", bid_range - math.log(T), file=output)
    print("# lower bound : ", -res.fun, file=output)
    return -res.fun, bid_range - math.log(T)

def handle(R):
    if GAPS:
        gap = compute_gaps(R, N, BID_RANGE)
        print(f"For Robustness {R}, the gap T is {gap}.")
    else:
        res, T_offset = consistency(R, N, BID_RANGE)
        print(f"For Robustness {R}, the consistency is {res}")


arg = sys.argv[1]
if ":" in arg:
    low, high, step = map(float, arg.split(":"))
    for R in np.arange(low, high, step):
        handle(R)
else:
    R = float(arg)
    handle(R)
        
print(f"# Building time : {building_time}s")
print(f"# Solving time : {solving_time}s")