/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.addon.hmc.HMCNodeWiseModels.hmcnwmodels;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.Hashtable;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.algo.tdidt.tune.CDTTuneFTest;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.MemoryTupleIterator;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.hierarchies.ClassesAttrType;
import si.ijs.kt.clus.ext.ensemble.ClusEnsembleClassifier;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.ClassesValue;
import si.ijs.kt.clus.main.ClusOutput;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.io.ClusModelCollectionIO;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.array.StringTable;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgsProvider;

public class HMCNodeWiseModels
implements CMDLineArgsProvider {
    private static String[] g_Options = new String[]{"forest"};
    private static int[] g_OptionArities = new int[]{0};
    protected Clus m_Clus;
    protected CMDLineArgs m_Cargs;
    protected StringTable m_Table = new StringTable();
    protected Hashtable m_Mappings;
    protected double[] m_FTests;

    public void run(String[] args) throws Exception {
        this.m_Clus = new Clus();
        Settings sett = this.m_Clus.getSettings();
        this.m_Cargs = new CMDLineArgs(this);
        this.m_Cargs.process(args);
        if (this.m_Cargs.allOK()) {
            new File("hsc").mkdir();
            new File("hsc/out").mkdir();
            new File("hsc/model").mkdir();
            sett.getGeneric().setDate(new Date());
            sett.getGeneric().setAppName(this.m_Cargs.getMainArg(0));
            this.m_Clus.initSettings(this.m_Cargs);
            ClusDecisionTree clss = new ClusDecisionTree(this.m_Clus);
            if (sett.getTree().getFTestArray().isVector()) {
                this.m_FTests = sett.getTree().getFTestArray().getDoubleVector();
                clss = new CDTTuneFTest(clss, sett.getTree().getFTestArray().getDoubleVector());
            }
            this.m_Clus.initialize(this.m_Cargs, clss);
            this.doRun();
        } else {
            ClusLogger.info("m_Cargs nok");
        }
    }

    public RowData getNodeData(RowData train, int nodeid) {
        ArrayList<DataTuple> selected = new ArrayList<DataTuple>();
        for (int i = 0; i < train.getNbRows(); ++i) {
            DataTuple tuple = this.m_Clus.getSchema().isSparse() ? train.getTuple(i) : train.getTuple(i);
            ClassesTuple target = (ClassesTuple)tuple.getObjVal(0);
            if (nodeid != -1 && !target.hasClass(nodeid)) continue;
            selected.add(tuple);
        }
        return new RowData(selected, train.getSchema());
    }

    public RowData createChildData(RowData nodeData, ClassesAttrType ctype, int childid) throws ClusException {
        ClassHierarchy chier = ctype.getHier();
        ClassesValue one = new ClassesValue("1", ctype.getTable());
        chier.addClass(one);
        chier.initialize();
        one.addHierarchyIndices(chier);
        RowData childData = new RowData(ctype.getSchema(), nodeData.getNbRows());
        for (int j = 0; j < nodeData.getNbRows(); ++j) {
            ClassesTuple clss = null;
            DataTuple tuple = this.m_Clus.getSchema().isSparse() ? nodeData.getTuple(j) : nodeData.getTuple(j);
            ClassesTuple target = (ClassesTuple)tuple.getObjVal(0);
            if (target.hasClass(childid)) {
                clss = new ClassesTuple(1);
                clss.addItem(new ClassesValue(one.getTerm()));
            } else {
                clss = new ClassesTuple(0);
            }
            DataTuple new_tuple = tuple.deepCloneTuple();
            new_tuple.setSchema(ctype.getSchema());
            new_tuple.setObjectVal(clss, 0);
            childData.setTuple(new_tuple, j);
        }
        return childData;
    }

    public ClusSchema createChildSchema(ClusSchema oschema, ClassesAttrType ctype, String name) throws ClusException, IOException {
        ClusSchema cschema = new ClusSchema(name);
        for (int j = 0; j < oschema.getNbAttributes(); ++j) {
            ClusAttrType atype = oschema.getAttrType(j);
            if (atype instanceof ClassesAttrType) continue;
            ClusAttrType copy_atype = atype.cloneType();
            cschema.addAttrType(copy_atype);
        }
        cschema.addAttrType(ctype);
        cschema.initializeSettings(this.m_Clus.getSettings());
        if (oschema.isSparse()) {
            cschema.setSparse();
        }
        return cschema;
    }

    public void doOneNode(ClassTerm node, ClassHierarchy hier, RowData train, RowData valid, RowData test) throws Exception {
        RowData nodeData = this.getNodeData(train, node.getIndex());
        String nodeName = node.toPathString("=");
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusInductionAlgorithmType clss;
            ClassTerm child = (ClassTerm)node.getChild(i);
            String childName = child.toPathString("=");
            ClassesAttrType ctype = new ClassesAttrType(nodeName + "-" + childName);
            ClusSchema cschema = this.createChildSchema(train.getSchema(), ctype, "REL-" + nodeName + "-" + childName);
            RowData childData = this.createChildData(nodeData, ctype, child.getIndex());
            if (this.m_Cargs.hasOption("forest")) {
                this.m_Clus.getSettings().getEnsemble().setEnsembleMode(true);
                clss = new ClusEnsembleClassifier(this.m_Clus);
            } else {
                clss = new ClusDecisionTree(this.m_Clus);
            }
            if (this.m_FTests != null) {
                clss = new CDTTuneFTest(clss, this.m_FTests);
            }
            this.m_Clus.recreateInduce(this.m_Cargs, clss, cschema, childData);
            String name = this.m_Clus.getSettings().getGeneric().getAppName() + "-" + nodeName + "-" + childName;
            ClusRun cr = new ClusRun(childData.cloneData(), this.m_Clus.getSummary());
            cr.copyTrainingData();
            if (valid != null) {
                RowData validNodeData = this.getNodeData(valid, node.getIndex());
                RowData validChildData = this.createChildData(validNodeData, ctype, child.getIndex());
                cr.setPruneSet(validChildData, null);
            }
            if (test != null) {
                RowData testNodeData = this.getNodeData(test, node.getIndex());
                RowData testChildData = this.createChildData(testNodeData, ctype, child.getIndex());
                MemoryTupleIterator iter = testChildData.getIterator();
                cr.setTestSet(iter);
            }
            this.m_Clus.initializeSummary(clss);
            ClusOutput output = new ClusOutput("hsc/out/" + name + ".out", cschema, this.m_Clus.getSettings());
            this.m_Clus.getStatManager().computeTrainSetStat((RowData)cr.getTrainingSet());
            this.m_Clus.induce(cr, clss);
            this.m_Clus.calcError(cr, null);
            output.writeHeader();
            output.writeOutput(cr, true, this.m_Clus.getSettings().getOutput().isOutTrainError());
            output.close();
            ClusModelCollectionIO io = new ClusModelCollectionIO();
            io.addModel(cr.addModelInfo(1));
            io.save("hsc/model/" + name + ".model");
        }
    }

    public void computeRecursive(ClassTerm node, ClassHierarchy hier, RowData train, RowData valid, RowData test, boolean[] computed) throws Exception {
        if (!computed[node.getIndex()]) {
            computed[node.getIndex()] = true;
            this.doOneNode(node, hier, train, valid, test);
            for (int i = 0; i < node.getNbChildren(); ++i) {
                ClassTerm child = (ClassTerm)node.getChild(i);
                this.computeRecursive(child, hier, train, valid, test, computed);
            }
        }
    }

    public void computeRecursiveRoot(ClassTerm node, ClassHierarchy hier, RowData train, RowData valid, RowData test, boolean[] computed) throws Exception {
        this.doOneNode(node, hier, train, valid, test);
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClassTerm child = (ClassTerm)node.getChild(i);
            this.computeRecursive(child, hier, train, valid, test, computed);
        }
    }

    public void doRun() throws Exception {
        ClusRun cr = this.m_Clus.partitionData();
        RowData train = (RowData)cr.getTrainingSet();
        RowData valid = (RowData)cr.getPruneSet();
        RowData test = cr.getTestSet();
        ClusStatManager mgr = this.m_Clus.getStatManager();
        ClassHierarchy hier = mgr.getHier();
        ClassTerm root = hier.getRoot();
        boolean[] computed = new boolean[hier.getTotal()];
        this.computeRecursiveRoot(root, hier, train, valid, test, computed);
    }

    @Override
    public String[] getOptionArgs() {
        return g_Options;
    }

    @Override
    public int[] getOptionArgArities() {
        return g_OptionArities;
    }

    @Override
    public int getNbMainArgs() {
        return 1;
    }

    @Override
    public void showHelp() {
    }

    public static void main(String[] args) throws IOException {
        try {
            HMCNodeWiseModels m = new HMCNodeWiseModels();
            m.run(args);
        }
        catch (IOException io) {
            ClusLogger.info("IO Error: " + io.getMessage());
            io.printStackTrace();
        }
        catch (ClusException cl) {
            ClusLogger.info("Error: " + cl.getMessage());
            cl.printStackTrace();
        }
        catch (ClassNotFoundException cn) {
            ClusLogger.info("Error: " + cn.getMessage());
        }
        catch (InterruptedException ie) {
            ClusLogger.info("Error: " + ie.getMessage());
            ie.printStackTrace();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        ClusLogger.info("Finished. Have a nice day, Stevanche.");
    }
}

