/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.rules;

import java.io.IOException;
import java.util.ArrayList;
import si.ijs.kt.clus.algo.rules.ClusRule;
import si.ijs.kt.clus.algo.rules.ClusRuleSet;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.ext.ensemble.ros.ClusROSModelInfo;
import si.ijs.kt.clus.ext.optiontree.ClusOptionNode;
import si.ijs.kt.clus.ext.optiontree.ClusSplitNode;
import si.ijs.kt.clus.ext.optiontree.MyNode;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.main.settings.section.SettingsOutput;
import si.ijs.kt.clus.main.settings.section.SettingsRules;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.tools.optimization.OptimizationAlgorithm;
import si.ijs.kt.clus.util.tools.optimization.OptimizationProblem;
import si.ijs.kt.clus.util.tools.optimization.de.DEAlgorithm;
import si.ijs.kt.clus.util.tools.optimization.gd.GDAlgorithm;

public class ClusRulesFromTree {
    protected boolean m_Validated;
    protected SettingsOutput.ConvertRules m_Mode;

    public ClusRulesFromTree(boolean onlyValidated, SettingsOutput.ConvertRules mode) {
        this.m_Validated = onlyValidated;
        this.m_Mode = mode;
    }

    public ClusRuleSet constructRules(ClusRun cr, ClusNode node, ClusStatManager mgr, boolean computeDispersion, SettingsRules.RulePredictionMethod optimizeRuleWeights) throws ClusException, IOException, InterruptedException {
        ClusRuleSet ruleSet = this.constructRules(node, mgr);
        RowData data = (RowData)cr.getTrainingSet();
        if (optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.Optimized) || optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.GDOptimized)) {
            OptimizationAlgorithm optAlg = null;
            OptimizationProblem.OptimizationParameter param = ruleSet.giveFormForWeightOptimization(null, data);
            optAlg = optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.GDOptimized) ? new GDAlgorithm(mgr, param, ruleSet) : new DEAlgorithm(mgr, param, ruleSet);
            ArrayList<Double> weights = optAlg.optimize();
            for (int j = 0; j < ruleSet.getModelSize(); ++j) {
                ruleSet.getRule(j).setOptWeight(weights.get(j));
            }
            ruleSet.removeLowWeightRules();
        }
        if (computeDispersion) {
            ruleSet.addDataToRules(data);
            ruleSet.computeDispersion(0);
            ruleSet.removeDataFromRules();
            if (cr.getTestIter() != null) {
                RowData testdata = cr.getTestSet();
                ruleSet.addDataToRules(testdata);
                ruleSet.computeDispersion(1);
                ruleSet.removeDataFromRules();
            }
        }
        ruleSet.numberRules();
        return ruleSet;
    }

    public ClusRuleSet constructOptionRules(ClusRun cr, MyNode node, ClusStatManager mgr, boolean computeDispersion, SettingsRules.RulePredictionMethod optimizeRuleWeights) throws ClusException, IOException, InterruptedException {
        ClusRuleSet ruleSet = this.constructOptionRules(node, mgr);
        RowData data = (RowData)cr.getTrainingSet();
        if (optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.Optimized) || optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.GDOptimized)) {
            OptimizationAlgorithm optAlg = null;
            OptimizationProblem.OptimizationParameter param = ruleSet.giveFormForWeightOptimization(null, data);
            optAlg = optimizeRuleWeights.equals((Object)SettingsRules.RulePredictionMethod.GDOptimized) ? new GDAlgorithm(mgr, param, ruleSet) : new DEAlgorithm(mgr, param, ruleSet);
            ArrayList<Double> weights = optAlg.optimize();
            System.out.print("The weights for rules from trees:");
            for (int j = 0; j < ruleSet.getModelSize(); ++j) {
                ruleSet.getRule(j).setOptWeight(weights.get(j));
                System.out.print(weights.get(j) + "; ");
            }
            System.out.print("\n");
            ruleSet.removeLowWeightRules();
        }
        if (computeDispersion) {
            ruleSet.addDataToRules(data);
            ruleSet.computeDispersion(0);
            ruleSet.removeDataFromRules();
            if (cr.getTestIter() != null) {
                RowData testdata = cr.getTestSet();
                ruleSet.addDataToRules(testdata);
                ruleSet.computeDispersion(1);
                ruleSet.removeDataFromRules();
            }
        }
        ruleSet.numberRules();
        return ruleSet;
    }

    public ClusRuleSet constructRules(ClusNode node, ClusStatManager mgr) throws ClusException {
        ClusRuleSet ruleSet = new ClusRuleSet(mgr);
        ClusRule init = new ClusRule(mgr);
        this.constructRecursive(node, init, ruleSet);
        boolean useRulesWithTotalAveraging = mgr.getSettings().getRules().getROSAddRulesWithTotalAveraging();
        if (mgr.getSettings().getEnsemble().isEnsembleROSEnabled() && mgr.getSettings().getEnsemble().getEnsembleROSVotingType().equals((Object)SettingsEnsemble.EnsembleROSVotingType.SubspaceAveraging)) {
            ClusROSModelInfo info = node.getROSModelInfo();
            ClusRuleSet rsOriginal = null;
            if (useRulesWithTotalAveraging && info.getTargets().size() != node.getTargetStat().getNbAttributes()) {
                rsOriginal = ruleSet.cloneDeep();
            }
            for (ClusRule rule : ruleSet.getRules()) {
                rule.setROSModelInfo(info);
                rule.postProc();
            }
            if (rsOriginal != null) {
                ruleSet.addRuleSet(rsOriginal);
            }
        }
        ruleSet.removeEmptyRules();
        ruleSet.simplifyRules();
        ruleSet.setTargetStat(node.getTargetStat());
        return ruleSet;
    }

    public ClusRuleSet constructOptionRules(MyNode node, ClusStatManager mgr) {
        ClusRuleSet ruleSet = new ClusRuleSet(mgr);
        ClusRule init = new ClusRule(mgr);
        this.constructRecursiveOption(node, init, ruleSet);
        ruleSet.removeEmptyRules();
        ruleSet.simplifyRules();
        ruleSet.setTargetStat(node.getTargetStat());
        return ruleSet;
    }

    public void constructRecursive(ClusNode node, ClusRule rule, ClusRuleSet set) {
        if ((node.atBottomLevel() || this.m_Mode.equals((Object)SettingsOutput.ConvertRules.AllNodes)) && (!this.m_Validated || node.getTargetStat().isValidPrediction())) {
            rule.setTargetStat(node.getTargetStat());
            rule.setID(node.getID());
            set.add(rule);
        }
        NodeTest test = node.getTest();
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            NodeTest branchTest = test.getBranchTest(i);
            ClusRule child_rule = rule.cloneRule();
            child_rule.addTest(branchTest);
            this.constructRecursive(child, child_rule, set);
        }
    }

    public void constructRecursiveOption(MyNode node, ClusRule rule, ClusRuleSet set) {
        if (node instanceof ClusOptionNode) {
            for (int i = 0; i < node.getNbChildren(); ++i) {
                this.constructRecursiveOption(node.getChild(i), rule, set);
            }
        } else {
            if ((node.atBottomLevel() || this.m_Mode.equals((Object)SettingsOutput.ConvertRules.AllNodes)) && (!this.m_Validated || node.getTargetStat().isValidPrediction())) {
                rule.setTargetStat(node.getTargetStat());
                rule.setID(node.getID());
                set.add(rule);
            }
            ClusSplitNode testnode = (ClusSplitNode)node;
            NodeTest test = testnode.getTest();
            for (int i = 0; i < testnode.getNbChildren(); ++i) {
                MyNode child = testnode.getChild(i);
                NodeTest branchTest = test.getBranchTest(i);
                ClusRule child_rule = rule.cloneRule();
                child_rule.addTest(branchTest);
                this.constructRecursiveOption(child, child_rule, set);
            }
        }
    }
}

