/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.semisupervised.confidence.regression;

import java.util.HashMap;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.error.Accuracy;
import si.ijs.kt.clus.error.RMSError;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.hmlc.HierErrorMeasures;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.ext.semisupervised.confidence.PredictionConfidence;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsHMLC;
import si.ijs.kt.clus.main.settings.section.SettingsSSL;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.util.exception.ClusException;

public class RForestProximities
extends PredictionConfidence {
    RowData m_trainingSet;
    int m_origLabeledMax;
    boolean proximitiesInitialized = false;

    public RForestProximities(ClusStatManager statManager, SettingsSSL.SSLNormalization normalizationType, SettingsSSL.SSLAggregation aggregationType) {
        super(statManager, normalizationType, aggregationType);
    }

    private double[] calculateExpectedError(ClusModel model) throws ClusException, InterruptedException {
        HashMap<Integer, Double> proximities = ((ClusForest)model).getProximities();
        double[] expectedOOBE = new double[this.getNbTargetAttributes()];
        for (int j = 0; j < this.m_origLabeledMax; ++j) {
            ClusError error;
            DataTuple tupleLabeled = this.m_trainingSet.getTuple(j);
            if (!proximities.containsKey(tupleLabeled.getIndex()) || !((ClusForest)model).containsOOBForTuple(tupleLabeled)) continue;
            ClusErrorList errListOOB = new ClusErrorList();
            switch (this.m_StatManager.getTargetMode()) {
                case HIERARCHICAL: {
                    error = new HierErrorMeasures(errListOOB, this.m_StatManager.getHier(), this.m_StatManager.getSettings().getHMLC().getRecallValues().getDoubleVector(), SettingsHMLC.HierarchyMeasures.PooledAUPRC, this.m_StatManager.getSettings().getOutput().isWriteCurves(), this.m_StatManager.getSettings().getOutput().isGzipOutput());
                    break;
                }
                case REGRESSION: {
                    error = new RMSError(errListOOB, this.m_StatManager.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target));
                    break;
                }
                case CLASSIFY: {
                    error = new Accuracy(errListOOB, this.m_StatManager.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target));
                    break;
                }
                default: {
                    error = new RMSError(errListOOB, this.m_StatManager.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target));
                }
            }
            errListOOB.addError(error);
            errListOOB.addExample(tupleLabeled, ((ClusForest)model).predictWeightedOOB(tupleLabeled));
            for (int k = 0; k < this.getNbTargetAttributes(); ++k) {
                int n = k;
                expectedOOBE[n] = expectedOOBE[n] + proximities.get(tupleLabeled.getIndex()) * error.getModelErrorComponent(k);
            }
        }
        return expectedOOBE;
    }

    @Override
    public double[] calculatePerTargetScores(ClusModel model, DataTuple tuple) throws ClusException, InterruptedException {
        ((ClusForest)model).predictWeightedStandardAndGetProximities(tuple);
        return this.calculateExpectedError(model);
    }

    @Override
    public double[] calculatePerTargetOOBScores(ClusForest model, DataTuple tuple) throws ClusException, InterruptedException {
        model.predictWeightedOOBAndGetProximities(tuple);
        return this.calculateExpectedError(model);
    }

    @Override
    public void calculateConfidenceScores(ClusModel model, RowData unlabeledData) throws ClusException, InterruptedException {
        if (!this.proximitiesInitialized) {
            for (int j = 0; j < this.m_origLabeledMax; ++j) {
                ((ClusForest)model).initializeProximities(this.m_trainingSet.getTuple(j));
            }
        }
        super.calculateConfidenceScores(model, unlabeledData);
        this.proximitiesInitialized = false;
    }

    @Override
    public void calculateOOBConfidenceScores(ClusForest model, RowData data) throws ClusException, InterruptedException {
        if (!this.proximitiesInitialized) {
            for (int j = 0; j < this.m_origLabeledMax; ++j) {
                model.initializeProximities(this.m_trainingSet.getTuple(j));
            }
        }
        this.proximitiesInitialized = true;
        super.calculateOOBConfidenceScores(model, data);
    }

    public void setTrainingSet(RowData trainingSet, int origLabeledMax) {
        this.m_trainingSet = trainingSet;
        this.m_origLabeledMax = origLabeledMax;
    }
}

