/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.TupleIterator;
import si.ijs.kt.clus.ext.ensemble.ClusReadWriteLock;
import si.ijs.kt.clus.ext.ensemble.ros.ClusEnsembleROSInfo;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.exception.ClusException;

public abstract class ClusEnsembleInduceOptimization
implements Serializable {
    private static final long serialVersionUID = 1L;
    public static final int SIGNIFICANT_DIGITS_IN_PREDICTIONS = 4;
    protected HashMap<DataTuple, Integer> m_TuplePositions;
    protected int m_NbUpdates = 0;
    protected ClusReadWriteLock m_NbUpdatesLock = new ClusReadWriteLock();
    protected ClusReadWriteLock m_AvgPredictionsLock = new ClusReadWriteLock();
    protected ClusEnsembleROSInfo m_EnsembleROSInfo = null;
    private Settings m_Settings;

    public ClusEnsembleInduceOptimization(TupleIterator train, TupleIterator test, Settings sett) throws IOException, ClusException {
        this.m_Settings = sett;
        this.m_TuplePositions = new HashMap();
        int count = 0;
        if (train != null) {
            train.init();
            DataTuple train_tuple = train.readTuple();
            while (train_tuple != null) {
                this.m_TuplePositions.put(train_tuple, count);
                ++count;
                train_tuple = train.readTuple();
            }
        }
        if (test != null) {
            test.init();
            DataTuple test_tuple = test.readTuple();
            while (test_tuple != null) {
                this.m_TuplePositions.put(test_tuple, count);
                ++count;
                test_tuple = test.readTuple();
            }
        }
    }

    public final Settings getSettings() {
        return this.m_Settings;
    }

    public int locateTuple(DataTuple tuple) {
        return this.m_TuplePositions.get(tuple);
    }

    public abstract void initPredictions(ClusStatistic var1, ClusEnsembleROSInfo var2);

    public abstract void updatePredictionsForTuples(ClusModel var1, TupleIterator var2, TupleIterator var3) throws IOException, ClusException, InterruptedException;

    public static double[] incrementPredictions(double[] avg_predictions, double[] predictions, double nb_models) {
        int plength = avg_predictions.length;
        double[] result = new double[plength];
        for (int i = 0; i < plength; ++i) {
            result[i] = ClusEnsembleInduceOptimization.computeNextAverage(avg_predictions[i], predictions[i], nb_models);
        }
        return result;
    }

    public double[][] incrementPredictions(double[][] sum_predictions, double[][] predictions, int nb_models) {
        double[][] result = new double[sum_predictions.length][];
        if (this.getSettings().getEnsemble().isEnsembleROSEnabled() && this.getSettings().getEnsemble().getEnsembleROSVotingType().equals((Object)SettingsEnsemble.EnsembleROSVotingType.SubspaceAveraging)) {
            int[] enabled = this.m_EnsembleROSInfo.getOnlyTargets(this.m_EnsembleROSInfo.getModelSubspace(nb_models - 1));
            for (int i = 0; i < sum_predictions.length; ++i) {
                if (enabled[i] == 1) {
                    result[i] = new double[sum_predictions[i].length];
                    for (int j = 0; j < sum_predictions[i].length; ++j) {
                        result[i][j] = ClusEnsembleInduceOptimization.computeNextAverage(sum_predictions[i][j], predictions[i][j], this.m_EnsembleROSInfo.getCoverageOpt(i));
                    }
                    continue;
                }
                result[i] = sum_predictions[i];
            }
        } else {
            for (int i = 0; i < sum_predictions.length; ++i) {
                result[i] = new double[sum_predictions[i].length];
                for (int j = 0; j < sum_predictions[i].length; ++j) {
                    result[i][j] = ClusEnsembleInduceOptimization.computeNextAverage(sum_predictions[i][j], predictions[i][j], nb_models);
                }
            }
        }
        return result;
    }

    private static double computeNextAverage(double currentAverege, double nextValue, double nbValues) {
        return nextValue / nbValues + currentAverege * (nbValues - 1.0) / nbValues;
    }

    public static double[][] incrementPredictions(double[][] sum_predictions, double[][] predictions) {
        double[][] result = new double[sum_predictions.length][];
        for (int i = 0; i < sum_predictions.length; ++i) {
            result[i] = new double[sum_predictions[i].length];
            for (int j = 0; j < sum_predictions[i].length; ++j) {
                result[i][j] = sum_predictions[i][j] + predictions[i][j];
            }
        }
        return result;
    }

    public static double[][] transformToMajority(double[][] counts) {
        int[] maxPerTarget = new int[counts.length];
        for (int i = 0; i < counts.length; ++i) {
            maxPerTarget[i] = -1;
            double m_max = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < counts[i].length; ++j) {
                if (!(counts[i][j] > m_max)) continue;
                maxPerTarget[i] = j;
                m_max = counts[i][j];
            }
        }
        double[][] result = new double[counts.length][];
        for (int m = 0; m < counts.length; ++m) {
            result[m] = new double[counts[m].length];
            double[] dArray = result[m];
            int n = maxPerTarget[m];
            dArray[n] = dArray[n] + 1.0;
        }
        return result;
    }

    public static double[][] transformToProbabilityDistribution(double[][] counts) {
        double[] sumPerTarget = new double[counts.length];
        for (int i = 0; i < counts.length; ++i) {
            for (int j = 0; j < counts[i].length; ++j) {
                int n = i;
                sumPerTarget[n] = sumPerTarget[n] + counts[i][j];
            }
        }
        double[][] result = new double[counts.length][];
        for (int m = 0; m < counts.length; ++m) {
            result[m] = new double[counts[m].length];
            for (int n = 0; n < counts[m].length; ++n) {
                result[m][n] = counts[m][n] / sumPerTarget[m];
            }
        }
        return result;
    }

    public abstract int getPredictionLength(int var1);

    public abstract void roundPredictions();
}

