#!/usr/bin/env python3
"""
Attempt to compute a lower bound using fixed cost wbot for the bidding sequence
see Typst.app

Uses SciPy 
"""

import sys, math
import numpy as np
import scipy.special
import time
from scipy.optimize import linprog
from scipy.sparse import lil_matrix
import random

solving_time = 0
building_time = 0

N = 400
BID_RANGE = 1.
SHOW_DUAL = False


def uniform_bids(n, bid_range):
    return np.array([math.exp(bid_range * i / n) for i in range(n)])

def left_concentrated_bids(n, bid_range, l = 1):
    return np.array([math.exp(x) for x in np.linspace(0,bid_range - l, n-3)] + [math.exp(bid_range * i / n) for i in range(n-3,n)])

def concentrated_bids(n, bid_range, t, l = 0.8, width = 1):
    a = int(l * n)
    b = n - a 
    b1 = b//2
    b2 = b - b1
    return np.array(
        [math.exp(x) for x in np.linspace(0, t - width, b1, endpoint = False)]
        + [math.exp(t + width * x) for x in np.linspace(-1, 1, int(l * n), endpoint = False)] 
        + [math.exp(x) for x in np.linspace(t + width, bid_range, b2)]
    )

def solve(R, n, bids):
    global solving_time
    global building_time

    t0 = time.time()
    p = n-1

    N = n**2 + 1
    def X(i,k): return i*n + k + 1 
    C = 0

    obj = np.zeros(N)
    obj[C] = 1  

    nb_ub = 2 * n
    A_ub = lil_matrix((nb_ub, N))
    b_ub = np.zeros(nb_ub)

    bounds = [(0, 1)] * N
    bounds[0] = (0,None)

    row_id = 0
    for t in range(n):
        for i in range(t+1):
            for k in range(t,n):
                A_ub[row_id, X(i,k)] = -1
        b_ub[row_id] = -1
        row_id += 1

    for t in range(n):
        for i in range(t+1):
            for k in range(i,n):
                A_ub[row_id, X(i,k)] = bids[k]
        if t != p:
            b_ub[row_id] = R * bids[t]
        else:
            A_ub[row_id, C] = -bids[p]
        row_id += 1

    assert row_id == nb_ub

    A_ub=A_ub.tocsr()

    building_t = time.time()

    res = linprog(
        obj,
        A_ub=A_ub,
        b_ub=b_ub,
        method="highs-ipm",
    )
    solving_t = time.time()
    solving_time += solving_t - building_t
    building_time += building_t - t0

    return res

i = 2
while i < len(sys.argv):
    a = sys.argv[i]
    i += 1 
    if a == "-n":
        N = int(sys.argv[i])
        i += 1
    elif a == "-b":
        BID_RANGE = float(sys.argv[i])
        i += 1
    elif a == "-s":
        SHOW_DUAL = True 
    else:
        print("# Wrong argument", a)


arg = sys.argv[1]
R = float(arg)
res = solve(R, N, uniform_bids(N, BID_RANGE))
if SHOW_DUAL:
    for i in range(N):
        for k in range(i,N):
            if res.x[i * N + k + 1] > 0:
                print(f"{i} {k} : {res.x[i * N + k + 1]}")
print("status : ", res.status)
print(R, res.fun)