import time
import math
import numpy as np
from scipy.special import lambertw

import gurobipy as gp
from gurobipy import GRB


def solve_primal_gurobi_cumsum_fast_scaled(
    N: int,
    M: int,
    a: int,
    R: float,
    msg: bool = False,
    method: int | None = None,
) -> float:
    """
    Primal with cumulative prefix sums.

    Same scaling as the original code:
        S[u] / a <= R x[j]

    but removes the variable C and minimizes directly:
        S[min(a-1,M)] / a.
    """

    assert N >= 0
    assert M >= 0
    assert a >= 1
    assert -N <= 0 <= M

    indices = range(-N, M + 1)

    model = gp.Model("primal_cumsum_fast_scaled")
    model.Params.OutputFlag = int(msg)
    model.Params.IgnoreNames = 1

    if method is not None:
        model.Params.Method = method

    x = model.addVars(indices, lb=0.0)
    S = model.addVars(indices, lb=0.0)

    u_star = min(a - 1, M)

    # Objective: minimize S[u_star] / a
    model.setObjective(S[u_star] / a, GRB.MINIMIZE)

    # x_0 >= 1
    model.addConstr(x[0] >= 1.0)

    # x_{i+1} >= x_i
    for i in range(-N, M):
        model.addConstr(x[i + 1] >= x[i])

    # Cumulative sums
    model.addConstr(S[-N] == x[-N])

    for i in range(-N + 1, M + 1):
        model.addConstr(S[i] == S[i - 1] + x[i])

    # Robustness constraints, with the original scaling
    for j in indices:
        u = min(j + a - 1, M)
        model.addConstr(S[u] / a <= R * x[j])

    model.optimize()

    if model.Status != GRB.OPTIMAL:
        raise RuntimeError(f"Gurobi primal failed with status {model.Status}")

    return model.ObjVal


def solve_dual_gurobi_cum_beta_fast_scaled(
    N: int,
    M: int,
    a: int,
    R: float,
    msg: bool = False,
    method: int | None = None,
) -> float:
    """
    Dual with cumulative tail beta variables.

    Keeps the original scaling:

        1_{k <= a-1}/a
        + B_ell/a
        + gamma_k - gamma_{k-1}
        >= lambda 1_{k=0} + R(B_k - B_{k+1})
    """

    assert N >= 0
    assert M >= 0
    assert a >= 1
    assert -N <= 0 <= M

    indices = range(-N, M + 1)
    gamma_indices = range(-N, M)

    model = gp.Model("dual_cum_beta_fast_scaled")
    model.Params.OutputFlag = int(msg)
    model.Params.IgnoreNames = 1

    if method is not None:
        model.Params.Method = method

    lam = model.addVar(lb=0.0)
    B = model.addVars(indices, lb=0.0)
    gamma = model.addVars(gamma_indices, lb=0.0)

    model.setObjective(lam, GRB.MAXIMIZE)

    def B_next(k: int):
        if k < M:
            return B[k + 1]
        return 0.0

    def gamma_at(i: int):
        if -N <= i <= M - 1:
            return gamma[i]
        return 0.0

    for k in indices:
        ell = max(k + 1 - a, -N)

        lhs = (
            (1.0 if k <= a - 1 else 0.0) / a
            + B[ell] / a
            + gamma_at(k)
            - gamma_at(k - 1)
        )

        rhs = (
            (lam if k == 0 else 0.0)
            + R * (B[k] - B_next(k))
        )

        model.addConstr(lhs >= rhs)

    # beta_k >= 0  <=>  B_k >= B_{k+1}
    for k in range(-N, M):
        model.addConstr(B[k] >= B[k + 1])

    model.optimize()

    if model.Status != GRB.OPTIMAL:
        raise RuntimeError(f"Gurobi dual failed with status {model.Status}")

    return lam.X

def wtop_from_R(R: float) -> float: 
    if R < math.e: 
        raise ValueError("Need R >= e.") 
    return float((-1.0 / lambertw(-1.0 / R, 0)).real) 

def compute_value_closed_form(R: float) -> float: 
    if R > 2/math.log(2): return R - wtop_from_R(R) 
    wtop = wtop_from_R(R) 
    z = -math.exp(1.0 / wtop - 2.0) 
    s = -lambertw(z, -1).real 
    return R / s




if __name__ == "__main__":
    a = 1000
    N = 10 * a
    M = 10 * a

    n = 30
    R_min = math.e
    R_max = 2.76

    p = 4  # larger => denser near e

    u = np.linspace(0, 1, n)
    R_values = R_min + (R_max - R_min) * u**p

    for R in R_values:
        t1 = time.process_time()
        C_dual = solve_dual_gurobi_cum_beta_fast_scaled(
            N, M, a, R, method=0
        )
        t2 = time.process_time()
        with open("outputs/low_bound.csv", 'a') as file:
            file.write(f"{R}, {C_dual}, {a}, {N}, {M}, {t2 - t1}\n")