from __future__ import annotations

import logging
import os
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import typer
from module_qc_data_tools import (
    __version__,
    load_json,
    outputDataFrame,
    qcDataFrame,
    save_dict_list,
)

from module_qc_analysis_tools.cli.globals import (
    CONTEXT_SETTINGS,
    OPTIONS,
    FitMethod,
    LogLevel,
)
from module_qc_analysis_tools.utils.analysis import (
    get_nominal_current,
    get_nominal_RextA,
    get_nominal_RextD,
    get_nominal_Voffs,
    perform_qc_analysis,
    submit_results,
)
from module_qc_analysis_tools.utils.misc import (
    DataExtractor,
    JsonChecker,
    bcolors,
    get_inputs,
    get_qc_config,
    get_time_stamp,
    linear_fit,
    linear_fit_np,
)

app = typer.Typer(context_settings=CONTEXT_SETTINGS)


@app.command()
def main(
    input_meas: Path = OPTIONS["input_meas"],
    base_output_dir: Path = OPTIONS["output_dir"],
    qc_criteria_path: Path = OPTIONS["qc_criteria"],
    layer: str = OPTIONS["layer"],
    permodule: bool = OPTIONS["permodule"],
    submit: bool = OPTIONS["submit"],
    site: str = OPTIONS["site"],
    nChips: int = OPTIONS["nchips"],
    fit_method: FitMethod = OPTIONS["fit_method"],
    verbosity: LogLevel = OPTIONS["verbosity"],
    lp_enable: bool = OPTIONS["lp_enable"],
):
    log = logging.getLogger(__name__)
    log.setLevel(verbosity.value)

    if submit and site == "":
        log.error(
            bcolors.ERROR
            + "You have supplied the --submit option without specifying --site (testing site). Please specify your testing site if you would like to submit these results"
            + bcolors.ENDC
        )
        return

    log.info("")
    log.info(" =======================================")
    log.info(" \tPerforming SLDO analysis")
    log.info(" =======================================")
    log.info("")

    test_type = os.path.basename(__file__).split(".py")[0]

    # SLDO parameters
    RextA = get_nominal_RextA(layer)
    RextD = get_nominal_RextD(layer)

    allinputs = get_inputs(input_meas, test_type)
    qc_config = get_qc_config(qc_criteria_path, test_type)

    time_start = round(datetime.timestamp(datetime.now()))
    output_dir = base_output_dir.joinpath(test_type).joinpath(f"{time_start}")
    output_dir.mkdir(parents=True, exist_ok=False)

    alloutput = []
    timestamps = []
    for filename in sorted(allinputs):
        log.info("")
        log.info(f" Loading {filename}")
        meas_timestamp = get_time_stamp(filename)
        inputDFs = load_json(filename)

        log.debug(
            f" There are results from {len(inputDFs)} chip(s) stored in this file"
        )
        for inputDF in inputDFs:
            qcframe = inputDF.get_results()
            metadata = qcframe.get_meta_data()

            try:
                kShuntA = (
                    metadata.get("ChipConfigs")
                    .get("RD53B")
                    .get("Parameter")
                    .get("KShuntA")
                )
                log.debug(f" Found kShuntA = {kShuntA} from chip config")
            except Exception:
                log.warning(
                    bcolors.WARNING
                    + " No KShuntA parameter found in chip metadata. Using default KShuntA = 1040"
                    + bcolors.ENDC
                )
                kShuntA = 1040

            try:
                kShuntD = (
                    metadata.get("ChipConfigs")
                    .get("RD53B")
                    .get("Parameter")
                    .get("KShuntD")
                )
                log.debug(f" Found kShuntD = {kShuntD} from chip config")
            except Exception:
                log.warning(
                    bcolors.WARNING
                    + " No KShuntD parameter found in chip metadata. Using default KShuntD = 1040"
                    + bcolors.ENDC
                )
                kShuntD = 1040.0

            try:
                chipname = metadata.get("Name")
                log.debug(f" Found chip name = {chipname} from chip config")
            except Exception:
                log.error(
                    bcolors.ERROR
                    + f" Chip name not found in input from {filename}, skipping."
                    + bcolors.ENDC
                )
                continue

            R_eff = 1.0 / ((kShuntA / RextA) + (kShuntD / RextD)) / nChips

            Vofs = get_nominal_Voffs(layer, lp_enable)

            p = np.poly1d([R_eff, Vofs])
            p1 = np.poly1d([R_eff, 0])

            """"" Check file integrity  """ ""
            checker = JsonChecker(inputDF, test_type)

            try:
                checker.check()
            except BaseException as exc:
                log.exception(exc)
                log.error(
                    bcolors.ERROR
                    + " JsonChecker check not passed, skipping this input."
                    + bcolors.ENDC
                )
                continue
            else:
                log.debug(" JsonChecker check passed!")
                pass

            """""  Calculate quanties   """ ""
            extractor = DataExtractor(inputDF, test_type)
            calculated_data = extractor.calculate()

            passes_qc = True

            # Plot parameters
            Iint_max = (
                max(
                    max(calculated_data["Iref"]["Values"] * 100000),
                    max(calculated_data["IcoreD"]["Values"]),
                    max(calculated_data["IcoreA"]["Values"]),
                    max(calculated_data["IshuntD"]["Values"]),
                    max(calculated_data["IshuntA"]["Values"]),
                    max(calculated_data["IinD"]["Values"]),
                    max(calculated_data["IinA"]["Values"]),
                )
                + 0.5
            )
            I_max = max(calculated_data["SetCurrent"]["Values"]) + 0.5
            I_min = min(calculated_data["SetCurrent"]["Values"]) - 0.5
            V_max = (
                max(
                    max(calculated_data["VrefOVP"]["Values"]),
                    max(calculated_data["Vofs"]["Values"]),
                    max(calculated_data["VDDD"]["Values"]),
                    max(calculated_data["VDDA"]["Values"]),
                    max(calculated_data["VinD"]["Values"]),
                    max(calculated_data["VinA"]["Values"]),
                )
                + 2.0
            )
            T_min = min(0.0, min(calculated_data["Temperature"]["Values"]))
            T_max = max(calculated_data["Temperature"]["Values"]) + 1.0

            # Internal voltages visualization
            fig, ax1 = plt.subplots()
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["VinA"]["Values"],
                marker="o",
                markersize=4,
                label="VinA",
                color="tab:red",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["VinD"]["Values"],
                marker="o",
                markersize=4,
                label="VinD",
                color="tab:red",
                linestyle="--",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["VDDA"]["Values"],
                marker="o",
                markersize=4,
                label="VDDA",
                color="tab:blue",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["VDDD"]["Values"],
                marker="o",
                markersize=4,
                label="VDDD",
                color="tab:blue",
                linestyle="--",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["Vofs"]["Values"],
                marker="o",
                markersize=4,
                label="Vofs",
                color="tab:orange",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["VrefOVP"]["Values"],
                marker="o",
                markersize=4,
                label="VrefOVP",
                color="tab:cyan",
            )

            xp = np.linspace(I_min, I_max, 1000)
            ax1.plot(
                xp,
                p(xp),
                label=f"V = {R_eff:.3f} I + {Vofs:.2f}",
                color="tab:brown",
                linestyle="dotted",
            )
            ax1.set_xlabel("I [A]")
            ax1.set_ylabel("V [V]")
            plt.title(f"VI curve for chip: {chipname}")
            plt.xlim(I_min, I_max)
            ax1.set_ylim(0.0, V_max)
            ax1.legend(loc="upper left", framealpha=0)
            plt.grid()

            ax2 = ax1.twinx()
            ax2.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["Temperature"]["Values"],
                marker="^",
                markersize=4,
                color="tab:green",
                label="Temperature (NTC)",
                linestyle="-.",
            )
            ax2.set_ylabel("T [C]")
            ax2.set_ylim(T_min, T_max)
            ax2.legend(loc="upper right", framealpha=0)

            plt.tight_layout()
            outfile = output_dir.joinpath(f"{chipname}_VI.png")
            log.info(f" Saving {outfile}")
            plt.savefig(outfile)
            plt.close()

            ax1.cla()
            ax2.cla()

            # Internal currents visualization
            fig, ax1 = plt.subplots()
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IinA"]["Values"],
                marker="o",
                markersize=4,
                label="IinA",
                color="tab:red",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IinD"]["Values"],
                marker="o",
                markersize=4,
                label="IinD",
                color="tab:red",
                linestyle="--",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IshuntA"]["Values"],
                marker="o",
                markersize=4,
                label="IshuntA",
                color="tab:blue",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IshuntD"]["Values"],
                marker="o",
                markersize=4,
                label="IshuntD",
                color="tab:blue",
                linestyle="--",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IcoreA"]["Values"],
                marker="o",
                markersize=4,
                label="IcoreA",
                color="tab:orange",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["IcoreD"]["Values"],
                marker="o",
                markersize=4,
                label="IcoreD",
                color="tab:orange",
                linestyle="--",
            )
            ax1.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["Iref"]["Values"] * 100000,
                marker="o",
                markersize=4,
                label="Iref*100k",
                color="tab:cyan",
            )
            ax1.set_xlabel("I [A]")
            ax1.set_ylabel("I [A]")
            plt.title(f"Currents for chip: {chipname}")

            plt.xlim(I_min, I_max)
            plt.ylim(0.0, Iint_max)
            ax1.legend(loc="upper left", framealpha=0)
            plt.grid()

            ax2 = ax1.twinx()
            ax2.plot(
                calculated_data["SetCurrent"]["Values"],
                calculated_data["Temperature"]["Values"],
                marker="^",
                markersize=4,
                color="tab:green",
                label="Temperature (NTC)",
                linestyle="-.",
            )
            ax2.set_ylabel("T [C]")
            ax2.set_ylim(T_min, T_max)
            ax2.legend(loc="upper right", framealpha=0)

            plt.tight_layout()
            outfile = output_dir.joinpath(f"{chipname}_II.png")
            log.info(f" Saving {outfile}")
            plt.savefig(outfile)
            plt.close()

            ax1.cla()
            ax2.cla()

            # SLDO fit
            VinAvg = (
                calculated_data["VinA"]["Values"] + calculated_data["VinD"]["Values"]
            ) / 2.0
            if fit_method.value == "root":
                slope, offset = linear_fit(
                    calculated_data["SetCurrent"]["Values"], VinAvg
                )
            else:
                slope, offset = linear_fit_np(
                    calculated_data["SetCurrent"]["Values"], VinAvg
                )

            # Residual analysis
            residual_VinA = (
                p1(calculated_data["SetCurrent"]["Values"])
                - (
                    calculated_data["VinA"]["Values"]
                    - calculated_data["Vofs"]["Values"]
                )
            ) * 1000
            residual_VinD = (
                p1(calculated_data["SetCurrent"]["Values"])
                - (
                    calculated_data["VinD"]["Values"]
                    - calculated_data["Vofs"]["Values"]
                )
            ) * 1000
            residual_VinA_nomVofs = (
                p(calculated_data["SetCurrent"]["Values"])
                - calculated_data["VinA"]["Values"]
            ) * 1000
            residual_VinD_nomVofs = (
                p(calculated_data["SetCurrent"]["Values"])
                - calculated_data["VinD"]["Values"]
            ) * 1000
            residual_Vin = p1(calculated_data["SetCurrent"]["Values"]) - (
                VinAvg - calculated_data["Vofs"]["Values"]
            )
            residual_Vofs = (Vofs - calculated_data["Vofs"]["Values"]) * 1000
            res_max = (
                max(
                    max(residual_VinA_nomVofs),
                    max(residual_VinD_nomVofs),
                    max(residual_VinA),
                    max(residual_VinD),
                    max(residual_Vofs),
                )
                + 20
            )
            res_min = (
                min(
                    min(residual_VinA_nomVofs),
                    min(residual_VinD_nomVofs),
                    min(residual_VinA),
                    min(residual_VinD),
                    min(residual_Vofs),
                )
                - 10
            )

            plt.plot(
                calculated_data["SetCurrent"]["Values"],
                residual_VinA_nomVofs,
                marker="o",
                markersize=4,
                label=f"{R_eff:.3f}I+{Vofs:.2f}-VinA",
                color="tab:red",
            )
            plt.plot(
                calculated_data["SetCurrent"]["Values"],
                residual_VinD_nomVofs,
                marker="o",
                markersize=4,
                label=f"{R_eff:.3f}I+{Vofs:.2f}-VinD",
                color="tab:red",
                linestyle="--",
            )
            plt.plot(
                calculated_data["SetCurrent"]["Values"],
                residual_VinA,
                marker="o",
                markersize=4,
                label=f"{R_eff:.3f}I+Vofs-VinA",
                color="tab:blue",
            )
            plt.plot(
                calculated_data["SetCurrent"]["Values"],
                residual_VinD,
                marker="o",
                markersize=4,
                label=f"{R_eff:.3f}I+Vofs-VinD",
                color="tab:blue",
                linestyle="--",
            )
            plt.plot(
                calculated_data["SetCurrent"]["Values"],
                residual_Vofs,
                marker="o",
                markersize=4,
                label=f"{Vofs}-Vofs",
                color="tab:orange",
            )
            plt.xlabel("I [A]")
            plt.ylabel("V [mV]")
            plt.title(f"VI curve for chip: {chipname}")
            plt.xlim(I_min, I_max)
            plt.ylim(res_min, res_max)
            plt.legend(loc="upper right", framealpha=0)
            plt.grid()
            plt.tight_layout()
            outfile = output_dir.joinpath(f"{chipname}_VIresidual.png")
            log.info(f" Saving {outfile}")
            plt.savefig(outfile)
            plt.close()

            # Find point measured closest to nominal input current
            sldo_nom_input_current = get_nominal_current(layer)
            log.debug(f" Calculated nominal current to be: {sldo_nom_input_current}")
            idx = (
                np.abs(calculated_data["SetCurrent"]["Values"] - sldo_nom_input_current)
            ).argmin()
            log.debug(
                f' Closest current measured to nominal is: {calculated_data["SetCurrent"]["Values"][idx]}'
            )

            # Load values to dictionary for QC analysis
            results = {}
            results.update({"SLDOLinearity": max(residual_Vin)})
            results.update(
                {
                    "VInA_VInD": max(
                        abs(
                            calculated_data["VinA"]["Values"]
                            - calculated_data["VinD"]["Values"]
                        )
                    )
                }
            )
            results.update({"SLDO_VDDA": calculated_data["VDDA"]["Values"][idx]})
            results.update({"SLDO_VDDD": calculated_data["VDDD"]["Values"][idx]})
            results.update({"SLDO_VINA": calculated_data["VinA"]["Values"][idx]})
            results.update({"SLDO_VIND": calculated_data["VinD"]["Values"][idx]})
            results.update({"SLDO_VOFFS": calculated_data["Vofs"]["Values"][idx]})
            results.update({"SLDO_IINA": calculated_data["IinA"]["Values"][idx]})
            results.update({"SLDO_IIND": calculated_data["IinD"]["Values"][idx]})
            results.update({"SLDO_IREF": calculated_data["Iref"]["Values"][idx] * 1e6})
            results.update({"SLDO_ISHUNTA": calculated_data["IshuntA"]["Values"][idx]})
            results.update({"SLDO_ISHUNTD": calculated_data["IshuntD"]["Values"][idx]})

            # Perform QC analysis
            passes_qc = perform_qc_analysis(
                test_type, qc_config, layer, results, verbosity.value
            )
            if passes_qc == -1:
                log.error(
                    bcolors.ERROR
                    + f" QC analysis for {chipname} was NOT successful. Please fix and re-run. Continuing to next chip.."
                    + bcolors.ENDC
                )
                continue
            log.info("")
            if passes_qc:
                log.info(
                    f" Chip {chipname} passes QC? "
                    + bcolors.OKGREEN
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            else:
                log.info(
                    f" Chip {chipname} passes QC? "
                    + bcolors.BADRED
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            log.info("")

            """"" Output a json file """ ""
            outputDF = outputDataFrame()
            outputDF.set_test_type(test_type)
            data = qcDataFrame()
            data._meta_data.update(metadata)
            data.add_property(
                "ANALYSIS_VERSION",
                __version__,
            )
            # Load values to store in output file
            analog_overhead = calculated_data["IshuntA"]["Values"][idx] / (
                calculated_data["IinA"]["Values"][idx]
                - calculated_data["IshuntA"]["Values"][idx]
            )
            digital_overhead = calculated_data["IshuntD"]["Values"][idx] / (
                calculated_data["IinD"]["Values"][idx]
                - calculated_data["IshuntD"]["Values"][idx]
            )
            data.add_parameter("SLDO_VI_SLOPE", round(slope, 4))
            data.add_parameter("SLDO_VI_OFFSET", round(offset, 4))
            data.add_parameter(
                "SLDO_NOM_INPUT_CURRENT", round(sldo_nom_input_current, 4)
            )
            data.add_parameter(
                "SLDO_VDDA", round(calculated_data["VDDA"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_VDDD", round(calculated_data["VDDD"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_VINA", round(calculated_data["VinA"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_VIND", round(calculated_data["VinD"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_VOFFS", round(calculated_data["Vofs"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_IINA", round(calculated_data["IinA"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_IIND", round(calculated_data["IinD"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_IREF", round(calculated_data["Iref"]["Values"][idx] * 1e6, 4)
            )
            data.add_parameter(
                "SLDO_ISHUNTA", round(calculated_data["IshuntA"]["Values"][idx], 4)
            )
            data.add_parameter(
                "SLDO_ISHUNTD", round(calculated_data["IshuntD"]["Values"][idx], 4)
            )
            data.add_parameter("SLDO_ANALOG_OVERHEAD", round(analog_overhead, 4))
            data.add_parameter("SLDO_DIGITAL_OVERHEAD", round(digital_overhead, 4))

            outputDF.set_results(data)
            outputDF.set_pass_flag(passes_qc)
            if submit:
                submit_results(
                    outputDF.to_dict(True),
                    time_start,
                    site,
                    output_dir.joinpath("submit.txt"),
                )
            if permodule:
                alloutput += [outputDF.to_dict(True)]
                timestamps += [meas_timestamp]
            else:
                outfile = output_dir.joinpath(f"{chipname}.json")
                log.info(f" Saving output of analysis to: {outfile}")
                save_dict_list(outfile, [outputDF.to_dict(True)])
    if permodule:
        # Only store results from same timestamp into same file
        dfs = np.array(alloutput)
        tss = np.array(timestamps)
        for x in np.unique(tss):
            outfile = output_dir.joinpath("module.json")
            log.info(f" Saving output of analysis to: {outfile}")
            save_dict_list(
                outfile,
                dfs[tss == x].tolist(),
            )


if __name__ == "__main__":
    typer.run(main)
