/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.hierarchical;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.hierarchies.ClassesAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.ClassesValue;
import si.ijs.kt.clus.ext.hierarchical.HierNodeWeights;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsHMLC;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.math.SingleStat;
import si.ijs.kt.clus.util.jeans.tree.CompleteTreeIterator;
import si.ijs.kt.clus.util.jeans.util.array.StringTable;

public class ClassHierarchy
implements Serializable {
    public static final long serialVersionUID = 1L;
    public static final int TEST = 0;
    public static final int ERROR = 1;
    public static final int UNKNOWN = -1;
    public static final int TREE = 0;
    public static final int DAG = 1;
    protected int m_MaxDepth = 0;
    protected int m_HierType = 0;
    protected ClassesTuple m_Eval;
    protected ArrayList<ClassTerm> m_ClassList = new ArrayList();
    protected HashMap<String, ClassTerm> m_ClassMap = new HashMap();
    protected ClassTerm m_Root;
    protected NumericAttrType[] m_DummyTypes;
    protected boolean m_IsLocked;
    protected double[] m_Weights;
    protected transient ClassesAttrType m_Type;
    public static final char DFS_WHITE = '\u0000';
    public static final char DFS_GRAY = '\u0001';
    public static final char DFS_BLACK = '\u0002';

    public ClassHierarchy() {
    }

    public ClassHierarchy(ClassesAttrType type) {
        this(new ClassTerm());
        this.setType(type);
    }

    public ClassHierarchy(ClassTerm root) {
        this.m_Root = root;
    }

    public Settings getSettings() {
        return this.m_Type.getSettings();
    }

    public final void setType(ClassesAttrType type) {
        this.m_Type = type;
    }

    public final ClassesAttrType getType() {
        return this.m_Type;
    }

    public final void addClass(ClassesValue val) {
        if (!this.isLocked()) {
            this.m_Root.addClass(val, 0, this);
        }
    }

    public final void print(PrintWriter wrt) {
        this.m_Root.print(0, wrt, null, null);
    }

    public final void print(PrintWriter wrt, double[] counts, double[] weights) {
        this.m_Root.print(0, wrt, counts, weights);
    }

    public final void print(PrintWriter wrt, double[] counts) {
        this.m_Root.print(0, wrt, counts, this.m_Weights);
    }

    public final int getMaxDepth() {
        return this.m_Root.getMaxDepth();
    }

    public final ClassTerm getRoot() {
        return this.m_Root;
    }

    public final void initClassListRecursiveTree(ClassTerm term) {
        this.m_ClassList.add(term);
        term.sortChildrenByID();
        for (int i = 0; i < term.getNbChildren(); ++i) {
            this.initClassListRecursiveTree((ClassTerm)term.getChild(i));
        }
    }

    public final void initClassListRecursiveDAG(ClassTerm term, HashSet<String> set) {
        if (!set.contains(term.getID())) {
            this.m_ClassList.add(term);
            term.sortChildrenByID();
            for (int i = 0; i < term.getNbChildren(); ++i) {
                this.initClassListRecursiveDAG((ClassTerm)term.getChild(i), set);
            }
            set.add(term.getID());
        }
    }

    public final void numberHierarchy() {
        this.m_Root.setIndex(-1);
        this.m_Root.sortChildrenByID();
        this.m_ClassList.clear();
        if (this.isDAG()) {
            HashSet<String> set = new HashSet<String>();
            for (int i = 0; i < this.m_Root.getNbChildren(); ++i) {
                this.initClassListRecursiveDAG((ClassTerm)this.m_Root.getChild(i), set);
            }
        } else {
            for (int i = 0; i < this.m_Root.getNbChildren(); ++i) {
                this.initClassListRecursiveTree((ClassTerm)this.m_Root.getChild(i));
            }
        }
        for (int i = 0; i < this.getTotal(); ++i) {
            ClassTerm term = this.getTermAt(i);
            term.setIndex(i);
        }
        this.setLocked(true);
    }

    void getAllParentChildTuplesRecursive(ClassTerm node, boolean[] visited, ArrayList<String> parentchilds) {
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClassTerm child = (ClassTerm)node.getChild(i);
            parentchilds.add(node.getID() + "/" + child.getID());
            if (visited[child.getIndex()]) continue;
            visited[child.getIndex()] = true;
            this.getAllParentChildTuplesRecursive(child, visited, parentchilds);
        }
    }

    public ArrayList<String> getAllParentChildTuples() {
        ArrayList<String> parentchilds = new ArrayList<String>();
        boolean[] visited = new boolean[this.getTotal()];
        this.getAllParentChildTuplesRecursive(this.m_Root, visited, parentchilds);
        return parentchilds;
    }

    void getAllPathsRecursive(ClassTerm node, String crpath, boolean[] visited, ArrayList<String> paths) {
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClassTerm child = (ClassTerm)node.getChild(i);
            String new_path = node.getIndex() == -1 ? "" : crpath + "/";
            new_path = new_path + child.getID();
            paths.add(new_path);
            if (visited[child.getIndex()]) continue;
            visited[child.getIndex()] = true;
            this.getAllPathsRecursive(child, new_path, visited, paths);
        }
    }

    public ArrayList<String> getAllPaths() {
        ArrayList<String> paths = new ArrayList<String>();
        boolean[] visited = new boolean[this.getTotal()];
        this.getAllPathsRecursive(this.m_Root, "", visited, paths);
        return paths;
    }

    public void addAllClasses(ClassesTuple tuple, boolean[] matrix) {
        int idx = 0;
        tuple.setSize(ClassHierarchy.countOnes(matrix));
        for (int i = 0; i < this.getTotal(); ++i) {
            if (!matrix[i]) continue;
            tuple.setItemAt(new ClassesValue(this.getTermAt(i), 1.0), idx++);
        }
    }

    public void fillBooleanMatrixMaj(double[] mean, boolean[] matrix, double treshold) {
        for (int i = 0; i < this.getTotal(); ++i) {
            ClassTerm term = this.getTermAt(i);
            if (!(mean[term.getIndex()] >= treshold / 100.0)) continue;
            matrix[term.getIndex()] = true;
        }
    }

    public static void removeParentNodesRec(ClassTerm node, boolean[] matrix) {
        if (matrix[node.getIndex()]) {
            ClassTerm parent = node.getCTParent();
            while (parent.getIndex() != -1 && matrix[parent.getIndex()]) {
                matrix[parent.getIndex()] = false;
                parent = parent.getCTParent();
            }
        }
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClassHierarchy.removeParentNodesRec((ClassTerm)node.getChild(i), matrix);
        }
    }

    public static void removeParentNodes(ClassTerm node, boolean[] matrix) {
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClassHierarchy.removeParentNodesRec((ClassTerm)node.getChild(i), matrix);
        }
    }

    public void removeParentNodesRecursive(ClassTerm term, boolean[] array) {
        for (int i = 0; i < term.getNbParents(); ++i) {
            ClassTerm par = term.getParent(i);
            if (par.getIndex() == -1 || !array[par.getIndex()]) continue;
            array[par.getIndex()] = false;
            this.removeParentNodesRecursive(par, array);
        }
    }

    public void removeParentNodes(boolean[] array) {
        for (int i = 0; i < this.getTotal(); ++i) {
            ClassTerm term = this.getTermAt(i);
            if (term.getIndex() == -1 || !array[term.getIndex()]) continue;
            this.removeParentNodesRecursive(term, array);
        }
    }

    public static int countOnes(boolean[] matrix) {
        int count = 0;
        for (int i = 0; i < matrix.length; ++i) {
            if (!matrix[i]) continue;
            ++count;
        }
        return count;
    }

    public ClassesTuple getBestTupleMajNoParents(double[] mean, double treshold) {
        boolean[] classes = new boolean[this.getTotal()];
        this.fillBooleanMatrixMaj(mean, classes, treshold);
        ClassHierarchy.removeParentNodes(this.getRoot(), classes);
        ClassesTuple tuple = new ClassesTuple();
        this.addAllClasses(tuple, classes);
        return tuple;
    }

    public ClassesTuple getBestTupleMaj(double[] mean, double treshold) {
        boolean[] classes = new boolean[this.getTotal()];
        this.fillBooleanMatrixMaj(mean, classes, treshold);
        ClassesTuple tuple = new ClassesTuple();
        this.addAllClasses(tuple, classes);
        return tuple;
    }

    public ClassesTuple getTuple(boolean[] nodes) {
        ClassesTuple result = new ClassesTuple();
        this.addAllClasses(result, nodes);
        return result;
    }

    public final CompleteTreeIterator getNoRootIter() {
        CompleteTreeIterator iter = new CompleteTreeIterator(this.m_Root);
        if (iter.hasMoreNodes()) {
            iter.getNextNode();
        }
        return iter;
    }

    public final CompleteTreeIterator getRootIter() {
        return new CompleteTreeIterator(this.m_Root);
    }

    public final double[] getWeights() {
        return this.m_Weights;
    }

    public final void calcWeights() {
        HierNodeWeights ws = new HierNodeWeights();
        SettingsHMLC.HierarchyWeight wtype = this.getSettings().getHMLC().getHierWType();
        double widec = this.getSettings().getHMLC().getHierWParam();
        ws.initExponentialDepthWeights(this, wtype, widec);
        this.m_Weights = ws.getWeights();
    }

    public final void calcMaxDepth() {
        this.m_MaxDepth = 0;
        for (int i = 0; i < this.getTotal(); ++i) {
            ClassTerm term = this.getTermAt(i);
            this.m_MaxDepth = Math.max(this.m_MaxDepth, term.getMaxDepth());
        }
    }

    public final SingleStat getMeanBranch(boolean[] enabled) {
        SingleStat stat = new SingleStat();
        this.m_Root.getMeanBranch(enabled, stat);
        return stat;
    }

    public final int getTotal() {
        return this.m_ClassList.size();
    }

    public final int getDepth() {
        return this.m_MaxDepth;
    }

    public final int[] getClassesByLevel() {
        int[] res = new int[this.getDepth() + 2];
        this.countClassesRecursive(this.m_Root, 0, res);
        return res;
    }

    public final void countClassesRecursive(ClassTerm root, int depth, int[] cls) {
        int n = depth;
        cls[n] = cls[n] + 1;
        for (int i = 0; i < root.getNbChildren(); ++i) {
            this.countClassesRecursive((ClassTerm)root.getChild(i), depth + 1, cls);
        }
    }

    public final void initialize() {
        this.numberHierarchy();
        this.calcWeights();
        this.calcMaxDepth();
        ClusSchema schema = this.m_Type.getSchema();
        int maxIndex = schema.getNbAttributes();
        this.m_DummyTypes = new NumericAttrType[this.getTotal()];
        for (int i = 0; i < this.getTotal(); ++i) {
            NumericAttrType type = new NumericAttrType("H" + i);
            type.setIndex(maxIndex++);
            type.setSchema(schema);
            this.m_DummyTypes[i] = type;
        }
    }

    public boolean[] removeInfrequentClasses(WHTDStatistic stat, double minfreq) {
        int i;
        boolean[] removed = new boolean[this.getTotal()];
        ArrayList<ClassTerm> new_cls = new ArrayList<ClassTerm>();
        for (i = 0; i < this.getTotal(); ++i) {
            double mean = stat.getMean(i);
            if (mean == 0.0 || mean < minfreq) {
                ClassTerm trm = this.getTermAt(i);
                for (int j = 0; j < trm.getNbParents(); ++j) {
                    ClassTerm par = trm.getParent(j);
                    par.removeChild(trm);
                }
                removed[trm.getIndex()] = true;
                continue;
            }
            new_cls.add(this.getTermAt(i));
        }
        this.m_ClassList.clear();
        this.m_ClassMap.clear();
        if (this.isDAG()) {
            for (i = 0; i < new_cls.size(); ++i) {
                ClassTerm trm = (ClassTerm)new_cls.get(i);
                this.m_ClassMap.put(trm.getID(), trm);
            }
        }
        return removed;
    }

    public final NumericAttrType[] getDummyAttrs() {
        return this.m_DummyTypes;
    }

    public final void showSummary() {
        int leaves = 0;
        int depth = this.getMaxDepth();
        ClusLogger.info("Depth: " + depth);
        ClusLogger.info("Nodes: " + this.getTotal());
        ClassTerm root = this.getRoot();
        int nb = root.getNbChildren();
        for (int i = 0; i < nb; ++i) {
            ClassTerm chi = (ClassTerm)root.getChild(i);
            int nbl = chi.getNbLeaves();
            ClusLogger.info("Child " + i + ": " + chi.getID() + " " + nbl);
            leaves += nbl;
        }
        ClusLogger.info("Leaves: " + leaves);
    }

    public final ClassTerm getClassTermTree(ClassesValue vl) throws ClusException {
        int pos = 0;
        int nb_level = vl.getNbLevels();
        ClassTerm subterm = this.m_Root;
        while (pos < nb_level) {
            String lookup = vl.getClassID(pos);
            if (lookup.equals("0")) {
                return subterm;
            }
            ClassTerm found = subterm.getByName(lookup);
            if (found == null) {
                throw new ClusException("Classes value not in tree hierarchy: " + vl.toPathString() + " (lookup: " + lookup + ", term: " + subterm.toPathString() + ", subterms: " + subterm.getKeysVector() + ")");
            }
            subterm = found;
            ++pos;
        }
        return subterm;
    }

    public final ClassTerm getClassTermDAG(ClassesValue vl) throws ClusException {
        ClassTerm term = this.getClassTermByName(vl.getMostSpecificClass());
        if (term == null) {
            throw new ClusException("Classes value not in DAG hierarchy: " + vl.toPathString());
        }
        return term;
    }

    public final ClassTerm getClassTerm(ClassesValue vl) throws ClusException {
        if (this.isTree()) {
            return this.getClassTermTree(vl);
        }
        return this.getClassTermDAG(vl);
    }

    public final int getClassIndex(ClassesValue vl) throws ClusException {
        return this.getClassTerm(vl).getIndex();
    }

    public final double getWeight(int idx) {
        return this.m_Weights[idx];
    }

    public final void setEvalClasses(ClassesTuple eval) {
        this.m_Eval = eval;
    }

    public final ClassesTuple getEvalClasses() {
        return this.m_Eval;
    }

    public final boolean[] getEvalClassesVector() {
        if (this.m_Eval == null) {
            boolean[] res = new boolean[this.getTotal()];
            Arrays.fill(res, true);
            return res;
        }
        return this.m_Eval.getVectorBoolean(this);
    }

    public void addChildrenToRoot() {
        for (ClassTerm term : this.m_ClassMap.values()) {
            if (term == this.m_Root || !term.atTopLevel()) continue;
            this.m_Root.addChildCheckAndParent(term);
        }
    }

    public void addParentChildTuple(String parent, String child) throws ClusException {
        ClassTerm parent_t = this.getClassTermByNameAddIfNotIn(parent);
        ClassTerm child_t = this.getClassTermByNameAddIfNotIn(child);
        if (parent_t.getByName(child) != null) {
            throw new ClusException("Duplicate parent-child relation '" + parent + "' -> '" + child + "' in DAG definition in .arff file");
        }
        parent_t.addChildCheckAndParent(child_t);
    }

    public void loadDAG(String[] cls) throws IOException, ClusException {
        this.addClassTerm("root", this.getRoot());
        for (int i = 0; i < cls.length; ++i) {
            String[] rel = cls[i].split("\\s*\\/\\s*");
            if (rel.length != 2) {
                throw new ClusException("Illegal parent child tuple in .arff");
            }
            String parent = rel[0];
            String child = rel[1];
            this.addParentChildTuple(parent, child);
        }
        this.addChildrenToRoot();
    }

    public void loadDAG(String fname) throws IOException, ClusException {
        String line = null;
        try (LineNumberReader rdr = new LineNumberReader(new FileReader(fname));){
            while ((line = rdr.readLine()) != null) {
                if ((line = line.trim()).equals("")) continue;
                String[] rel = line.split("\\s*\\,\\s*");
                if (rel.length != 2) {
                    throw new ClusException("Illegal line '" + line + "' in DAG definition file: '" + fname + "'");
                }
                String parent = rel[0];
                String child = rel[1];
                this.addParentChildTuple(parent, child);
            }
        }
        this.addChildrenToRoot();
    }

    public void findCycleRecursive(ClassTerm term, char[] visited, ClassTerm[] pi, boolean[] hasCycle) {
        visited[term.getIndex()] = '\u0001';
        for (int i = 0; i < term.getNbChildren(); ++i) {
            ClassTerm child = (ClassTerm)term.getChild(i);
            if (visited[child.getIndex()] == '\u0000') {
                pi[child.getIndex()] = term;
                this.findCycleRecursive(child, visited, pi, hasCycle);
                continue;
            }
            if (visited[child.getIndex()] != '\u0001') continue;
            ClusLogger.info("Cycle: ");
            System.out.print("(" + term.getID() + "," + child.getID() + ")");
            ClassTerm w = term;
            do {
                System.out.print("; (" + w.getID() + "," + pi[w.getIndex()].getID() + ")");
            } while ((w = pi[w.getIndex()]) != child);
            ClusLogger.info();
            hasCycle[0] = true;
        }
        visited[term.getIndex()] = 2;
    }

    public void findCycle() throws ClusException {
        char[] visited = new char[this.getTotal()];
        ClassTerm[] pi = new ClassTerm[this.getTotal()];
        boolean[] hasCycle = new boolean[1];
        Arrays.fill(visited, '\u0000');
        for (int i = 0; i < this.m_ClassList.size(); ++i) {
            ClassTerm term = this.getTermAt(i);
            if (visited[term.getIndex()] != '\u0000') continue;
            this.findCycleRecursive(term, visited, pi, hasCycle);
        }
        if (hasCycle[0]) {
            throw new ClusException("hasCycle[0] == true");
        }
    }

    public void writeTargets(RowData data, ClusSchema schema, String name) throws ClusException, IOException {
        double[] wis = this.getWeights();
        PrintWriter wrt = new PrintWriter(new FileWriter(name + ".weights"));
        wrt.print("weights(X) :- X = [");
        for (int i = 0; i < wis.length; ++i) {
            if (i != 0) {
                wrt.print(",");
            }
            wrt.print(wis[i]);
        }
        wrt.println("].");
        wrt.println();
        ClassTerm[] terms = new ClassTerm[wis.length];
        CompleteTreeIterator iter = this.getRootIter();
        while (iter.hasMoreNodes()) {
            ClassTerm node = (ClassTerm)iter.getNextNode();
            if (node.getIndex() == -1) continue;
            terms[node.getIndex()] = node;
        }
        for (int i = 0; i < wis.length; ++i) {
            wrt.print("% class " + terms[i] + ": ");
            wrt.println(wis[i]);
        }
        wrt.close();
        ClusAttrType[] keys = schema.getAllAttrUse(ClusAttrType.AttributeUseType.Key);
        int sidx = this.getType().getArrayIndex();
        wrt = new PrintWriter(new FileWriter(name + ".targets"));
        for (int i = 0; i < data.getNbRows(); ++i) {
            DataTuple tuple = data.getTuple(i);
            int pos = 0;
            for (int j = 0; j < keys.length; ++j) {
                if (pos != 0) {
                    wrt.print(",");
                }
                wrt.print(keys[j].getString(tuple));
                ++pos;
            }
            ClassesTuple target = (ClassesTuple)tuple.getObjVal(sidx);
            double[] vec = target.getVectorNodeAndAncestors(this);
            wrt.print(",");
            wrt.print(target.toString());
            wrt.print(",[");
            for (int j = 0; j < vec.length; ++j) {
                if (j != 0) {
                    wrt.print(",");
                }
                wrt.print(vec[j]);
            }
            wrt.println("]");
        }
        wrt.close();
    }

    public void setLocked(boolean lock) {
        this.m_IsLocked = lock;
    }

    public boolean isLocked() {
        return this.m_IsLocked;
    }

    public void setHierType(int type) {
        this.m_HierType = type;
    }

    public void setHierTypeFromSettings(SettingsHMLC.HierarchyType t) {
        this.m_HierType = t.equals((Object)SettingsHMLC.HierarchyType.DAG) ? 1 : (t.equals((Object)SettingsHMLC.HierarchyType.Tree) ? 0 : -1);
    }

    public boolean isTree() {
        return this.m_HierType == 0;
    }

    public boolean isDAG() {
        return this.m_HierType == 1;
    }

    public ClassTerm getClassTermByNameAddIfNotIn(String id) {
        ClassTerm found = this.getClassTermByName(id);
        if (found == null) {
            found = new ClassTerm(id);
            this.addClassTerm(id, found);
        }
        return found;
    }

    public ClassTerm getClassTermByName(String id) {
        return this.m_ClassMap.get(id);
    }

    public void addClassTerm(String id, ClassTerm term) {
        this.m_ClassMap.put(id, term);
    }

    public void addClassTerm(ClassTerm term) {
        this.m_ClassList.add(term);
    }

    public ClassTerm getTermAt(int i) {
        return this.m_ClassList.get(i);
    }

    public ClassesValue createValueByName(String name, StringTable table) throws ClusException {
        ClassesValue val = new ClassesValue(name, table);
        ClassTerm term = this.getClassTerm(val);
        val.setClassTerm(term);
        return val;
    }

    public boolean[] getIsLeafVector() {
        boolean[] answer = new boolean[this.m_ClassList.size()];
        for (int classInd = 0; classInd < answer.length; ++classInd) {
            answer[classInd] = this.getTermAt(classInd).atBottomLevel();
        }
        return answer;
    }
}

