from __future__ import annotations

import logging
import os
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import typer
from module_qc_data_tools import load_json, outputDataFrame, qcDataFrame, save_dict_list

from module_qc_analysis_tools import __version__
from module_qc_analysis_tools.cli.globals import (
    CONTEXT_SETTINGS,
    OPTIONS,
    FitMethod,
    LogLevel,
)
from module_qc_analysis_tools.utils.analysis import perform_qc_analysis, submit_results
from module_qc_analysis_tools.utils.misc import (
    DataExtractor,
    JsonChecker,
    bcolors,
    get_inputs,
    get_qc_config,
    get_time_stamp,
    getImuxMap,
    getVmuxMap,
    linear_fit,
    linear_fit_np,
)

app = typer.Typer(context_settings=CONTEXT_SETTINGS)

log = logging.getLogger(__name__)

EMPTY_VAL = -1


def get_NtcCalPar(metadata):
    # Read NTC parameters from metadata in the chip config.
    if "NtcCalPar" in metadata:
        NtcCalPar = metadata.get("NtcCalPar")
    else:
        NtcCalPar = [
            0.0007488999981433153,
            0.0002769000129774213,
            7.059500006789676e-08,
        ]
        log.warning(
            bcolors.WARNING
            + " No NtcCalPar found in the input config file! Using the default NTC parameters."
            + bcolors.ENDC
        )
    return NtcCalPar


def get_NfPar(metadata):
    # Read Nf parameters from metadata in the chip config.
    NfPar = {}
    if "NfASLDO" in metadata:
        NfPar["NfASLDO"] = metadata.get("NfASLDO")
    else:
        NfPar["NfASLDO"] = 1.264
        log.warning(
            bcolors.WARNING
            + " No NfASLDO found in the input config file! Using the default Nf parameter value 1.264."
            + bcolors.ENDC
        )
    if "NfDSLDO" in metadata:
        NfPar["NfDSLDO"] = metadata.get("NfDSLDO")
    else:
        NfPar["NfDSLDO"] = 1.264
        log.warning(
            bcolors.WARNING
            + " No NfASLDO found in the input config file! Using the default Nf parametervalue 1.264."
            + bcolors.ENDC
        )
    if "NfACB" in metadata:
        NfPar["NfACB"] = metadata.get("NfACB")
    else:
        NfPar["NfACB"] = 1.264
        log.warning(
            bcolors.WARNING
            + " No Nfacb found in the input config file! Using the default Nf parameter value 1.264."
            + bcolors.ENDC
        )

    return NfPar


def calculate_T(calculated_data, NtcCalPar, NfPar):
    # Calculate T External NTC
    Vntc = np.array(calculated_data["Vntc"]["Values"])
    Intc = np.array(calculated_data["Intc"]["Values"])

    Rntc = np.mean(Vntc / Intc)
    A = NtcCalPar[0]
    B = NtcCalPar[1]
    C = NtcCalPar[2]
    AR_TEMP_NTC = 1 / (A + B * np.log(Rntc) + C * ((np.log(Rntc)) ** 3)) - 273.15

    log.debug(f" T Ext NTC: {AR_TEMP_NTC} C")

    # Calculate T External External NTC
    AR_TEMP_EXT = np.mean(np.array(calculated_data["TExtExtNTC"]["Values"]))

    log.debug(f" T Ext Ext NTC: {AR_TEMP_EXT} C")

    # Calculate T MOS sensors
    Vmux14 = np.array(calculated_data["VMonSensAna"]["Values"])
    Vmux16 = np.array(calculated_data["VMonSensDig"]["Values"])
    Vmux18 = np.array(calculated_data["VMonSensAcb"]["Values"])

    def calc_temp_sens(Vmux, Nf):
        V_Bias0 = np.mean(Vmux[:16])
        V_Bias1 = np.mean(Vmux[-16:])
        q = 1.602e-19
        kB = 1.38064852e-23
        dV = V_Bias1 - V_Bias0
        T = dV * q / (Nf * kB * np.log(15)) - 273.15
        return T

    AR_TEMP_ASLDO = calc_temp_sens(Vmux14, NfPar["NfASLDO"])
    AR_TEMP_DSLDO = calc_temp_sens(Vmux16, NfPar["NfDSLDO"])
    AR_TEMP_ACB = calc_temp_sens(Vmux18, NfPar["NfACB"])

    log.debug(f" T MonSensAna: {AR_TEMP_ASLDO} C")
    log.debug(f" T MonSensDig: {AR_TEMP_DSLDO} C")
    log.debug(f" T MonSensAcb: {AR_TEMP_ACB} C")

    return (
        round(AR_TEMP_NTC, 1),
        round(AR_TEMP_EXT, 1),
        round(AR_TEMP_ASLDO, 1),
        round(AR_TEMP_DSLDO, 1),
        round(AR_TEMP_ACB, 1),
    )


def round_list(list_values, digit=None):
    rounded_list = []
    for item in list_values:
        if item >= 0.01 or item == 0:
            rounded_list.append(round(item, digit))
        else:
            rounded_list.append(
                float(np.format_float_scientific(item, precision=digit))
            )
    return rounded_list


def plot_vdd_vs_trim(trim, vdd, vdd_name, output_name, chipname, fit_method):
    fig, ax1 = plt.subplots()
    ax1.plot(trim, vdd, "o", label=f"{vdd_name} vs trim")
    if fit_method == "root":
        p1, p0 = linear_fit(trim, vdd)
    elif fit_method == "numpy":
        p1, p0 = linear_fit_np(trim, vdd)
    ax1.axhline(y=1.2, color="r", linestyle="--", label=f"Nominal {vdd_name} value")
    x_line = np.linspace(trim[0], trim[-1], 100)
    ax1.plot(
        x_line,
        p1 * x_line + p0,
        "g-",
        alpha=0.5,
        label=f"Fitted line y = {p1:.3e} * x + {p0:.3e}",
    )
    ax1.set_title(f"{vdd_name} vs Trim Chip {chipname}")
    ax1.set_xlabel("Trim")
    ax1.set_ylabel(f"{vdd_name} (V)")
    ax1.legend()
    log.info(f" Saving {output_name}")
    fig.savefig(output_name)


@app.command()
def main(
    input_meas: Path = OPTIONS["input_meas"],
    base_output_dir: Path = OPTIONS["output_dir"],
    qc_criteria_path: Path = OPTIONS["qc_criteria"],
    layer: str = OPTIONS["layer"],
    permodule: bool = OPTIONS["permodule"],
    submit: bool = OPTIONS["submit"],
    site: str = OPTIONS["site"],
    fit_method: FitMethod = OPTIONS["fit_method"],
    verbosity: LogLevel = OPTIONS["verbosity"],
):
    log.setLevel(verbosity.value)

    if submit and site == "":
        log.error(
            bcolors.ERROR
            + "You have supplied the --submit option without specifying --site (testing site). Please specify your testing site if you would like to submit these results"
            + bcolors.ENDC
        )
        return

    log.info("")
    log.info(" ===============================================")
    log.info(" \tPerforming Analog Readback analysis")
    log.info(" ===============================================")
    log.info("")

    test_type = os.path.basename(__file__).split(".py")[0]
    time_start = round(datetime.timestamp(datetime.now()))
    output_dir = base_output_dir.joinpath(test_type).joinpath(f"{time_start}")
    output_dir.mkdir(parents=True, exist_ok=False)

    allinputs = get_inputs(input_meas, test_type)
    qc_config = get_qc_config(qc_criteria_path, test_type)

    alloutput = []
    timestamps = []

    alloutput_int_biases = []
    timestamps_int_biases = []

    for filename in sorted(allinputs):
        log.info("")
        log.info(f" Loading {filename}")
        meas_timestamp = get_time_stamp(filename)
        inputDFs = load_json(filename)
        log.debug(
            f" There are results from {len(inputDFs)} measuremnet(s) stored in this file"
        )

        chipnames = []
        results = {}
        data = {}
        int_biases = {}
        for inputDF in inputDFs:
            qcframe = inputDF.get_results()
            metadata = qcframe.get_meta_data()

            # Read chipname from input DF
            try:
                chipname = metadata.get("Name")
                chipnames.append(chipname)
                log.debug(f" Found chip name = {chipname} from chip config")
            except Exception:
                log.warning(
                    bcolors.WARNING
                    + "Chip name not found in input from {filename}, skipping."
                    + bcolors.ENDC
                )
                continue

            # Create an output DF for each chip
            if chipname not in data:
                data[chipname] = qcDataFrame()
                data[chipname].add_property(
                    "ANALYSIS_VERSION",
                    __version__,
                )
                data[chipname]._meta_data.update(metadata)
                int_biases[chipname] = {}

            """"" Check file integrity  """ ""
            checker = JsonChecker(inputDF, test_type)

            try:
                checker.check()
            except BaseException as exc:
                log.exception(exc)
                log.warning(
                    bcolors.WARNING
                    + " JsonChecker check not passed, skipping this input."
                    + bcolors.ENDC
                )
                continue
            else:
                log.debug(" JsonChecker check passed!")
                pass

            """""  Calculate quanties   """ ""
            # Vmux conversion is embedded.
            extractor = DataExtractor(inputDF, test_type)
            calculated_data = extractor.calculate()
            Vmux_map = getVmuxMap()
            Imux_map = getImuxMap()

            AR_values_names = []
            for imux in range(32):
                AR_values_names.append(Imux_map[imux])
            for vmux in range(40):
                AR_values_names.append(Vmux_map[vmux])

            tmpresults = {}
            if inputDF._subtestType == "AR_VMEAS":
                for key in calculated_data:
                    int_biases[chipname][key] = calculated_data[key]["Values"][0]
                AR_values = []
                NOT_MEASURED = 0
                for name in AR_values_names:
                    if name in int_biases[chipname]:
                        AR_values.append(int_biases[chipname][name])
                    else:
                        AR_values.append(NOT_MEASURED)
                data[chipname].add_parameter(
                    "AR_NOMINAL_SETTINGS", round_list(AR_values, 4)
                )
                tmpresults.update({"AR_NOMINAL_SETTINGS": AR_values})

            elif inputDF._subtestType == "AR_TEMP":
                NtcCalPar = get_NtcCalPar(metadata["ChipConfigs"]["RD53B"]["Parameter"])
                NfPar = get_NfPar(metadata["ChipConfigs"]["RD53B"]["Parameter"])
                (
                    AR_TEMP_NTC,
                    AR_TEMP_EXT,
                    AR_TEMP_ASLDO,
                    AR_TEMP_DSLDO,
                    AR_TEMP_ACB,
                ) = calculate_T(calculated_data, NtcCalPar, NfPar)
                # Add parameters for output file
                data[chipname].add_parameter("AR_TEMP_NTC", AR_TEMP_NTC)
                data[chipname].add_parameter("AR_TEMP_EXT", AR_TEMP_EXT)
                data[chipname].add_parameter("AR_TEMP_ASLDO", AR_TEMP_ASLDO)
                data[chipname].add_parameter("AR_TEMP_DSLDO", AR_TEMP_DSLDO)
                data[chipname].add_parameter("AR_TEMP_ACB", AR_TEMP_ACB)
                data[chipname].add_parameter("AR_TEMP_NF_ASLDO", NfPar["NfASLDO"])
                data[chipname].add_parameter("AR_TEMP_NF_DSLDO", NfPar["NfDSLDO"])
                data[chipname].add_parameter("AR_TEMP_NF_ACB", NfPar["NfACB"])
                data[chipname].add_parameter("AR_TEMP_POLY_TOP", EMPTY_VAL)
                data[chipname].add_parameter("AR_TEMP_POLY_BOTTOM", EMPTY_VAL)
                data[chipname].add_parameter("AR_TEMP_NF_TOP", EMPTY_VAL)
                data[chipname].add_parameter("AR_TEMP_NF_BOTTOM", EMPTY_VAL)
                data[chipname].add_parameter("AR_RING_OSCILATOR_A", [EMPTY_VAL] * 16)
                data[chipname].add_parameter("AR_RING_OSCILATOR_B", [EMPTY_VAL] * 16)
                # Load values to dictionary for QC analysis
                tmpresults.update({"ChipNTC_vs_ExtExt": AR_TEMP_NTC - AR_TEMP_EXT})
                tmpresults.update({"ASLO_ChipNTC": AR_TEMP_ASLDO - AR_TEMP_NTC})
                tmpresults.update({"DSLD_ChipNTC": AR_TEMP_DSLDO - AR_TEMP_NTC})
                tmpresults.update({"ACB_ChipNTC": AR_TEMP_ACB - AR_TEMP_NTC})

            elif inputDF._subtestType == "AR_VDD":
                vdda = calculated_data["VDDA"]["Values"].tolist()
                vddd = calculated_data["VDDD"]["Values"].tolist()
                trimA = calculated_data["SldoTrimA"]["Values"].tolist()
                trimD = calculated_data["SldoTrimD"]["Values"].tolist()
                # Add parameters for output file
                data[chipname].add_parameter("AR_VDDA_VS_TRIM", round_list(vdda, 4))
                data[chipname].add_parameter("AR_VDDD_VS_TRIM", round_list(vddd, 4))
                output_name_vdda = output_dir.joinpath(f"{chipname}_VDDA_TRIM.png")
                output_name_vddd = output_dir.joinpath(f"{chipname}_VDDD_TRIM.png")
                plot_vdd_vs_trim(
                    trimA, vdda, "VDDA", output_name_vdda, chipname, fit_method.value
                )
                plot_vdd_vs_trim(
                    trimD, vddd, "VDDD", output_name_vddd, chipname, fit_method.value
                )
                # Load values to dictionary for QC analysis
                tmpresults.update({"AR_VDDA_VS_TRIM": round_list(vdda, 4)})
                tmpresults.update({"AR_VDDD_VS_TRIM": round_list(vddd, 4)})

            else:
                log.warning(
                    bcolors.WARNING
                    + f"{filename}.json does not have any required subtestType. Skipping."
                    + bcolors.ENDC
                )
                continue

            if results.get(chipname):
                results[chipname].update(tmpresults)
            else:
                results[chipname] = tmpresults

        log.debug(
            f" There are results from {len(chipnames)} chip(s) stored in this file"
        )

        """"" Output a json file """ ""
        for key, df in data.items():
            outputDF = outputDataFrame()
            outputDF.set_test_type(test_type)
            outputDF.set_results(df)

            # Perform QC analysis
            passes_qc = perform_qc_analysis(
                test_type,
                qc_config,
                layer,
                results.get(key),
                verbosity.value,
            )
            if passes_qc == -1:
                log.error(
                    bcolors.ERROR
                    + f" QC analysis for {key} was NOT successful. Please fix and re-run. Continuing to next chip.."
                    + bcolors.ENDC
                )
                continue
            log.info("")
            if passes_qc:
                log.info(
                    f" Chip {key} passes QC? "
                    + bcolors.OKGREEN
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            else:
                log.info(
                    f" Chip {key} passes QC? "
                    + bcolors.BADRED
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            log.info("")
            outputDF.set_pass_flag(passes_qc)
            if submit:
                submit_results(
                    outputDF.to_dict(True),
                    time_start,
                    site,
                    output_dir.joinpath("submit.txt"),
                )

            if permodule:
                alloutput += [outputDF.to_dict(True)]
                timestamps += [meas_timestamp]
            else:
                outfile = output_dir.joinpath(f"{key}.json")
                log.info(f" Saving output of analysis to: {outfile}")
                save_dict_list(outfile, [outputDF.to_dict(True)])

        if verbosity.value == "DEBUG":
            # Save an output file for only internal biases
            for key in int_biases:
                if permodule:
                    alloutput_int_biases += [int_biases[key]]
                    timestamps_int_biases += [meas_timestamp]
                else:
                    outfile = output_dir.joinpath(f"{key}_internal_biases.json")
                    log.info(f" Saving DEBUG file with internal biases to: {outfile}")
                    save_dict_list(outfile, [int_biases[key]])
    if permodule:
        # Only store results from same timestamp into same file
        dfs = np.array(alloutput)
        tss = np.array(timestamps)
        for x in np.unique(tss):
            outfile = output_dir.joinpath("module.json")
            log.info(f" Saving output of analysis to: {outfile}")
            save_dict_list(
                outfile,
                dfs[tss == x].tolist(),
            )
        if verbosity.value == "DEBUG":
            # Save an output file for only internal biases
            dfs = np.array(alloutput_int_biases)
            tss = np.array(timestamps_int_biases)
            for x in np.unique(tss):
                outfile = output_dir.joinpath(f"internal_biases_{x}.json")
                log.info(f" Saving DEBUG file with internal biases to: {outfile}")
                save_dict_list(
                    outfile,
                    dfs[tss == x].tolist(),
                )


if __name__ == "__main__":
    typer.run(main)
