/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.rules.probabilistic;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import java.util.function.Function;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.rules.ClusRuleInduce;
import si.ijs.kt.clus.algo.rules.ClusRuleSet;
import si.ijs.kt.clus.algo.rules.ClusRulesFromTree;
import si.ijs.kt.clus.algo.rules.probabilistic.ClusRuleHelperMethods;
import si.ijs.kt.clus.algo.rules.probabilistic.ClusRuleProbabilisticRuleSetInduceWeights;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.ext.ensemble.ClusEnsembleInduce;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.ext.optiontree.DepthFirstInduceWithOptions;
import si.ijs.kt.clus.ext.optiontree.MyNode;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.ClusSummary;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsOutput;
import si.ijs.kt.clus.main.settings.section.SettingsRules;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.tools.optimization.sls.OptSmoothLocalSearch;

public class ClusRuleProbabilisticRuleSetInduce
extends ClusRuleInduce {
    long m_seed = 1L;
    Integer m_maxRuleCardinality = null;
    Integer m_maxRulesNb = null;
    double m_validationRatio = 0.05;
    Random m_randGen;
    ClusRun m_mainClusRun = null;
    Clus m_mainClus = null;
    RowData m_trainingData = null;
    RowData m_validationData = null;
    RowData m_originalFullData = null;
    int m_ensembleSize;
    boolean estimateWeights = true;
    String m_folderName = "rules.debug";

    public ClusRuleProbabilisticRuleSetInduce(ClusSchema schema, Settings sett, Clus clus) throws ClusException, IOException {
        super(schema, sett);
        if (sett.getGeneral().hasRandomSeed()) {
            this.m_seed = sett.getGeneral().getRandomSeed();
        }
        this.m_randGen = new Random(this.m_seed);
        this.m_mainClus = clus;
        this.m_maxRuleCardinality = sett.getRules().getMaxRuleCardinality();
        this.m_maxRulesNb = sett.getRules().getMaxRulesNb();
        this.m_validationRatio = sett.getRules().getValidationSetPercentage();
        this.m_ensembleSize = sett.getEnsemble().getNbBaggingSets().getInt();
    }

    void splitData(double validationRatio) {
        RowData all_data = (RowData)this.m_originalFullData.cloneData();
        all_data.addIndices();
        int validationSize = (int)((double)all_data.getNbRows() * validationRatio);
        int trainSize = all_data.getNbRows() - validationSize;
        if ((double)trainSize < 1.0) {
            System.err.println("Validation set is too big!");
        }
        ClusLogger.info(String.format("Splitting learning data: Train set: %s examples | Validation set: %s examples", trainSize, validationSize));
        RowData validationData = new RowData(this.getSchema());
        RowData trainData = new RowData(this.getSchema());
        int newIndex = -1;
        ArrayList<Integer> trainingExamples = new ArrayList<Integer>();
        ArrayList<Integer> validationExamples = new ArrayList<Integer>();
        while (validationData.getNbRows() < validationSize) {
            newIndex = this.m_randGen.nextInt(all_data.getNbRows() - 1);
            if (validationExamples.contains(newIndex)) continue;
            validationData.add(all_data.getTuple(newIndex).cloneTuple());
            validationExamples.add(newIndex);
        }
        for (int i = 0; i < all_data.getNbRows(); ++i) {
            if (trainingExamples.contains(i) || validationExamples.contains(i)) continue;
            trainData.add(all_data.getTuple(i).cloneTuple());
            trainingExamples.add(i);
        }
        this.m_trainingData = trainData;
        this.m_validationData = validationData;
    }

    @Override
    public void induceAll(ClusRun mainClusRun) throws Exception {
        ClusRuleSet initialRuleSet;
        this.m_mainClusRun = mainClusRun;
        this.m_originalFullData = (RowData)mainClusRun.getTrainingSet();
        ClusLogger.info("Constraints: MaxRuleCardinality = " + this.m_maxRuleCardinality + " MaxRuleSetSize = " + this.m_maxRulesNb);
        ClusRuleProbabilisticRuleSetInduceWeights objWeights = new ClusRuleProbabilisticRuleSetInduceWeights();
        ClusRuleProbabilisticRuleSetInduceWeights bestWeights = new ClusRuleProbabilisticRuleSetInduceWeights();
        if (this.estimateWeights) {
            ClusLogger.info("ESTIMATING WEIGHTS");
            this.getSettings().getEnsemble().setNbBags(1);
            initialRuleSet = this.getInitialRuleSet(1.0 - this.m_validationRatio);
            initialRuleSet.calculateCoverageBitVectors(this.m_originalFullData);
            this.calculateDefaultRuleAndPrototypesForRuleSet(initialRuleSet);
            ClusRuleSet finalSet = null;
            double previousBestScore = -1.0;
            String bestWeightsString = "";
            double lowerBound = 0.0;
            double upperBound = 2.0;
            int upperMax = 1000;
            int lowerMin = -1000;
            double step = 10.0;
            for (double lbd = lowerBound; lbd <= upperBound && lowerBound > (double)lowerMin && upperBound < (double)upperMax; lbd += 0.1) {
                double d;
                objWeights.objectiveAccuracyEnabled = true;
                objWeights.objectiveSizeEnabled = true;
                objWeights.WEIGHT_OBJECTIVE_ACCURACY = 1.0;
                objWeights.WEIGHT_OBJECTIVE_SIZE = lbd;
                Function<ClusRuleSet, Double> objectiveFunction = this.getObjectiveFunction(initialRuleSet, objWeights);
                ClusRuleSet currentSet = this.runOnce(initialRuleSet, objectiveFunction);
                double tmpval = objectiveFunction.apply(currentSet);
                if (d > previousBestScore) {
                    finalSet = currentSet;
                    previousBestScore = tmpval;
                    bestWeightsString = objWeights.getWeightsString();
                    bestWeights.setWeights(objWeights.getWeights());
                }
                System.err.println(lbd + ";" + currentSet.computeErrorScore(this.m_validationData) + ";" + currentSet.getModelSize() + ";" + tmpval);
                if (lbd != upperBound) continue;
                if (bestWeights.WEIGHT_OBJECTIVE_SIZE == upperBound) {
                    lowerBound = upperBound;
                    upperBound += step;
                    lbd = lowerBound - 1.0;
                    continue;
                }
                if (bestWeights.WEIGHT_OBJECTIVE_SIZE != 1.0) continue;
                upperBound = lowerBound;
                lbd = lowerBound -= step;
            }
            System.err.println("BEST:\nWeights: " + bestWeightsString + "\nRule set size: " + finalSet.getModelSize() + "\nRMSE: " + finalSet.computeErrorScore(this.m_validationData));
            this.estimateWeights = false;
        }
        this.getSettings().getEnsemble().setNbBags(this.m_ensembleSize);
        initialRuleSet = this.getInitialRuleSet(this.m_validationRatio);
        initialRuleSet.calculateCoverageBitVectors(this.m_originalFullData);
        this.calculateDefaultRuleAndPrototypesForRuleSet(initialRuleSet);
        bestWeights.WEIGHT_OBJECTIVE_ACCURACY = 100.0;
        Function<ClusRuleSet, Double> objectiveFunction = this.getObjectiveFunction(initialRuleSet, bestWeights);
        ClusRuleSet currentSet = this.runOnce(initialRuleSet, objectiveFunction);
        ClusRuleHelperMethods.debug_PrintRuleSet(currentSet, this.m_folderName, "final_rules.txt", true, false);
        ClusRuleHelperMethods.debug_PrintRuleSet(currentSet, this.m_folderName, "final_tests.txt", false, true);
        this.m_mainClusRun.addModelInfo(0);
        this.m_mainClusRun.addModelInfo(2);
        ClusModelInfo rules_model_info = this.m_mainClusRun.getModelInfo(2);
        rules_model_info.setName("Rules");
        rules_model_info.setModel(currentSet);
        System.err.println("INDUCED: " + currentSet.getModelSize() + "; RMSE: " + currentSet.computeErrorScore(this.m_validationData));
    }

    ClusRuleSet runOnce(ClusRuleSet initialRuleSet, Function<ClusRuleSet, Double> objectiveFunction) throws ClusException, InterruptedException {
        OptSmoothLocalSearch optAlg = new OptSmoothLocalSearch(this.m_maxRulesNb, this.m_randGen, this.getStatManager(), this.m_mainClusRun, this);
        double bias = 0.3333333333333333;
        ClusRuleSet optimizedSet1 = optAlg.SmoothLocalSearch(initialRuleSet, bias, bias, objectiveFunction);
        ClusRuleSet optimizedSet2 = optAlg.SmoothLocalSearch(initialRuleSet, bias, -1.0, objectiveFunction);
        ClusRuleSet finalSet = objectiveFunction.apply(optimizedSet1) > objectiveFunction.apply(optimizedSet2) ? optimizedSet1 : optimizedSet2;
        this.calculateDefaultRuleAndPrototypesForRuleSet(finalSet);
        return finalSet;
    }

    public ClusRuleSet calculateDefaultRuleAndPrototypesForRuleSet(ClusRuleSet ruleSet) throws ClusException, InterruptedException {
        RowData trainingData;
        RowData uncoveredData = trainingData = (RowData)this.m_mainClusRun.getTrainingSet();
        ClusStatistic defaultRule = this.createTotalTargetStat(uncoveredData);
        defaultRule = this.getStatManager().getTrainSetStat(ClusAttrType.AttributeUseType.Target).cloneStat();
        defaultRule.copy(this.getStatManager().getTrainSetStat(ClusAttrType.AttributeUseType.Target));
        defaultRule.calcMean();
        ruleSet.setTargetStat(defaultRule);
        ruleSet.postProc();
        return ruleSet;
    }

    Function<ClusRuleSet, Double> getObjectiveFunction(ClusRuleSet initialRuleSet, final ClusRuleProbabilisticRuleSetInduceWeights weights) {
        int tmpMaxGlobalRuleCardinality = 0;
        for (int r = 0; r < initialRuleSet.getRules().size(); ++r) {
            if (tmpMaxGlobalRuleCardinality >= initialRuleSet.getRule(r).getModelSize()) continue;
            tmpMaxGlobalRuleCardinality = initialRuleSet.getRule(r).getModelSize();
        }
        final int N = this.m_trainingData.getNbRows();
        final int maxGlobalRuleCardinality = tmpMaxGlobalRuleCardinality;
        final int initialRuleSetSize = initialRuleSet.getModelSize();
        final Function<ClusRuleSet, Double> objectiveRuleSetSize = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                return (double)(initialRuleSetSize - t.getRules().size()) / (double)initialRuleSetSize;
            }
        };
        final Function<ClusRuleSet, Double> objectiveCardinality = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                double f = 0.0;
                for (int r = 0; r < t.getRules().size(); ++r) {
                    f += (double)(maxGlobalRuleCardinality - t.getRule(r).getTests().size());
                }
                return f;
            }
        };
        final Function<ClusRuleSet, Double> objectiveRuleOverlap = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                double f = 0.0;
                for (int r1 = 0; r1 < t.getRules().size(); ++r1) {
                    for (int r2 = r1 + 1; r2 < t.getRules().size(); ++r2) {
                        f += (double)((float)N - t.getRule(r1).overlap(t.getRule(r2)));
                    }
                }
                return f;
            }
        };
        final Function<ClusRuleSet, Double> objectiveAccuracy = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                double errorScore = -1.0;
                try {
                    errorScore = t.computeErrorScore(ClusRuleProbabilisticRuleSetInduce.this.m_validationData);
                    return 1.0 / Math.abs(errorScore + 1.0);
                }
                catch (ClusException e) {
                    e.printStackTrace();
                    return Double.POSITIVE_INFINITY;
                }
            }
        };
        final Function<ClusRuleSet, Double> objectiveCorrectCover = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                int correctCoverCount = 0;
                for (int example = 0; example < ClusRuleProbabilisticRuleSetInduce.this.m_validationData.getNbRows(); ++example) {
                    DataTuple tuple = ClusRuleProbabilisticRuleSetInduce.this.m_validationData.getTuple(example);
                    if (t.correctCoverRuleCount(tuple) < 1) continue;
                    ++correctCoverCount;
                }
                return correctCoverCount;
            }
        };
        final Function<ClusRuleSet, Double> objectiveIncorrectCover = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                double f = t.incorrectCoverAcrossAllRules();
                f = (double)(N * t.getRules().size()) - f;
                return f;
            }
        };
        Function<ClusRuleSet, Double> objectiveFunction = new Function<ClusRuleSet, Double>(){

            @Override
            public Double apply(ClusRuleSet t) {
                double sum = (weights.isEnabledObjectiveAccuracy() ? weights.WEIGHT_OBJECTIVE_ACCURACY * (Double)objectiveAccuracy.apply(t) : 0.0) + (weights.isEnabledObjectiveSize() ? weights.WEIGHT_OBJECTIVE_SIZE * (Double)objectiveRuleSetSize.apply(t) : 0.0) + (weights.isEnabledObjectiveCardinality() ? weights.WEIGHT_OBJECTIVE_CARDINALITY * (Double)objectiveCardinality.apply(t) : 0.0) + (weights.isEnabledObjectiveOverlap() ? weights.WEIGHT_OBJECTIVE_OVERLAP * (Double)objectiveRuleOverlap.apply(t) : 0.0) + (weights.isEnabledObjectiveCorrectCover() ? weights.WEIGHT_OBJECTIVE_CORRECT_COVER * (Double)objectiveCorrectCover.apply(t) : 0.0) + (weights.isEnabledObjectiveIncorrectCover() ? weights.WEIGHT_OBJECTIVE_INCORRECT_COVER * (Double)objectiveIncorrectCover.apply(t) : 0.0);
                return sum;
            }
        };
        return objectiveFunction;
    }

    ClusRuleSet getInitialRuleSet(double validationRatio) throws Exception {
        this.splitData(validationRatio);
        int tmpVerbose = this.getSettings().getGeneral().enableVerbose(0);
        SettingsRules.InitialRuleGeneratingMethod generatingMode = this.getSettings().getRules().getInitialRuleGeneratingMethod();
        ClusRuleSet initialRules = null;
        switch (generatingMode) {
            case RandomForest: {
                initialRules = this.getRandomForestRules(this.m_trainingData);
                break;
            }
            case OptionTree: {
                initialRules = this.getOptionTreeRules(this.m_trainingData);
                break;
            }
            default: {
                throw new RuntimeException("Unknown initial rules generation method.");
            }
        }
        this.getSettings().getGeneral().enableVerbose(tmpVerbose);
        if (initialRules.getModelSize() == 0) {
            throw new RuntimeException(String.format("No rules have been induced! Decrease validation set size (current value: %s).", this.getSettings().getRules().getValidationSetPercentage()));
        }
        for (int r = initialRules.getModelSize() - 1; r >= 0; --r) {
            if (this.m_maxRuleCardinality >= initialRules.getRule(r).getModelSize()) continue;
            initialRules.remove(initialRules.getRule(r));
        }
        initialRules.numberRules();
        ClusRuleHelperMethods.debug_PrintInitialRules(initialRules, this.m_folderName, "InitialRuleSet", true);
        ClusRuleHelperMethods.debug_PrintInitialRules(initialRules, this.m_folderName, "InitialRuleSet_Reduced", false);
        return initialRules;
    }

    ClusRuleSet getRandomForestRules(RowData dataToUse) throws Exception {
        ClusLogger.info("Inducing random forest for initial rule set");
        ClusLogger.info("==============================================================================");
        boolean ensembleMode = this.getSettings().getEnsemble().isEnsembleMode();
        boolean sectionEnsembleEnabled = this.getSettings().getEnsemble().isSectionEnsembleEnabled();
        this.getSettings().getEnsemble().setSectionEnsembleEnabled(true);
        this.getSettings().getEnsemble().setEnsembleMode(true);
        this.getSettings().getRules().disableRuleInduceParams();
        this.m_mainClusRun.setTrainingSet(dataToUse);
        ClusEnsembleInduce ensemble = new ClusEnsembleInduce((ClusInductionAlgorithm)this, this.m_mainClus);
        ClusRun forestRun = new ClusRun(this.m_mainClusRun);
        ensemble.induceAll(forestRun);
        this.getSettings().getRules().returnRuleInduceParams();
        ClusForest forestModel = (ClusForest)forestRun.getModel(1);
        ClusRulesFromTree treeTransform = new ClusRulesFromTree(true, SettingsOutput.ConvertRules.AllNodes);
        ClusRuleSet ruleSet = new ClusRuleSet(this.getStatManager());
        int numberOfUniqueRules = 0;
        for (int iTree = 0; iTree < forestModel.getNbModels(); ++iTree) {
            ClusNode treeRootNode = (ClusNode)forestModel.getModel(iTree);
            numberOfUniqueRules += ruleSet.addRuleSet(treeTransform.constructRules(treeRootNode, this.getStatManager()));
        }
        ClusLogger.info("Transformed " + forestModel.getNbModels() + " trees in ensemble into rules.\n\tCreated " + ruleSet.getModelSize() + " rules. (" + numberOfUniqueRules + " of them are unique.)");
        this.getSettings().getEnsemble().setSectionEnsembleEnabled(sectionEnsembleEnabled);
        this.getSettings().getEnsemble().setEnsembleMode(ensembleMode);
        ClusLogger.info("==============================================================================");
        if (!this.estimateWeights) {
            ClusModelInfo ensemble_model_info_default = forestRun.getModelInfo(0);
            this.m_mainClusRun.addModelInfo(ensemble_model_info_default);
            ClusModelInfo ensemble_model_info_original = forestRun.getModelInfo(1);
            ensemble_model_info_original.setName("RF Ensemble");
            this.m_mainClusRun.addModelInfo(ensemble_model_info_original);
            ClusModelInfo rules_model_info = new ClusModelInfo("Rules");
            rules_model_info.setModel(ruleSet);
            this.m_mainClusRun.addModelInfo(rules_model_info);
        }
        return ruleSet;
    }

    ClusRuleSet getOptionTreeRules(RowData dataToUse) throws ClusException, IOException, InterruptedException {
        boolean sectionOptionEnabled = this.m_mainClus.getSettings().getOptionTree().isSectionOptionEnabled();
        this.m_mainClus.getSettings().getOptionTree().setSectionOptionEnabled(true);
        this.m_mainClus.getStatManager().setRuleInduceOnly(false);
        ClusStatManager mngr = new ClusStatManager(this.getSchema(), this.getSettings());
        ClusSummary summary = new ClusSummary();
        summary.setStatManager(mngr);
        ClusRun optionRun = new ClusRun(this.m_trainingData, summary);
        DepthFirstInduceWithOptions inducer = new DepthFirstInduceWithOptions(optionRun.getStatManager().getSchema(), optionRun.getStatManager().getSettings());
        inducer.initialize();
        inducer.getStatManager().createClusteringStat();
        inducer.getStatManager().createTargetStat();
        inducer.getStatManager().initClusteringWeights();
        inducer.getStatManager().initBeamSearchHeuristic();
        inducer.getStatManager().initHeuristic();
        inducer.getStatManager().initStopCriterion();
        MyNode optionModel = (MyNode)inducer.induceSingleUnpruned(optionRun);
        ClusModelInfo defInfo = this.m_mainClusRun.addModelInfo(0);
        ClusModel defModel = ClusDecisionTree.induceDefault(this.m_mainClusRun);
        defInfo.setModel(defModel);
        ClusModelInfo origInfo = this.m_mainClusRun.addModelInfo(1);
        origInfo.setName("Option Tree");
        origInfo.setModel(optionModel);
        ClusRulesFromTree treeTransform = new ClusRulesFromTree(true, SettingsOutput.ConvertRules.AllNodes);
        ClusRuleSet ruleSet = new ClusRuleSet(this.getStatManager());
        int numberOfUniqueRules = 0;
        numberOfUniqueRules += ruleSet.addRuleSet(treeTransform.constructOptionRules(optionModel, this.getStatManager()));
        for (int i = 0; i < ruleSet.getModelSize(); ++i) {
            ruleSet.getRule(i).getTargetStat().calcMean();
        }
        ClusLogger.info("Transformed an Option Tree into " + ruleSet.getModelSize() + " rules. (" + numberOfUniqueRules + " of them are unique.)");
        this.m_mainClus.getSettings().getOptionTree().setSectionOptionEnabled(sectionOptionEnabled);
        return ruleSet;
    }

    @Override
    public ClusModel induceSingleUnpruned(ClusRun cr) throws ClusException, IOException {
        throw new RuntimeException("induceSingleUnpruned()");
    }
}

