import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit


def custom_fit(x, y, sigma, func, initial_guess):
    """
    Fits data to function that I can define
    """
    # Initial guess for the parameters (you may need to adjust this)
    # initial_guess = np.ones(func.__code__.co_argcount - 1)
    popt, pcov = curve_fit(
        func, x, y, sigma=sigma, p0=initial_guess, bounds=(0, np.inf)
    )

    return popt, pcov


# Example: Defining a custom function (e.g., exponential decay)
# first argument is the variable, the others the parameters to be estimated


def michaelismenten_subinh(S, V, Km, Ki):
    return V * S / (Km + S + (S**2 / Ki))


def calc_max(V, Km, Ki):
    return V / (1 + 2 * np.sqrt(Km / Ki))


def calc_Ks(Km, Ki):
    lmda = 2 + 4 * np.sqrt(Km / Ki)
    p = Ki * (1 - lmda) / 2
    s = -p - np.sqrt(p**2 - Km * Ki)
    return s


def calc_cmax(Km, Ki):
    cmax = np.sqrt(Km * Ki)
    return cmax


def michaelismenten_subinh_mod(S, Vmax, Smax, Ki):
    return (Ki + 2 * Smax) * Vmax * S / (Ki * S + Smax**2 + S**2)


# calculation Ks from smax and Ki
def calc_Ks_mod(Smax, Ki):
    w = np.sqrt(((Ki + 4 * Smax) ** 2 / 4) - Smax**2)
    Ks_mod = (Ki + 4 * Smax) / 2 - w
    return Ks_mod


def calc_Ks_mod_err(Smax, Ki, Smaxerr, Kierr):
    return np.sqrt(
        Kierr**2
        * (
            0.5
            - (Ki + 4 * Smax) / (2.0 * np.sqrt(Ki**2 + 8 * Ki * Smax + 12 * Smax**2))
        )
        ** 2
        + (2 - (2 * (Ki + 3 * Smax)) / np.sqrt(Ki**2 + 8 * Ki * Smax + 12 * Smax**2))
        ** 2
        * Smaxerr**2
    )


# calculation of Ks2 from smax and Ki
# substrate concentration that leady to half maximal activity after passing vmax
def calc_Ks2_mod(Smax, Ki):
    w = np.sqrt(((Ki + 4 * Smax) ** 2 / 4) - Smax**2)
    Ks_mod = (Ki + 4 * Smax) / 2 + w
    return Ks_mod


def calc_Ks2_mod_err(Smax, Ki, Smaxerr, Kierr):
    return np.sqrt(
        (
            Kierr
            + (Kierr * (Ki + 4 * Smax)) / np.sqrt(Ki**2 + 8 * Ki * Smax + 12 * Smax**2)
        )
        ** 2
        / 4.0
        + (2 + (2 * (Ki + 3 * Smax)) / np.sqrt(Ki**2 + 8 * Ki * Smax + 12 * Smax**2))
        ** 2
        * Smaxerr**2
    )
