import os
import time
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn_extra.cluster import KMedoids

# --- CONFIGURATION ---
RESULTS_FILE = "resultats_kmedoids.csv"
BATCH_SIZE = 30
N_CPUS = 30
N_TRIALS = 5  # Nombre de répétitions pour chaque k


def solve_single_k(k, dist_matrix):
    t0 = time.time()
    N = dist_matrix.shape[0]

    # Cas trivial : k = N
    if k == N:
        medoids_str = ";".join(map(str, range(N)))
        return [k, 0.0, time.time() - t0, medoids_str]

    best_inertia = np.inf
    best_medoids = None

    # On lance le calcul 10 fois manuellement pour trouver le meilleur F_k*
    for _ in range(N_TRIALS):
        model = KMedoids(
            n_clusters=k,
            metric='precomputed',
            method='alternate',
            init='k-medoids++',
            max_iter=300
        )
        model.fit(dist_matrix)

        if model.inertia_ < best_inertia:
            best_inertia = model.inertia_
            best_medoids = model.medoid_indices_

    duration = time.time() - t0
    medoids_str = ";".join(map(str, best_medoids))

    return [k, best_inertia, duration, medoids_str]


def run_parallel_kmedoids(dist_matrix_file):
    print("# 1. Chargement de la matrice de distances...")
    dist_matrix = np.load(dist_matrix_file)
    N = dist_matrix.shape[0]

    print("# 2. Vérification de l'état d'avancement...")
    computed_k = set()
    if os.path.exists(RESULTS_FILE):
        try:
            df_old = pd.read_csv(RESULTS_FILE)
            # On s'assure que k est bien lu comme entier
            computed_k = set(df_old['k'].astype(int).unique())
        except:
            pass
    else:
        with open(RESULTS_FILE, "w") as f:
            f.write("k,cout_F_k*,temps_calcul_s,F_k*\n")

    # Liste des k restant à calculer
    tasks = [k for k in range(1, N + 1) if k not in computed_k]

    if not tasks:
        print("# Tous les calculs sont déjà terminés.")
        return

    print(f"# Lancement : {len(tasks)} valeurs de k (Lots de {BATCH_SIZE} | {N_CPUS} CPUs | {N_TRIALS} essais/k).")

    for i in range(0, len(tasks), BATCH_SIZE):
        batch = tasks[i: i + BATCH_SIZE]

        # Utilisation de backend="loky" pour une gestion efficace de la mémoire partagée
        results = Parallel(n_jobs=N_CPUS, backend="loky")(
            delayed(solve_single_k)(k, dist_matrix) for k in batch
        )

        with open(RESULTS_FILE, "a") as f:
            for res in results:
                f.write(f"{res[0]},{res[1]},{res[2]:.2f},{res[3]}\n")
            f.flush()

        print(f"  > Batch fini : k={batch[0]} à k={batch[-1]} sauvegardés.")


if __name__ == "__main__":
    FICHIER_MATRICE = "matrice_distances.npy"
    if os.path.exists(FICHIER_MATRICE):
        run_parallel_kmedoids(FICHIER_MATRICE)
    else:
        print(f"Erreur : {FICHIER_MATRICE} introuvable.")
