#!/usr/bin/env python3
import numpy as np
import matplotlib.pyplot as plt

class PiecewiseLinearEditor:
    def __init__(self):
        self.fig, (self.ax_f, self.ax_inv) = plt.subplots(1, 2, figsize=(10, 4))

        x_min = 0.
        x_max = 10.

        y_min = 0.
        y_max = 10.

        self.t1 = np.linspace(x_min, x_max, 400)
        self.t2 = np.linspace(y_min, y_max, 400)

        self.R = 0
        self.C = 0

        self.x = np.array([x_min, x_max])
        self.y = np.array([y_min, y_max])

        self.selected = None
        self.eps = 0.15  # distance threshold

        self.line_f, = self.ax_f.plot(self.x, self.y, '-o')
        self.line_inv, = self.ax_inv.plot(self.t2, self.CR(self.t2))

        self.ax_f.set_title("f(x)")
        self.ax_inv.set_title("f⁻¹(y)")

        self.text_info = self.ax_inv.text(
            0.02, 0.98, "",
            transform=self.ax_inv.transAxes,
            verticalalignment='top',
            fontsize=11,
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8)
        )

        for ax in (self.ax_f, self.ax_inv):
            ax.set_xlim(x_min-1, x_max+1)
            ax.set_ylim(y_min-1, y_max+1)

        self.fig.canvas.mpl_connect("button_press_event", self.on_press)
        self.fig.canvas.mpl_connect("motion_notify_event", self.on_motion)
        self.fig.canvas.mpl_connect("button_release_event", self.on_release)
        self.update()

    def update(self):
        y_data = self.CR(self.t2)
        self.C = min(y_data)
        self.R = max(y_data)
        self.line_f.set_data(self.x, self.y)
        self.line_inv.set_data(self.t2, y_data)
        self.text_info.set_text(f"C = {self.C:.4f}\nR = {self.R:.4f}")
        self.fig.canvas.draw_idle()

    # CR(f(t)) = [ integral (-inf -> t+1) exp(f(x)) dx ]
    def CR_of_f(self, t):
        t = np.asarray(t)
        sol = []
        for tj in t:
            # need to compute integral from -inf to x_0 of exp(a x + c)
            x_0 = self.x[0]
            x_1 = self.x[1]
            y_0 = self.y[0]
            y_1 = self.y[1]
            a = (y_1 - y_0)/(x_1 - x_0)
            c = y_0 - a * x_0
            integral = np.exp(a * x_0 + c)/a
            for i in range(len(self.x) - 1):
                x_i = self.x[i] 
                y_i = self.y[i]
                x_ii = self.x[i+1]
                y_ii = self.y[i+1]
                a = (y_ii - y_i)/(x_ii - x_i)
                c = y_i - a * x_i
                # integral of exp(a x + c) from x_i to x_ii = exp(c)/a * (exp(a x_ii) - exp(a x_i))
                if tj+1 >= x_ii:
                    integral += (np.exp(a * x_ii + c) - np.exp(a * x_i + c)) / a
                else :
                    integral += (np.exp(a * (tj+1) + c) - np.exp(a * x_i + c)) / a
                    break
            if tj + 1  >= self.x[-1]:
                # need to compute integral from x[-1] to tj+1 of exp(a x + c)
                x_0 = self.x[-2]
                x_1 = self.x[-1]
                y_0 = self.y[-2]
                y_1 = self.y[-1]
                a = (y_1 - y_0)/(x_1 - x_0)
                c = y_0 - a * x_0
                integral += (np.exp(a * (tj+1) + c) - np.exp(a * x_1 + c)) / a
            
            sol.append(integral/np.exp(np.interp(tj, self.x, self.y)))
        return sol

    def CR(self, t):
        t = np.asarray(t)
        f_inv = np.interp(t, self.y, self.x)
        return self.CR_of_f(f_inv)

    def find_point(self, event):
        if event.inaxes != self.ax_f:
            return None
        d = np.hypot(self.x - event.xdata, self.y - event.ydata)
        idx = np.argmin(d)
        return idx if d[idx] < self.eps else None

    def on_press(self, event):
        if event.inaxes != self.ax_f:
            return

        idx = self.find_point(event)

        if idx is None:
            self.add_point(event)
            self.selected = None
        else:
            self.selected = idx

    def on_motion(self, event):
        if self.selected is None or event.inaxes != self.ax_f:
            return

        i = self.selected
        
        xmin = self.x[i-1] + 1e-3 if i > 0 else -np.inf
        xmax = self.x[i+1] - 1e-3 if i < len(self.x)-1 else np.inf
        ymin = self.y[i-1] + 1e-3 if i > 0 else -np.inf
        ymax = self.y[i+1] - 1e-3 if i < len(self.y)-1 else np.inf
        print(self.x[i], self.y[i])

        self.x[i] = np.clip(event.xdata, xmin, xmax)
        self.y[i] = np.clip(event.ydata, ymin, ymax)

        self.update()

    def on_release(self, event):
        self.selected = None

    def add_point(self, event):
        x_new, y_new = event.xdata, event.ydata
        if x_new is None or y_new is None:
            return

        idx = np.searchsorted(self.x, x_new)

        if (idx > 0 and y_new <= self.y[idx-1]) or \
           (idx < len(self.y) and y_new >= self.y[idx]):
            return

        self.x = np.insert(self.x, idx, x_new)
        self.y = np.insert(self.y, idx, y_new)

        self.update()

editor = PiecewiseLinearEditor()
plt.show()
