from __future__ import annotations

import logging
import os
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import typer
from module_qc_data_tools import (
    __version__,
    load_json,
    outputDataFrame,
    qcDataFrame,
    save_dict_list,
)

from module_qc_analysis_tools.cli.globals import (
    CONTEXT_SETTINGS,
    OPTIONS,
    FitMethod,
    LogLevel,
)
from module_qc_analysis_tools.utils.analysis import perform_qc_analysis, submit_results
from module_qc_analysis_tools.utils.misc import (
    DataExtractor,
    JsonChecker,
    bcolors,
    get_inputs,
    get_qc_config,
    get_time_stamp,
    linear_fit,
    linear_fit_np,
)

app = typer.Typer(context_settings=CONTEXT_SETTINGS)


@app.command()
def main(
    input_meas: Path = OPTIONS["input_meas"],
    base_output_dir: Path = OPTIONS["output_dir"],
    qc_criteria_path: Path = OPTIONS["qc_criteria"],
    layer: str = OPTIONS["layer"],
    permodule: bool = OPTIONS["permodule"],
    submit: bool = OPTIONS["submit"],
    site: str = OPTIONS["site"],
    fit_method: FitMethod = OPTIONS["fit_method"],
    verbosity: LogLevel = OPTIONS["verbosity"],
):
    log = logging.getLogger(__name__)
    log.setLevel(verbosity.value)

    if submit and site == "":
        log.error(
            bcolors.ERROR
            + "You have supplied the --submit option without specifying --site (testing site). Please specify your testing site if you would like to submit these results"
            + bcolors.ENDC
        )
        return

    log.info("")
    log.info(" ===============================================")
    log.info(" \tPerforming ADC calibration analysis")
    log.info(" ===============================================")
    log.info("")

    test_type = os.path.basename(__file__).split(".py")[0]

    time_start = round(datetime.timestamp(datetime.now()))
    output_dir = base_output_dir.joinpath(test_type).joinpath(f"{time_start}")
    output_dir.mkdir(parents=True, exist_ok=False)

    allinputs = get_inputs(input_meas, test_type)
    qc_config = get_qc_config(qc_criteria_path, test_type)

    alloutput = []
    timestamps = []
    for filename in sorted(allinputs):
        log.info("")
        log.info(f" Loading {filename}")
        meas_timestamp = get_time_stamp(filename)
        inputDFs = load_json(filename)
        log.debug(
            f" There are results from {len(inputDFs)} chip(s) stored in this file"
        )
        for inputDF in inputDFs:
            qcframe = inputDF.get_results()
            metadata = qcframe.get_meta_data()

            """"" Check file integrity  """ ""
            checker = JsonChecker(inputDF, test_type)

            try:
                checker.check()
            except BaseException as exc:
                log.exception(exc)
                log.warning(
                    bcolors.WARNING
                    + " JsonChecker check not passed, skipping this input."
                    + bcolors.ENDC
                )
                continue
            else:
                log.debug(" JsonChecker check passed!")
                pass

            try:
                chipname = metadata.get("Name")
                log.debug(f" Found chip name = {chipname} from chip config")
            except Exception:
                log.warning(
                    bcolors.WARNING
                    + "Chip name not found in input from {filename}, skipping."
                    + bcolors.ENDC
                )
                continue

            """""  Calculate quanties   """ ""
            # Vmux conversion is embedded.
            extractor = DataExtractor(inputDF, test_type)
            calculated_data = extractor.calculate()

            """""        Plotting       """ ""
            x_key = "ADC_Vmux8"
            x = calculated_data.pop(x_key)
            y_key = "VcalMed"

            value = calculated_data.get(y_key)
            if fit_method.value == "root":
                p1, p0 = linear_fit(x["Values"], value["Values"])
            if fit_method.value == "numpy":
                p1, p0 = linear_fit_np(x["Values"], value["Values"])
            # Convert from V to mV
            p1mv = p1 * 1000
            p0mv = p0 * 1000
            fig, ax1 = plt.subplots()
            ax1.plot(
                x["Values"],
                value["Values"],
                "o",
                label="Measured data",
                markersize=10,
            )
            x_line = np.linspace(x["Values"][0], x["Values"][-1], 100)
            ax1.plot(x_line, p1 * x_line + p0, "r--", label="Fitted line")
            ax1.text(
                x["Values"][0],
                0.75 * value["Values"][-1],
                f"y = {p1:.4e} * x + {p0:.4e}",
            )
            ax1.set_xlabel(f"{x_key}[{x['Unit']}]")
            ax1.set_ylabel(f"{y_key}[{value['Unit']}]")
            ax1.set_title(chipname)
            ax1.legend()
            outfile = output_dir.joinpath(f"{chipname}.png")
            log.info(f" Saving {outfile}")
            fig.savefig(f"{outfile}")

            # Load values to dictionary for QC analysis
            results = {}
            results.update({"ADC_CALIBRATION_SLOPE": round(p1mv, 3)})
            results.update({"ADC_CALIBRATION_OFFSET": round(p0mv, 3)})

            # Perform QC analysis
            passes_qc = perform_qc_analysis(
                test_type, qc_config, layer, results, verbosity.value
            )
            if passes_qc == -1:
                log.error(
                    bcolors.ERROR
                    + f" QC analysis for {chipname} was NOT successful. Please fix and re-run. Continuing to next chip.."
                    + bcolors.ENDC
                )
                continue
            log.info("")
            if passes_qc:
                log.info(
                    f" Chip {chipname} passes QC? "
                    + bcolors.OKGREEN
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            else:
                log.info(
                    f" Chip {chipname} passes QC? "
                    + bcolors.BADRED
                    + f"{passes_qc}"
                    + bcolors.ENDC
                )
            log.info("")

            """"" Output a json file """ ""
            outputDF = outputDataFrame()
            outputDF.set_test_type(test_type)
            data = qcDataFrame()
            data._meta_data.update(metadata)
            data.add_property(
                "ANALYSIS_VERSION",
                __version__,
            )
            for key, value in results.items():
                data.add_parameter(key, value)
            outputDF.set_results(data)
            outputDF.set_pass_flag(passes_qc)
            if submit:
                submit_results(
                    outputDF.to_dict(True),
                    time_start,
                    site,
                    output_dir.joinpath("submit.txt"),
                )
            if permodule:
                alloutput += [outputDF.to_dict(True)]
                timestamps += [meas_timestamp]
            else:
                outfile = output_dir.joinpath(f"{chipname}.json")
                log.info(f" Saving output of analysis to: {outfile}")
                save_dict_list(outfile, [outputDF.to_dict(True)])
    if permodule:
        # Only store results from same timestamp into same file
        dfs = np.array(alloutput)
        tss = np.array(timestamps)
        for x in np.unique(tss):
            outfile = output_dir.joinpath("module.json")
            log.info(f" Saving output of analysis to: {outfile}")
            save_dict_list(
                outfile,
                dfs[tss == x].tolist(),
            )


if __name__ == "__main__":
    typer.run(main)
