/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.ensemble;

import java.util.ArrayList;
import java.util.Arrays;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.RegressionStat;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.array.MDoubleArrayComparator;

public class ClusBoostingForest
extends ClusForest {
    private static final long serialVersionUID = 1L;
    protected ArrayList<Double> m_BetaI = new ArrayList();
    protected transient MDoubleArrayComparator m_Compare = new MDoubleArrayComparator(0);

    public ClusBoostingForest(ClusStatManager statmgr) {
        super(statmgr, null);
    }

    public void addModelToForest(ClusModel model, double beta) {
        super.addModelToForest(model);
        this.m_BetaI.add(new Double(beta));
    }

    public double getBetaI(int i) {
        return this.m_BetaI.get(i);
    }

    public double getMedianThreshold() {
        double sum = 0.0;
        for (int i = 0; i < this.m_BetaI.size(); ++i) {
            sum += Math.log(1.0 / this.getBetaI(i));
        }
        return 0.5 * sum;
    }

    @Override
    public ClusStatistic predictWeighted(DataTuple tuple) throws ClusException, InterruptedException {
        ClusStatistic predicted = this.m_Stat.cloneSimple();
        for (int i = 0; i < this.getNbModels(); ++i) {
            predicted.addPrediction(this.getModel(i).predictWeighted(tuple), 1.0 / (double)this.getNbModels());
        }
        predicted.computePrediction();
        return predicted;
    }

    public void predictWeightedRegression(RegressionStat predicted, DataTuple tuple) throws ClusException, InterruptedException {
        double[] result = predicted.getNumericPred();
        double[][] treePredictions = new double[this.getNbModels()][];
        for (int i = 0; i < treePredictions.length; ++i) {
            RegressionStat pred = (RegressionStat)this.getModel(i).predictWeighted(tuple);
            treePredictions[i] = pred.getNumericPred();
        }
        double medianThr = this.getMedianThreshold();
        double[][] preds = new double[this.getNbModels()][2];
        int nbAttr = predicted.getNbAttributes();
        for (int i = 0; i < nbAttr; ++i) {
            int j;
            for (j = 0; j < this.getNbModels(); ++j) {
                preds[j][0] = treePredictions[j][i];
                preds[j][1] = Math.log(1.0 / this.getBetaI(j));
            }
            Arrays.sort(preds, this.m_Compare);
            j = 0;
            double sum = 0.0;
            while (!((sum += preds[j][1]) >= medianThr)) {
                ++j;
            }
            result[i] = preds[j][0];
        }
    }

    public ClusBoostingForest cloneBoostingForestWithThreshold(double threshold) {
        ClusBoostingForest clone = new ClusBoostingForest(this.m_StatManager);
        clone.setModels(this.getModels());
        clone.m_BetaI = this.m_BetaI;
        WHTDStatistic stat = (WHTDStatistic)this.getStat().cloneStat();
        stat.copyAll(this.getStat());
        stat.setThreshold(threshold);
        clone.setStat(stat);
        return clone;
    }
}

