/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.statistic;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.stream.DoubleStream;
import si.ijs.kt.clus.algo.kNN.KnnModel;
import si.ijs.kt.clus.algo.kNN.methods.SearchAlgorithm;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.statistic.ClassificationStat;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.io.ini.INIFileNominalOrDoubleOrVector;

public class KnnMlcStat
extends ClassificationStat {
    private double[][] m_PriorLabelProbabilities;
    private ArrayList<HashMap<Integer, double[][]>> m_NeighbourhoodProbabilities;
    private boolean m_IsInitialized = false;

    public KnnMlcStat(Settings sett, NominalAttrType[] nomAtts) {
        super(sett, nomAtts);
        for (int i = 0; i < this.m_NbTarget; ++i) {
            if (this.m_ClassCounts[i].length == 2) continue;
            throw new RuntimeException("This is not MLC!");
        }
    }

    public KnnMlcStat(Settings sett, NominalAttrType[] nomAtts, INIFileNominalOrDoubleOrVector multiLabelThreshold) {
        super(sett, nomAtts, multiLabelThreshold);
        throw new RuntimeException("Thresholds do not have any influence here.");
    }

    @Override
    public ClusStatistic cloneStat() {
        KnnMlcStat res = new KnnMlcStat(this.m_Settings, this.m_Attrs);
        res.m_Training = this.m_Training;
        res.m_ParentStat = this.m_ParentStat;
        if (this.m_Thresholds != null) {
            res.m_Thresholds = Arrays.copyOf(this.m_Thresholds, this.m_Thresholds.length);
        }
        res.m_PriorLabelProbabilities = this.m_PriorLabelProbabilities;
        res.m_NeighbourhoodProbabilities = this.m_NeighbourhoodProbabilities;
        res.m_IsInitialized = this.m_IsInitialized;
        return res;
    }

    public void tryInitializeMLC(int[] ks, RowData trainSet, KnnModel model, double smoother) throws ClusException {
        if (this.m_IsInitialized) {
            throw new RuntimeException("This method was called more than once.");
        }
        this.m_IsInitialized = true;
        int nbExamples = trainSet.getNbRows();
        int maxK = model.getMaxK();
        SearchAlgorithm search = model.getSearch();
        ClusAttrType[] labels = trainSet.getSchema().getAllAttrUse(ClusAttrType.AttributeUseType.Target);
        int nbLabels = labels.length;
        this.m_PriorLabelProbabilities = new double[nbLabels][2];
        this.m_NeighbourhoodProbabilities = new ArrayList();
        for (int label = 0; label < nbLabels; ++label) {
            HashMap<Integer, double[][]> labelMap = new HashMap<Integer, double[][]>();
            for (int k : ks) {
                labelMap.put(k, new double[2][maxK + 1]);
            }
            this.m_NeighbourhoodProbabilities.add(labelMap);
        }
        for (DataTuple dt : trainSet.getData()) {
            LinkedList<DataTuple> nearest = search.returnNNs(dt, maxK);
            for (int label = 0; label < nbLabels; ++label) {
                NominalAttrType attr = (NominalAttrType)labels[label];
                int dtAttrValue = attr.getNominal(dt);
                double[] dArray = this.m_PriorLabelProbabilities[label];
                int n = dtAttrValue;
                dArray[n] = dArray[n] + 1.0;
                int labelCount = 0;
                int nbNeighs = 0;
                int kIndex = 0;
                for (DataTuple n2 : nearest) {
                    ++nbNeighs;
                    int neighAttrValue = attr.getNominal(n2);
                    if (neighAttrValue == 0) {
                        ++labelCount;
                    }
                    if (ks[kIndex] != nbNeighs) continue;
                    double[] dArray2 = this.m_NeighbourhoodProbabilities.get(label).get(nbNeighs)[dtAttrValue];
                    int n3 = labelCount;
                    dArray2[n3] = dArray2[n3] + 1.0;
                    ++kIndex;
                }
                if (kIndex == ks.length) continue;
                throw new RuntimeException("Not all neighbourhood sizes were analyzed.");
            }
        }
        for (int label = 0; label < nbLabels; ++label) {
            HashMap<Integer, double[][]> labelMap = this.m_NeighbourhoodProbabilities.get(label);
            for (double[][] counts : labelMap.values()) {
                for (int isRelevant = 0; isRelevant < counts.length; ++isRelevant) {
                    double countSum = DoubleStream.of(counts[isRelevant]).sum();
                    for (int labelCount = 0; labelCount < counts[isRelevant].length; ++labelCount) {
                        counts[isRelevant][labelCount] = KnnMlcStat.makeSmoother(counts[isRelevant][labelCount], countSum, counts[isRelevant].length, smoother);
                    }
                }
            }
            for (int isRelevant = 0; isRelevant < this.m_PriorLabelProbabilities[label].length; ++isRelevant) {
                this.m_PriorLabelProbabilities[label][isRelevant] = KnnMlcStat.makeSmoother(this.m_PriorLabelProbabilities[label][isRelevant], nbExamples, this.m_PriorLabelProbabilities[label].length, smoother);
            }
        }
    }

    @Override
    public void calcMean() {
        this.m_MajorityClasses = new int[this.m_NbTarget];
        for (int i = 0; i < this.m_NbTarget; ++i) {
            this.m_MajorityClasses[i] = this.getMajorityClass(i);
        }
    }

    @Override
    public int getMajorityClass(int attr) {
        int majClass = -1;
        double pMax = Double.NEGATIVE_INFINITY;
        int nbRelevantNeighbourhood = (int)Math.round(this.m_ClassCounts[attr][0]);
        for (int i = 0; i < 2; ++i) {
            double p = this.getProbability(attr, i, nbRelevantNeighbourhood);
            if (!(p > pMax)) continue;
            majClass = i;
            pMax = p;
        }
        return majClass;
    }

    private double getProbability(int labelIndex, int labelValue, int classCount) {
        return this.m_PriorLabelProbabilities[labelIndex][labelValue] * this.m_NeighbourhoodProbabilities.get(labelIndex).get(this.m_NbExamples)[labelValue][classCount];
    }

    private static double makeSmoother(double origNumerator, double origDenominator, double classes, double smoother) {
        return (origNumerator + smoother) / (classes * smoother + origDenominator);
    }
}

