/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.pruning;

import org.apache.commons.math3.distribution.NormalDistribution;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.statistic.ClassificationStat;
import si.ijs.kt.clus.util.ClusUtil;
import si.ijs.kt.clus.util.exception.ClusException;

public class C45Pruner
extends PruneTree {
    RowData m_TrainingData;
    boolean m_SubTreeRaising = true;
    double m_ConfidenceFactor = 0.25;
    double m_ZScore = 0.0;

    @Override
    public void prune(ClusNode node) throws ClusException {
        this.m_ZScore = this.computeZScore();
        node.safePrune();
        node.pruneByTrainErr(null);
        this.pruneC45Recursive(node, this.m_TrainingData);
    }

    @Override
    public int getNbResults() {
        return 1;
    }

    public void pruneC45Recursive(ClusNode node, RowData data) throws ClusException {
        if (!node.atBottomLevel()) {
            NodeTest tst = node.getTest();
            for (int i = 0; i < node.getNbChildren(); ++i) {
                ClusNode child = (ClusNode)node.getChild(i);
                RowData subset = data.applyWeighted(tst, i);
                this.pruneC45Recursive(child, subset);
            }
            double errorsLargestBranch = 0.0;
            int indexOfLargestBranch = node.getLargestBranchIndex();
            if (this.m_SubTreeRaising) {
                ClusNode largest = (ClusNode)node.getChild(indexOfLargestBranch);
                errorsLargestBranch = this.getEstimatedErrorsForBranch(largest, data);
            } else {
                errorsLargestBranch = Double.MAX_VALUE;
            }
            double errorsLeaf = this.getEstimatedErrorsForDistribution((ClassificationStat)node.getTargetStat());
            double errorsTree = this.getEstimatedErrors(node);
            if (ClusUtil.smOrEq(errorsLeaf, errorsTree + 0.1) && ClusUtil.smOrEq(errorsLeaf, errorsLargestBranch + 0.1)) {
                node.makeLeaf();
                return;
            }
            if (ClusUtil.smOrEq(errorsLargestBranch, errorsTree + 0.1)) {
                ClusNode largest = (ClusNode)node.getChild(indexOfLargestBranch);
                node.makeLeaf();
                node.setTest(largest.getTest());
                node.setNbChildren(largest.getNbChildren());
                for (int i = 0; i < largest.getNbChildren(); ++i) {
                    node.setChild(largest.getChild(i), i);
                }
                node.adaptToData(data);
                this.pruneC45Recursive(node, data);
            }
        }
    }

    public double getEstimatedErrorsForDistribution(ClassificationStat stat) {
        if (ClusUtil.eq(stat.getTotalWeight(), 0.0)) {
            return 0.0;
        }
        double nb_incorrect = stat.getError();
        return nb_incorrect + this.addErrs(stat.getTotalWeight(), nb_incorrect, this.m_ConfidenceFactor);
    }

    public double getEstimatedErrorsForBranch(ClusNode node, RowData data) throws ClusException {
        if (node.atBottomLevel()) {
            ClassificationStat stat = (ClassificationStat)node.getTargetStat().cloneStat();
            data.calcTotalStatBitVector(stat);
            return this.getEstimatedErrorsForDistribution(stat);
        }
        double sum = 0.0;
        NodeTest tst = node.getTest();
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            RowData subset = data.applyWeighted(tst, i);
            sum += this.getEstimatedErrorsForBranch(child, subset);
        }
        return sum;
    }

    public double getEstimatedErrors(ClusNode node) {
        if (node.atBottomLevel()) {
            return this.getEstimatedErrorsForDistribution((ClassificationStat)node.getTargetStat());
        }
        double sum = 0.0;
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            sum += this.getEstimatedErrors(child);
        }
        return sum;
    }

    public double addErrs(double N, double e, double CF) {
        if (CF > 0.5) {
            return 0.0;
        }
        if (e < 1.0) {
            double base = N * (1.0 - Math.pow(CF, 1.0 / N));
            if (e == 0.0) {
                return base;
            }
            return base + e * (this.addErrs(N, 1.0, CF) - base);
        }
        if (e + 0.5 >= N) {
            return Math.max(N - e, 0.0);
        }
        double z = this.m_ZScore;
        double f = (e + 0.5) / N;
        double r = (f + z * z / (2.0 * N) + z * Math.sqrt(f / N - f * f / N + z * z / (4.0 * N * N))) / (1.0 + z * z / N);
        return r * N - e;
    }

    @Override
    public void setTrainingData(RowData data) {
        this.m_TrainingData = data;
    }

    public double computeZScore() throws ClusException {
        return new NormalDistribution().inverseCumulativeProbability(1.0 - this.m_ConfidenceFactor);
    }
}

