/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.semisupervised;

import java.io.IOException;
import java.io.PrintWriter;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusData;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.ext.semisupervised.ClusSemiSupervisedInduce;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsSSL;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.selection.RandomSelection;
import si.ijs.kt.clus.selection.XValRandomSelection;
import si.ijs.kt.clus.selection.XValSelection;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class ClusSemiSupervisedPCTs
extends ClusSemiSupervisedInduce {
    ClusInductionAlgorithm m_Induce;
    ClusModel m_Model;
    double[] m_ParameterValues;
    int m_InternalXValFolds;
    boolean m_IsInternalXVal;
    boolean m_Pruning;
    String m_ScoresPath;
    boolean m_SaveScores;
    PrintWriter writer;
    int[] m_InternalFolds;
    boolean m_InduceMain;

    public ClusSemiSupervisedPCTs(ClusInductionAlgorithm clss_induce) throws ClusException, IOException {
        super(clss_induce);
        this.m_Induce = clss_induce;
        this.initialize(this.getSchema(), this.getSettings());
    }

    public void initialize(ClusSchema schema, Settings settx) {
        SettingsSSL sett = settx.getSSL();
        this.m_ParameterValues = sett.getSSLPossibleWeights();
        this.m_InternalXValFolds = sett.getSSLInternalFolds();
        this.m_Pruning = sett.getSSLPruningWhenTuning();
        this.m_ScoresPath = sett.getSSLWeightScoresFile();
        this.m_SaveScores = !this.m_ScoresPath.equalsIgnoreCase("NO");
        boolean bl = this.m_IsInternalXVal = sett.shouldForceInternalXVal() || this.m_ParameterValues.length > 1;
        if (this.m_IsInternalXVal && !this.m_SaveScores) {
            ClusLogger.info("It does not make any sense to force running the internal XVAL and not saving the score. Just saying.");
        }
        this.m_PercentageLabeled = (double)sett.getPercentageLabeled() / 100.0;
        this.m_InternalFolds = sett.getInternalFoldIndices();
        this.m_InduceMain = sett.shouldInduceMain();
    }

    @Override
    public void initialize() throws ClusException, IOException {
        this.m_Induce.initialize();
    }

    @Override
    public ClusModel induceSingleUnpruned(ClusRun crOriginal) throws Exception {
        ClusRun cr = new ClusRun(crOriginal);
        cr.setTrainingSet(crOriginal.getTrainingSet().cloneData());
        this.partitionData(cr);
        if (this.m_SaveScores) {
            this.writer = new PrintWriter(this.m_ScoresPath);
        }
        boolean lowerIsBetter = false;
        double bestError = 0.0;
        NumericAttrType[] num = this.m_Schema.getNumericAttrUse(ClusAttrType.AttributeUseType.Target);
        NominalAttrType[] nom = this.m_Schema.getNominalAttrUse(ClusAttrType.AttributeUseType.Target);
        if (nom.length != 0) {
            lowerIsBetter = false;
            bestError = Double.MIN_VALUE;
        } else if (num.length != 0) {
            lowerIsBetter = true;
            bestError = Double.MAX_VALUE;
        }
        ClusAttrType[] clustering = this.m_Schema.getAllAttrUse(ClusAttrType.AttributeUseType.Clustering);
        double bestWeight = 0.0;
        if (this.m_IsInternalXVal) {
            XValRandomSelection xvalmain = new XValRandomSelection(cr.getTrainingSet().getNbRows(), this.m_InternalXValFolds);
            for (double weight : this.m_ParameterValues) {
                double error = 0.0;
                for (int fold : this.m_InternalFolds) {
                    ClusLogger.info(String.format("Inducing fold %d for weihgt %f", fold, weight));
                    XValSelection sel = new XValSelection(xvalmain, fold);
                    ClusRun foldRun = new ClusRun(cr);
                    foldRun.setTrainingSet(foldRun.getTrainingSet().cloneData());
                    ClusData val = foldRun.getTrainingSet().select(sel);
                    foldRun.setTestSet(((RowData)val).getIterator());
                    RowData trainingSet = (RowData)foldRun.getTrainingSet();
                    if (weight != 1.0) {
                        for (int i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                            trainingSet.add(this.m_UnlabeledData.getTuple(i));
                        }
                        foldRun.setTrainingSet(trainingSet);
                    }
                    this.adjustNumberOfTrees(foldRun);
                    foldRun.getStatManager().initNormalizationWeights(foldRun.getStatManager().createStatistic(ClusAttrType.AttributeUseType.All), foldRun.getTrainingSet());
                    ClusAttributeWeights clusteringWeights = foldRun.getStatManager().getClusteringWeights();
                    this.setWeights(clusteringWeights, clustering, weight);
                    ClusModel model = this.m_Induce.induceSingleUnpruned(foldRun);
                    if (this.m_Pruning) {
                        ClusNode orig = (ClusNode)model;
                        orig.numberTree();
                        PruneTree pruner = this.m_Induce.getStatManager().getTreePruner(trainingSet);
                        pruner.setTrainingData(trainingSet);
                        pruner.prune(orig);
                    }
                    error += this.calculateError(model, foldRun.getTestSet(), foldRun.getTestSet().getNbRows()).getModelError();
                }
                if (this.m_SaveScores) {
                    this.writer.print(weight);
                    this.writer.print(',');
                    this.writer.println(error);
                }
                if (!(lowerIsBetter && error <= bestError) && (lowerIsBetter || !(error >= bestError))) continue;
                bestError = error;
                bestWeight = weight;
            }
            if (this.m_SaveScores) {
                this.writer.println();
                this.writer.print(bestWeight);
                this.writer.print(',');
                this.writer.println(bestError);
                this.writer.close();
            }
            if (bestWeight != 1.0) {
                RowData trainingSet = (RowData)cr.getTrainingSet();
                for (int i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                    trainingSet.add(this.m_UnlabeledData.getTuple(i));
                }
                cr.setTrainingSet(trainingSet);
            }
            ClusLogger.info();
            ClusLogger.info("Weight parameter w = " + bestWeight + " for SSL-PCT algorithm was selected via " + this.m_InternalXValFolds + "-fold internal cross validation");
            ClusLogger.info();
        } else {
            bestWeight = this.m_ParameterValues[0];
        }
        cr.getStatManager().initNormalizationWeights(cr.getStatManager().createStatistic(ClusAttrType.AttributeUseType.All), cr.getTrainingSet());
        ClusAttributeWeights clusteringWeights = cr.getStatManager().getClusteringWeights();
        this.setWeights(clusteringWeights, clustering, bestWeight);
        if (!this.m_InduceMain) {
            ClusLogger.info("Internal folds computed. Exiting now.");
            ClusLogger.info("Done.");
            System.exit(0);
        }
        return this.m_Induce.induceSingleUnpruned(cr);
    }

    public void setWeights(ClusAttributeWeights clusteringWeights, ClusAttrType[] clustering, double weight) throws ClusException {
        int nbClustering = clustering.length;
        double nbTarget = this.m_Schema.getNbTargetAttributes();
        double nbOther = (double)nbClustering - nbTarget;
        if (this.m_StatManager.getClusterMode() == ClusStatManager.Mode.HIER_CLASS_AND_REG) {
            int nbHierClasses = this.m_StatManager.getHier().getTotal();
            for (int i = 0; i < nbClustering; ++i) {
                double sslweight;
                ClusAttrType attrType = clustering[i];
                if (attrType.getStatus().equals((Object)ClusAttrType.Status.Target)) {
                    sslweight = weight * ((double)nbClustering / nbTarget);
                    clusteringWeights.setWeight(i, 0.0);
                    for (int j = 1; j <= nbHierClasses; ++j) {
                        clusteringWeights.setWeight(i + j, sslweight);
                    }
                    continue;
                }
                sslweight = (1.0 - weight) * ((double)nbClustering / nbOther);
                clusteringWeights.setWeight(i, sslweight);
            }
        } else {
            for (int i = 0; i < nbClustering; ++i) {
                ClusAttrType attrType = clustering[i];
                double sslweight = attrType.getStatus().equals((Object)ClusAttrType.Status.Target) ? weight * ((double)nbClustering / nbTarget) : (1.0 - weight) * ((double)nbClustering / nbOther);
                clusteringWeights.setWeight(attrType, sslweight);
            }
        }
    }

    @Override
    public void partitionData(ClusRun cr) throws IOException, ClusException, InterruptedException {
        DataTuple t;
        int i;
        this.m_UnlabeledData = cr.getUnlabeledSet();
        this.m_TrainingSet = new RowData(cr.getStatManager().getSchema());
        RowData tempTestSet = new RowData(cr.getStatManager().getSchema());
        RowData tempTrainingSet = (RowData)cr.getTrainingSet();
        if (this.m_UnlabeledData == null && tempTrainingSet.getNbUnlabeled() == 0) {
            this.m_UnlabeledData = new RowData(cr.getStatManager().getSchema());
            ClusLogger.info("UnlabeledData not set. Unlabeled examples will be selected from training set (Percentage labeled = " + this.m_PercentageLabeled + ")");
            RandomSelection randomSelection = new RandomSelection(tempTrainingSet.getNbRows(), this.m_PercentageLabeled, cr.getStatManager().getSettings().getGeneral().getRandomSeed());
            if (this.m_IsInternalXVal) {
                for (int i2 = 0; i2 < tempTrainingSet.getNbRows(); ++i2) {
                    if (!randomSelection.isSelected(i2)) {
                        DataTuple t2 = tempTrainingSet.getTuple(i2).deepCloneTuple();
                        t2.makeUnlabeled();
                        this.m_UnlabeledData.add(t2);
                        tempTestSet.add(tempTrainingSet.getTuple(i2).deepCloneTuple());
                        continue;
                    }
                    this.m_TrainingSet.add(tempTrainingSet.getTuple(i2).deepCloneTuple());
                }
                cr.setTrainingSet(this.m_TrainingSet);
            } else {
                for (int i3 = 0; i3 < tempTrainingSet.getNbRows(); ++i3) {
                    if (randomSelection.isSelected(i3)) continue;
                    tempTestSet.add(tempTrainingSet.getTuple(i3).deepCloneTuple());
                    DataTuple t3 = tempTrainingSet.getTuple(i3).deepCloneTuple();
                    t3.makeUnlabeled();
                    tempTrainingSet.setTuple(t3, i3);
                }
                cr.setTrainingSet(tempTrainingSet);
            }
            if (cr.getTestSet() == null) {
                ClusLogger.info("Testing data not set. Semi-supervised learning will be evaluated on unlabeled data.");
                cr.setTestSet(tempTestSet.getIterator());
            }
            return;
        }
        if (this.m_UnlabeledData != null) {
            if (this.m_IsInternalXVal) {
                for (i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                    t = this.m_UnlabeledData.getTuple(i);
                    if (t.isUnlabeled()) continue;
                    t.makeUnlabeled();
                    this.m_UnlabeledData.setTuple(t, i);
                }
            } else {
                for (i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                    t = this.m_UnlabeledData.getTuple(i);
                    if (!t.isUnlabeled()) {
                        t.makeUnlabeled();
                    }
                    tempTrainingSet.add(t.deepCloneTuple());
                }
                cr.setTrainingSet(tempTrainingSet);
                this.m_UnlabeledData = null;
            }
        }
        if (tempTrainingSet.getNbUnlabeled() > 0 && this.m_IsInternalXVal) {
            if (this.m_UnlabeledData == null) {
                this.m_UnlabeledData = new RowData(cr.getStatManager().getSchema());
            }
            for (i = 0; i < tempTrainingSet.getNbRows(); ++i) {
                t = tempTrainingSet.getTuple(i);
                if (t.isUnlabeled()) {
                    this.m_UnlabeledData.add(t.deepCloneTuple());
                    continue;
                }
                this.m_TrainingSet.add(t.deepCloneTuple());
            }
            cr.setTrainingSet(this.m_TrainingSet);
        }
    }

    private void adjustNumberOfTrees(ClusRun foldRun) {
        foldRun.setIsInternalXValRun(true);
    }
}

