/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.semisupervised;

import java.io.IOException;
import java.io.PrintWriter;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.DataPreprocs;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.ext.semisupervised.ClusSemiSupervisedInduce;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class ClusSelfTrainingFTFInduce
extends ClusSemiSupervisedInduce {
    ClusInductionAlgorithm m_Induce;
    double m_Threshold;
    int m_Iterations;

    public ClusSelfTrainingFTFInduce(ClusSchema schema, Settings sett, ClusInductionAlgorithm clss_induce) throws ClusException, IOException {
        super(schema, sett);
        this.m_Induce = clss_induce;
        this.initialize(this.getSchema(), this.getSettings());
    }

    public void initialize(ClusSchema schema, Settings sett) {
        this.m_Threshold = sett.getSSL().getConfidenceThreshold();
        this.m_Iterations = sett.getSSL().getIterations();
        this.m_PercentageLabeled = (double)sett.getSSL().getPercentageLabeled() / 100.0;
    }

    @Override
    public ClusModel induceSingleUnpruned(ClusRun cr) throws Exception {
        this.partitionData(cr);
        int iterations = 0;
        double deltaUnlabeled = this.m_Threshold + 1.0;
        boolean first = true;
        int origLabeledMax = this.m_TrainingSet.getNbRows();
        ClusRun myClusRun = new ClusRun(cr);
        ClusAttributeWeights targetWeights = cr.getStatManager().getNormalizationWeights();
        PrintWriter writer = null;
        double originalError = 0.0;
        while (iterations <= this.m_Iterations && deltaUnlabeled > this.m_Threshold) {
            ClusStatistic stat;
            ClusLogger.info();
            ClusLogger.info("SelfTrainingFTF iteration: " + ++iterations);
            ClusLogger.info();
            this.m_Model = this.m_Induce.induceSingleUnpruned(myClusRun);
            if (first) {
                first = false;
                ClusModelInfo defInfo = cr.addModelInfo(0);
                defInfo.setModel(this.m_Model);
                originalError = this.calculateError(cr.getTestSet()).getModelError();
                writer = new PrintWriter(cr.getStatManager().getSettings().getGeneric().getAppName() + "_SelfTrainingFTFErrors.csv", "UTF-8");
                writer.println("DeltaUnlabeled,errorSSL,errorSupervised,errorOOBLabeled,errorOOBTrainingSet,errorTrainingSet,UnlabeledModelError");
                for (int i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                    DataTuple t = this.m_UnlabeledData.getTuple(i);
                    stat = this.m_Model.predictWeighted(t);
                    stat.computePrediction();
                    stat.predictTuple(t);
                    this.m_TrainingSet.add(t);
                }
            } else {
                deltaUnlabeled = 0.0;
                for (int i = 0; i < this.m_UnlabeledData.getNbRows(); ++i) {
                    DataTuple t = this.m_UnlabeledData.getTuple(i);
                    DataTuple temp = t.deepCloneTuple();
                    stat = this.m_Model.predictWeighted(t);
                    stat.computePrediction();
                    stat.predictTuple(t);
                    deltaUnlabeled += stat.getSquaredDistance(temp, targetWeights);
                }
            }
            ClusRun tempRun = new ClusRun(myClusRun);
            tempRun.setTrainingSet(this.m_UnlabeledData);
            ClusModel tempModel = this.m_Induce.induceSingleUnpruned(tempRun);
            writer.println(deltaUnlabeled + "," + this.calculateError(cr.getTestSet()).getModelError() + "," + originalError + "," + this.getOOBError(this.m_TrainingSet, origLabeledMax).getModelError() + "," + this.getOOBError(this.m_TrainingSet, this.m_TrainingSet.getNbRows()).getModelError() + "," + this.calculateError(this.m_TrainingSet).getModelError() + "," + this.calculateError(tempModel, this.m_TrainingSet, origLabeledMax).getModelError());
        }
        writer.close();
        ClusModelInfo origInfo = cr.addModelInfo(1);
        String additionalInfo = "Semi-supervised Self-training FTF\n\t Iterations performed = " + iterations + "\n\t Base model: ";
        ((ClusForest)this.m_Model).setModelInfo(additionalInfo);
        origInfo.setModel(this.m_Model);
        return this.m_Model;
    }

    @Override
    public void initialize() throws ClusException, IOException {
        this.m_Induce.initialize();
    }

    @Override
    public void initializeHeuristic() {
        this.m_Induce.initializeHeuristic();
    }

    @Override
    public ClusSchema getSchema() {
        return this.m_Induce.getSchema();
    }

    @Override
    public ClusStatManager getStatManager() {
        return this.m_Induce.getStatManager();
    }

    @Override
    public Settings getSettings() {
        return this.getStatManager().getSettings();
    }

    @Override
    public void getPreprocs(DataPreprocs pps) {
        this.getStatManager().getPreprocs(pps);
    }
}

