/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error.mlcForHmlc;

import java.io.Serializable;
import si.ijs.kt.clus.error.BinaryPredictionList;
import si.ijs.kt.clus.error.ROCAndPRCurve;
import si.ijs.kt.clus.error.mlcForHmlc.MlcHmlcSubError;

public abstract class MLROCAndPRCurve
implements MlcHmlcSubError,
Serializable {
    protected double m_AreaROC;
    protected double m_AreaPR;
    protected double[] m_Thresholds;
    protected transient boolean m_ExtendPR;
    protected transient BinaryPredictionList m_Values;
    protected BinaryPredictionList[] m_ClassWisePredictions;
    protected ROCAndPRCurve[] m_ROCAndPRCurves;
    protected double m_AverageAUROC = -1.0;
    protected double m_AverageAUPRC = -1.0;
    protected double m_WAvgAUPRC = -1.0;
    protected double m_PooledAUPRC = -1.0;

    public MLROCAndPRCurve(int dim) {
        this.m_ClassWisePredictions = new BinaryPredictionList[dim];
        this.m_ROCAndPRCurves = new ROCAndPRCurve[dim];
        for (int i = 0; i < dim; ++i) {
            BinaryPredictionList predlist;
            this.m_ClassWisePredictions[i] = predlist = new BinaryPredictionList();
            this.m_ROCAndPRCurves[i] = new ROCAndPRCurve(predlist);
        }
    }

    public double getCurveError(CurveType typeOfCurve) {
        this.computeAll();
        switch (typeOfCurve) {
            case averageAUROC: {
                return this.m_AverageAUROC;
            }
            case averageAUPRC: {
                return this.m_AverageAUPRC;
            }
            case weightedAUPRC: {
                return this.m_WAvgAUPRC;
            }
            case pooledAUPRC: {
                return this.m_PooledAUPRC;
            }
        }
        throw new RuntimeException("Unknown type of curve: " + (Object)((Object)typeOfCurve));
    }

    public void computeAll() {
        int dim = this.m_ROCAndPRCurves.length;
        BinaryPredictionList pooled = new BinaryPredictionList();
        ROCAndPRCurve pooledCurve = new ROCAndPRCurve(pooled);
        for (int i = 0; i < dim; ++i) {
            this.m_ClassWisePredictions[i].sort();
            this.m_ROCAndPRCurves[i].computeCurves();
            this.m_ROCAndPRCurves[i].clear();
            pooled.add(this.m_ClassWisePredictions[i]);
            this.m_ClassWisePredictions[i].clearData();
        }
        pooled.sort();
        pooledCurve.computeCurves();
        pooledCurve.clear();
        int cnt = 0;
        double sumAUROC = 0.0;
        double sumAUPRC = 0.0;
        double sumAUPRCw = 0.0;
        double sumFrequency = 0.0;
        for (int i = 0; i < dim; ++i) {
            double freq = this.m_ClassWisePredictions[i].getFrequency();
            sumAUROC += this.m_ROCAndPRCurves[i].getAreaROC();
            sumAUPRC += this.m_ROCAndPRCurves[i].getAreaPR();
            sumAUPRCw += freq * this.m_ROCAndPRCurves[i].getAreaPR();
            sumFrequency += freq;
            ++cnt;
        }
        this.m_AverageAUROC = sumAUROC / (double)cnt;
        this.m_AverageAUPRC = sumAUPRC / (double)cnt;
        this.m_WAvgAUPRC = sumAUPRCw / sumFrequency;
        this.m_PooledAUPRC = pooledCurve.getAreaPR();
    }

    @Override
    public void addExample(boolean[] actual, double[] predicted, boolean[] predictedThresholded) {
        double[] probabilities = predicted;
        for (int i = 0; i < actual.length; ++i) {
            this.m_ClassWisePredictions[i].addExample(actual[i], probabilities[i]);
        }
    }

    @Override
    public void add(MlcHmlcSubError other) {
        MLROCAndPRCurve o = (MLROCAndPRCurve)other;
        for (int i = 0; i < this.m_ROCAndPRCurves.length; ++i) {
            this.m_ClassWisePredictions[i].add(o.m_ClassWisePredictions[i]);
        }
        this.computeAll();
    }

    public static enum CurveType {
        averageAUROC,
        averageAUPRC,
        weightedAUPRC,
        pooledAUPRC;

    }
}

