/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.tdidt;

import java.util.Random;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.data.ClusData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.Accuracy;
import si.ijs.kt.clus.error.PearsonCorrelation;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.common.ClusErrorOutput;
import si.ijs.kt.clus.error.hmlc.HierClassWiseAccuracy;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.ClusSummary;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsSIT;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.selection.XValMainSelection;
import si.ijs.kt.clus.selection.XValRandomSelection;
import si.ijs.kt.clus.selection.XValSelection;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.ResourceInfo;
import si.ijs.kt.clus.util.jeans.util.IntervalCollection;

public class ClusSITDecisionTree
extends ClusDecisionTree {
    protected ClusInductionAlgorithmType m_Class;

    public ClusSITDecisionTree(ClusInductionAlgorithmType clss) {
        super(clss.getClus());
        this.m_Class = clss;
    }

    @Override
    public void printInfo() {
        ClusLogger.info("---------SIT---------");
        ClusLogger.info("Heuristic: " + this.getStatManager().getHeuristicName());
    }

    public ClusErrorList createTuneError(ClusStatManager mgr) {
        ClusErrorList parent = new ClusErrorList();
        if (mgr.getTargetMode() == ClusStatManager.Mode.HIERARCHICAL) {
            parent.addError(new HierClassWiseAccuracy(parent, mgr.getHier()));
            return parent;
        }
        NumericAttrType[] num = mgr.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target);
        NominalAttrType[] nom = mgr.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target);
        if (nom.length != 0) {
            parent.addError(new Accuracy(parent, nom));
        }
        if (num.length != 0) {
            parent.addError(new PearsonCorrelation(parent, num));
        }
        return parent;
    }

    private final void showFold(int i, XValMainSelection sel) {
        if (i != 0) {
            System.out.print(" ");
        }
        System.out.print(String.valueOf(i + 1));
        System.out.flush();
    }

    public ClusError doParamXVal(ClusData trset, ClusData pruneset) throws Exception {
        int prevVerb = this.getSettings().getGeneral().enableVerbose(0);
        ClusStatManager mgr = this.getStatManager();
        ClusSummary summ = new ClusSummary();
        summ.setTestError(this.createTuneError(mgr));
        Random random = new Random(0L);
        int nbfolds = 10;
        XValRandomSelection sel = new XValRandomSelection(trset.getNbRows(), nbfolds, random);
        for (int i = 0; i < nbfolds; ++i) {
            this.showFold(i, sel);
            XValSelection msel = new XValSelection(sel, i);
            ClusRun cr = this.m_Clus.partitionDataBasic(trset, msel, pruneset, summ, i + 1);
            ClusModel pruned = this.m_Class.induceSingle(cr);
            cr.addModelInfo(2).setModel(pruned);
            this.m_Clus.calcError(cr, summ, null);
        }
        ClusModelInfo mi = summ.getModelInfo(2);
        this.getSettings().getGeneral().enableVerbose(prevVerb);
        ClusErrorList err_list = mi.getTestError();
        ClusError err = err_list.getFirstError();
        ClusLogger.info();
        return err;
    }

    private void resetWeights(int main_target) {
        this.resetWeights();
        this.getStatManager().getClusteringWeights().m_Weights[main_target] = 1.0;
    }

    private void resetWeights() {
        ClusStatManager mgr = this.getStatManager();
        double[] weights = mgr.getClusteringWeights().m_Weights;
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = 0.0;
        }
    }

    private double addBestSupportTasks(double[] weights, int emc, int[] support_range, ClusData trset, ClusData pruneset) throws Exception {
        int j;
        ClusStatManager mgr = this.getStatManager();
        double[] best_weights = (double[])weights.clone();
        ClusError err = this.doParamXVal(trset, pruneset);
        double best_err = err.getModelErrorComponent(emc);
        System.out.print("Current best Target error: " + best_err + " for targets ");
        for (j = 0; j < weights.length; ++j) {
            if (best_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info();
        for (int i = support_range[0]; i <= support_range[1]; ++i) {
            mgr.getClusteringWeights().m_Weights = (double[])weights.clone();
            if (mgr.getClusteringWeights().m_Weights[i] == 1.0) continue;
            mgr.getClusteringWeights().m_Weights[i] = 1.0;
            System.out.print("Testing targets: ");
            for (int j2 = 0; j2 < weights.length; ++j2) {
                if (mgr.getClusteringWeights().m_Weights[j2] != 1.0) continue;
                System.out.print(j2 + 1 + " ");
            }
            ClusLogger.info();
            err = this.doParamXVal(trset, pruneset);
            ClusLogger.info("Correlation: " + err.getModelErrorComponent(emc));
            if (err.getModelErrorComponent(emc) > best_err) {
                best_err = err.getModelErrorComponent(emc);
                best_weights = (double[])mgr.getClusteringWeights().m_Weights.clone();
            }
            ClusLogger.info();
        }
        ClusLogger.info("Best error: " + best_err);
        System.out.print("Best targets:");
        for (j = 0; j < weights.length; ++j) {
            if (best_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info();
        mgr.getClusteringWeights().m_Weights = best_weights;
        return best_err;
    }

    private double substractBestSupportTasks(double[] weights, int emc, int[] support_range, ClusData trset, ClusData pruneset) throws Exception {
        int j;
        ClusStatManager mgr = this.getStatManager();
        double[] best_weights = (double[])weights.clone();
        ClusError err = this.doParamXVal(trset, pruneset);
        double best_err = err.getModelErrorComponent(emc);
        System.out.print("Current best Target error: " + best_err + " for targets ");
        for (j = 0; j < weights.length; ++j) {
            if (best_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info();
        for (int i = support_range[0]; i <= support_range[1]; ++i) {
            mgr.getClusteringWeights().m_Weights = (double[])weights.clone();
            if (mgr.getClusteringWeights().m_Weights[i] == 0.0) continue;
            mgr.getClusteringWeights().m_Weights[i] = 0.0;
            System.out.print("Testing targets: ");
            for (int j2 = 0; j2 < weights.length; ++j2) {
                if (mgr.getClusteringWeights().m_Weights[j2] != 1.0) continue;
                System.out.print(j2 + 1 + " ");
            }
            ClusLogger.info();
            err = this.doParamXVal(trset, pruneset);
            ClusLogger.info("Correlation: " + err.getModelErrorComponent(emc));
            if (err.getModelErrorComponent(emc) > best_err) {
                best_err = err.getModelErrorComponent(emc);
                best_weights = (double[])mgr.getClusteringWeights().m_Weights.clone();
            }
            ClusLogger.info();
        }
        ClusLogger.info("Best error: " + best_err);
        System.out.print("Best targets:");
        for (j = 0; j < weights.length; ++j) {
            if (best_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        mgr.getClusteringWeights().m_Weights = best_weights;
        return best_err;
    }

    public void findBestSupportTasks(ClusData trset, ClusData pruneset) throws Exception {
        ClusStatManager mgr = this.getStatManager();
        Settings settings = mgr.getSettings();
        int main_target = new Integer(settings.getSIT().getMainTarget()) - 1;
        IntervalCollection targets = new IntervalCollection(this.getSettings().getAttribute().getTarget());
        int[] support_range = new int[]{targets.getMinIndex() - 1, targets.getMaxIndex() - 1};
        int emc = main_target - support_range[0];
        boolean recursive = settings.getSIT().getRecursive();
        this.resetWeights(main_target);
        double[] weights = mgr.getClusteringWeights().m_Weights;
        double best_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
        if (recursive) {
            ClusLogger.info("\n---recursive sit---");
            weights = mgr.getClusteringWeights().m_Weights;
            double new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            while (new_err > best_err) {
                best_err = new_err;
                weights = mgr.getClusteringWeights().m_Weights;
                new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            }
        }
        ClusLogger.info();
    }

    public void twoSidedSit(ClusData trset, ClusData pruneset) throws Exception {
        double ST_err;
        ClusStatManager mgr = this.getStatManager();
        SettingsSIT settings = mgr.getSettings().getSIT();
        int main_target = new Integer(settings.getMainTarget()) - 1;
        IntervalCollection targets = new IntervalCollection(this.getSettings().getAttribute().getTarget());
        int[] support_range = new int[]{targets.getMinIndex() - 1, targets.getMaxIndex() - 1};
        int emc = main_target - support_range[0];
        this.resetWeights(main_target);
        double[] weights = mgr.getClusteringWeights().m_Weights;
        ClusError err = this.doParamXVal(trset, pruneset);
        double best_err = ST_err = err.getModelErrorComponent(emc);
        ClusLogger.info("Estimated ST error: " + ST_err);
        for (int i = support_range[0]; i <= support_range[1]; ++i) {
            mgr.getClusteringWeights().m_Weights[i] = 1.0;
        }
        err = this.doParamXVal(trset, pruneset);
        double MT_err = err.getModelErrorComponent(emc);
        ClusLogger.info("Estimated MT error: " + MT_err);
        if (MT_err > ST_err) {
            best_err = MT_err;
            ClusLogger.info("\n---recursive sub sit---");
            weights = mgr.getClusteringWeights().m_Weights;
            double new_err = this.substractBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            while (new_err > best_err) {
                best_err = new_err;
                weights = mgr.getClusteringWeights().m_Weights;
                new_err = this.substractBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            }
        } else {
            ClusLogger.info("\n---recursive add sit---");
            this.resetWeights(main_target);
            weights = mgr.getClusteringWeights().m_Weights;
            double new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            while (new_err > best_err) {
                best_err = new_err;
                weights = mgr.getClusteringWeights().m_Weights;
                new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            }
        }
        ClusLogger.info();
    }

    public void superSit(ClusData trset, ClusData pruneset) throws Exception {
        double ST_err;
        ClusStatManager mgr = this.getStatManager();
        Settings settings = mgr.getSettings();
        int main_target = new Integer(settings.getSIT().getMainTarget()) - 1;
        IntervalCollection targets = new IntervalCollection(settings.getAttribute().getTarget());
        int[] support_range = new int[]{targets.getMinIndex() - 1, targets.getMaxIndex() - 1};
        int emc = main_target - support_range[0];
        this.resetWeights(main_target);
        double[] weights = mgr.getClusteringWeights().m_Weights;
        ClusError err = this.doParamXVal(trset, pruneset);
        double best_err = ST_err = err.getModelErrorComponent(emc);
        ClusLogger.info("Estimated ST error: " + ST_err);
        double[] starting_weights = (double[])weights.clone();
        for (int i = support_range[0]; i <= support_range[1]; ++i) {
            mgr.getClusteringWeights().m_Weights[i] = 1.0;
            err = this.doParamXVal(trset, pruneset);
            double MT_err = err.getModelErrorComponent(emc);
            if (!(MT_err > ST_err)) continue;
            starting_weights[i] = 1.0;
            ClusLogger.info("Adding target " + (i + 1) + " to starting set");
        }
        mgr.getClusteringWeights().m_Weights = starting_weights;
        System.out.print("Starting from targets ");
        for (int j = 0; j < weights.length; ++j) {
            if (starting_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info("\n---recursive sit---");
        weights = starting_weights;
        double new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
        new_err = this.substractBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
        while (new_err > best_err) {
            best_err = new_err;
            weights = mgr.getClusteringWeights().m_Weights;
            new_err = this.addBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
            new_err = this.substractBestSupportTasks((double[])weights.clone(), emc, support_range, trset, pruneset);
        }
        ClusLogger.info();
    }

    public void sweepSit(ClusData trset, ClusData pruneset) throws Exception {
        int j;
        ClusStatManager mgr = this.getStatManager();
        Settings settings = mgr.getSettings();
        int main_target = new Integer(settings.getSIT().getMainTarget()) - 1;
        IntervalCollection targets = new IntervalCollection(settings.getAttribute().getTarget());
        int[] support_range = new int[]{targets.getMinIndex() - 1, targets.getMaxIndex() - 1};
        int emc = main_target - support_range[0];
        this.resetWeights(main_target);
        double[] weights = mgr.getClusteringWeights().m_Weights;
        ClusError err = this.doParamXVal(trset, pruneset);
        double best_err = err.getModelErrorComponent(emc);
        double[] selected_weights = (double[])weights.clone();
        boolean improved = true;
        System.out.print("Set before sweeping: ");
        for (j = 0; j < weights.length; ++j) {
            if (selected_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info();
        while (improved) {
            improved = false;
            for (int i = support_range[0]; i <= support_range[1]; ++i) {
                int j2;
                double MT_err;
                if (mgr.getClusteringWeights().m_Weights[i] == 0.0) {
                    mgr.getClusteringWeights().m_Weights[i] = 1.0;
                    err = this.doParamXVal(trset, pruneset);
                    MT_err = err.getModelErrorComponent(emc);
                    if (MT_err > best_err) {
                        best_err = MT_err;
                        selected_weights[i] = 1.0;
                        System.out.print("Adding target " + (i + 1) + " to selected set: ");
                        for (j2 = 0; j2 < weights.length; ++j2) {
                            if (selected_weights[j2] != 1.0) continue;
                            System.out.print(j2 + 1 + " ");
                        }
                        ClusLogger.info();
                        mgr.getClusteringWeights().m_Weights = (double[])selected_weights.clone();
                        improved = true;
                        continue;
                    }
                    mgr.getClusteringWeights().m_Weights = (double[])selected_weights.clone();
                    continue;
                }
                mgr.getClusteringWeights().m_Weights[i] = 0.0;
                err = this.doParamXVal(trset, pruneset);
                MT_err = err.getModelErrorComponent(emc);
                if (MT_err > best_err) {
                    best_err = MT_err;
                    selected_weights[i] = 0.0;
                    System.out.print("Removing target " + (i + 1) + " from selected set: ");
                    for (j2 = 0; j2 < weights.length; ++j2) {
                        if (selected_weights[j2] != 1.0) continue;
                        System.out.print(j2 + 1 + " ");
                    }
                    ClusLogger.info();
                    mgr.getClusteringWeights().m_Weights = (double[])selected_weights.clone();
                    improved = true;
                    continue;
                }
                mgr.getClusteringWeights().m_Weights = (double[])selected_weights.clone();
            }
        }
        System.out.print("Final targets ");
        for (j = 0; j < weights.length; ++j) {
            if (selected_weights[j] != 1.0) continue;
            System.out.print(j + 1 + " ");
        }
        ClusLogger.info();
    }

    public void exhaustiveSearch(ClusRun cr) throws Exception {
        ClusStatManager mgr = this.getStatManager();
        Settings settings = mgr.getSettings();
        IntervalCollection targets = new IntervalCollection(settings.getAttribute().getTarget());
        int[] support_range = new int[]{targets.getMinIndex() - 1, targets.getMaxIndex() - 1};
        this.resetWeights();
        double[] weights = mgr.getClusteringWeights().m_Weights;
        ClusErrorOutput errOutput = new ClusErrorOutput(settings.getGeneric().getAppName() + ".err", settings);
        int n = support_range[1] - support_range[0] + 1;
        for (int B = 0; B < 1 << n; ++B) {
            for (int b = 0; b < n; ++b) {
                weights[b + support_range[0]] = (B & 1 << b) > 0 ? 1.0 : 0.0;
            }
            ClusLogger.info();
            for (int j = 0; j < weights.length; ++j) {
                if (weights[j] != 1.0) continue;
                System.out.print(j + 1 + " ");
            }
            errOutput.writeOutput(cr, false, false, weights);
        }
        this.doParamXVal(cr.getTrainingSet(), cr.getPruneSet());
    }

    @Override
    public void induceAll(ClusRun cr) throws Exception {
        long start_time = ResourceInfo.getTime();
        long done_time = ResourceInfo.getTime();
        ClusLogger.info("----------Building final model------------");
        this.m_Class.induceAll(cr);
        cr.setInductionTime(done_time - start_time);
    }
}

