/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.statistic;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import org.apache.commons.math3.distribution.TDistribution;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.ext.ensemble.ClusOOBWeights;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSForestInfo;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSModelInfo;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.CombStat;
import si.ijs.kt.clus.util.format.ClusFormat;
import si.ijs.kt.clus.util.format.ClusNumberFormat;
import si.ijs.kt.clus.util.jeans.util.StringUtils;

public abstract class RegressionStatBase
extends ClusStatistic {
    public static final long serialVersionUID = 1L;
    protected int m_NbAttrs;
    protected NumericAttrType[] m_Attrs;
    public double[] m_Means;

    public void setNbAttributes(int value) {
        this.m_NbAttrs = value;
    }

    public RegressionStatBase(Settings sett, NumericAttrType[] attrs) {
        this(sett, attrs, false);
    }

    public RegressionStatBase(Settings sett, NumericAttrType[] attrs, boolean onlymean) {
        super(sett);
        this.m_Attrs = attrs;
        this.m_NbAttrs = attrs.length;
        if (onlymean) {
            this.m_Means = new double[this.m_NbAttrs];
        }
    }

    @Override
    public int getNbAttributes() {
        return this.m_NbAttrs;
    }

    public NumericAttrType[] getAttributes() {
        return this.m_Attrs;
    }

    public NumericAttrType getAttribute(int idx) {
        return this.m_Attrs[idx];
    }

    @Override
    public void addPrediction(ClusStatistic other, double weight) {
        RegressionStatBase or = (RegressionStatBase)other;
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            if (Double.isNaN(or.m_Means[i])) continue;
            int n = i;
            this.m_Means[n] = this.m_Means[n] + weight * or.m_Means[i];
        }
    }

    @Override
    public void updateWeighted(DataTuple tuple, int idx) {
        this.updateWeighted(tuple, tuple.getWeight());
    }

    @Override
    public void computePrediction() {
    }

    public abstract void calcMean(double[] var1);

    @Override
    public void calcMean() {
        if (this.m_Means == null) {
            this.m_Means = new double[this.m_NbAttrs];
        }
        this.calcMean(this.m_Means);
    }

    public void setMeans(double[] means) {
        this.m_Means = means;
    }

    public abstract double getMean(int var1);

    public abstract double getSVarS(int var1);

    public double getVariance(int i) {
        return this.m_SumWeight != 0.0 ? this.getSVarS(i) / this.m_SumWeight : 0.0;
    }

    public double getStandardDeviation(int i) {
        return Math.sqrt(this.getSVarS(i) / (this.m_SumWeight - 1.0));
    }

    public double getScaledSS(int i, ClusAttributeWeights scale) {
        return this.getSVarS(i) * scale.getWeight(this.getAttribute(i));
    }

    public double getScaledVariance(int i, ClusAttributeWeights scale) {
        return this.getVariance(i) * scale.getWeight(this.getAttribute(i));
    }

    public double getRootScaledVariance(int i, ClusAttributeWeights scale) {
        return Math.sqrt(this.getScaledVariance(i, scale));
    }

    public double[] getRootScaledVariances(ClusAttributeWeights scale) {
        int nb = this.getNbAttributes();
        double[] res = new double[nb];
        for (int i = 0; i < res.length; ++i) {
            res[i] = this.getRootScaledVariance(i, scale);
        }
        return res;
    }

    @Override
    public double getDispersion(ClusAttributeWeights scale, RowData data) {
        System.err.println(this.getClass().getName() + ": getDispersion(): Not yet implemented!");
        return Double.POSITIVE_INFINITY;
    }

    public double getTTestPValue(int att, ClusStatManager stat_manager) {
        double global_mean = ((CombStat)stat_manager.getTrainSetStat()).m_RegStat.getMean(att);
        double global_var = ((CombStat)stat_manager.getTrainSetStat()).m_RegStat.getVariance(att);
        double global_n = ((CombStat)stat_manager.getTrainSetStat()).getTotalWeight();
        double local_mean = this.getMean(att);
        double local_var = this.getVariance(att);
        double local_n = this.getTotalWeight();
        double t = Math.abs(local_mean - global_mean) / Math.sqrt(local_var / local_n + global_var / global_n);
        double degreesOfFreedom = 0.0;
        degreesOfFreedom = this.df(local_var, global_var, local_n, global_n);
        TDistribution tDistribution = new TDistribution(degreesOfFreedom);
        return 1.0 - tDistribution.cumulativeProbability(-t, t);
    }

    protected double df(double v1, double v2, double n1, double n2) {
        return (v1 / n1 + v2 / n2) * (v1 / n1 + v2 / n2) / (v1 * v1 / (n1 * n1 * (n1 - 1.0)) + v2 * v2 / (n2 * n2 * (n2 - 1.0)));
    }

    @Override
    public double[] getNumericPred() {
        return this.m_Means;
    }

    @Override
    public String getPredictedClassName(int idx) {
        return "";
    }

    @Override
    public int getNbNumericAttributes() {
        return this.m_NbAttrs;
    }

    @Override
    public double getError(ClusAttributeWeights scale) {
        return this.getSVarS(scale);
    }

    @Override
    public double getErrorDiff(ClusAttributeWeights scale, ClusStatistic other) {
        return this.getSVarSDiff(scale, other);
    }

    public double getRMSE(ClusAttributeWeights scale) {
        return Math.sqrt(this.getSVarS(scale) / this.getTotalWeight());
    }

    public void initNormalizationWeights(ClusAttributeWeights weights, boolean[] shouldNormalize) {
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            int idx = this.m_Attrs[i].getIndex();
            if (!shouldNormalize[idx]) continue;
            double var = this.getVariance(i);
            double norm = var > 0.0 ? 1.0 / var : 1.0;
            weights.setWeight(this.m_Attrs[i], norm);
        }
    }

    @Override
    public double getSquaredDistance(DataTuple tuple, ClusAttributeWeights weights) {
        double sum = 0.0;
        for (int i = 0; i < this.getNbAttributes(); ++i) {
            NumericAttrType type = this.getAttribute(i);
            double dist = type.getNumeric(tuple) - this.m_Means[i];
            sum += dist * dist * weights.getWeight(type);
        }
        return sum / (double)this.getNbAttributes();
    }

    public double[] getPointwiseSquaredDistance(DataTuple tuple, ClusAttributeWeights weights) {
        double[] distances = new double[this.getNbAttributes()];
        for (int i = 0; i < this.getNbAttributes(); ++i) {
            NumericAttrType type = this.getAttribute(i);
            distances[i] = type.getNumeric(tuple) - this.m_Means[i];
            int n = i;
            distances[n] = distances[n] * distances[i];
        }
        return distances;
    }

    @Override
    public String getArrayOfStatistic() {
        ClusNumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        StringBuffer buf = new StringBuffer();
        buf.append("[");
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            buf.append(fr.format(this.m_Means[i]));
        }
        buf.append("]");
        return buf.toString();
    }

    @Override
    public String getPredictString() {
        StringBuffer buf = new StringBuffer();
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            buf.append(String.valueOf(this.m_Means[i]));
        }
        return buf.toString();
    }

    @Override
    public String getDebugString() {
        int i;
        ClusNumberFormat fr = ClusFormat.THREE_AFTER_DOT;
        StringBuffer buf = new StringBuffer();
        buf.append("[");
        for (i = 0; i < this.m_NbAttrs; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            buf.append(fr.format(this.getMean(i)));
        }
        buf.append("]");
        buf.append("[");
        for (i = 0; i < this.m_NbAttrs; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            buf.append(fr.format(this.getVariance(i)));
        }
        buf.append("]");
        return buf.toString();
    }

    @Override
    public void printDistribution(PrintWriter wrt) throws IOException {
        ClusNumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        for (int i = 0; i < this.m_Attrs.length; ++i) {
            wrt.print(StringUtils.printStr(this.m_Attrs[i].getName(), 35));
            wrt.print(" [");
            wrt.print(fr.format(this.getMean(i)));
            wrt.print(",");
            wrt.print(fr.format(this.getVariance(i)));
            wrt.println("]");
        }
    }

    @Override
    public void addPredictWriterSchema(String prefix, ClusSchema schema) {
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            ClusAttrType type = this.m_Attrs[i].cloneType();
            type.setName(prefix + "-p-" + type.getName());
            schema.addAttrType(type);
        }
    }

    @Override
    public String getPredictWriterString() {
        StringBuffer buf = new StringBuffer();
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            if (this.m_Means != null) {
                buf.append("" + this.m_Means[i]);
                continue;
            }
            buf.append("?");
        }
        return buf.toString();
    }

    @Override
    public void predictTuple(DataTuple prediction) {
        for (int i = 0; i < this.m_NbAttrs; ++i) {
            NumericAttrType type = this.m_Attrs[i];
            type.setNumeric(prediction, this.m_Means[i]);
        }
    }

    @Override
    public void predictTupleOneComponent(DataTuple tuple, int i, double value) {
        NumericAttrType type = this.m_Attrs[i];
        type.setNumeric(tuple, value);
    }

    @Override
    public void vote(ArrayList<ClusStatistic> votes) {
        this.reset();
        this.m_Means = new double[this.m_NbAttrs];
        int nb_votes = votes.size();
        for (int model = 0; model < nb_votes; ++model) {
            RegressionStatBase vote = (RegressionStatBase)votes.get(model);
            for (int target = 0; target < this.m_NbAttrs; ++target) {
                int n = target;
                this.m_Means[n] = this.m_Means[n] + vote.getMean(target) / (double)nb_votes;
            }
        }
    }

    @Override
    public void vote(ArrayList<ClusStatistic> votes, ClusOOBWeights weights) {
        this.reset();
        SettingsEnsemble.EnsembleVotingType evt = this.getSettings().getEnsemble().getEnsembleVotingType();
        this.m_Means = new double[this.m_NbAttrs];
        int nb_votes = votes.size();
        block4: for (int model = 0; model < nb_votes; ++model) {
            RegressionStatBase vote = (RegressionStatBase)votes.get(model);
            switch (evt) {
                case OOBModelWeighted: {
                    int target;
                    for (target = 0; target < this.m_NbAttrs; ++target) {
                        int n = target;
                        this.m_Means[n] = this.m_Means[n] + vote.getMean(target) * weights.getModelWeight(model);
                    }
                    continue block4;
                }
                case OOBTargetWeighted: {
                    int target;
                    for (target = 0; target < this.m_NbAttrs; ++target) {
                        int n = target;
                        this.m_Means[n] = this.m_Means[n] + vote.getMean(target) * weights.getComponentWeight(model, target);
                    }
                    continue block4;
                }
                default: {
                    throw new RuntimeException("OOB voting not defined! si.ijs.kt.clus.statistic.RegressionStatBase.vote(ArrayList<ClusStatistic>, ClusOOBWeights)");
                }
            }
        }
    }

    @Override
    public void vote(ArrayList<ClusStatistic> votes, ClusROSForestInfo ROSForestInfo) {
        this.reset();
        this.m_Means = new double[this.m_NbAttrs];
        double[] coverage = ROSForestInfo.getCoverage();
        SettingsEnsemble.EnsembleROSAlgorithmType at = this.getSettings().getEnsemble().getEnsembleROSAlgorithmType();
        block3: for (int model = 0; model < votes.size(); ++model) {
            RegressionStatBase vote = (RegressionStatBase)votes.get(model);
            ClusROSModelInfo info = ROSForestInfo.getROSModelInfo(model);
            switch (at) {
                case FixedSubspaces: 
                case DynamicSubspaces: {
                    for (Integer target : info.getTargets()) {
                        int n = target;
                        this.m_Means[n] = this.m_Means[n] + vote.getMean(target) / coverage[target];
                    }
                    for (Integer target : ROSForestInfo.getTargetsNotLearned()) {
                        int n = target;
                        this.m_Means[n] = this.m_Means[n] + vote.getMean(target) / (double)votes.size();
                    }
                    continue block3;
                }
                default: {
                    throw new RuntimeException("ROS algorithm type not defined! si.ijs.kt.clus.statistic.RegressionStatBase.vote(ArrayList<ClusStatistic>, ClusROSForestInfo)");
                }
            }
        }
    }

    @Override
    public void vote(ArrayList<ClusStatistic> votes, ClusOOBWeights weights, ClusROSForestInfo ROSForestInfo) {
        this.reset();
        this.m_Means = new double[this.m_NbAttrs];
        for (int model = 0; model < votes.size(); ++model) {
            RegressionStatBase vote = (RegressionStatBase)votes.get(model);
            ClusROSModelInfo info = ROSForestInfo.getROSModelInfo(model);
            for (Integer target : info.getTargets()) {
                int n = target;
                this.m_Means[n] = this.m_Means[n] + vote.getMean(target) * weights.getComponentWeight(model, target);
            }
            for (Integer target : ROSForestInfo.getTargetsNotLearned()) {
                int n = target;
                this.m_Means[n] = this.m_Means[n] + vote.getMean(target) * weights.getComponentWeight(model, target);
            }
        }
    }
}

