/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.pruning;

import java.util.ArrayList;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.heuristic.EncodingCost;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class EncodingCostPruning
extends PruneTree {
    public double m_Ecc;
    public double m_EccGain = Double.NEGATIVE_INFINITY;
    public double m_BestEcc = Double.MAX_VALUE;
    public ClusNode m_BestTreeSoFar;
    public ClusNode m_BestNodeToPrune;
    public RowData m_Data;
    public EncodingCost m_EC = new EncodingCost();

    @Override
    public void setTrainingData(RowData data) {
        this.m_Data = data;
        this.m_EC.setAttributes(this.m_Data.getSchema().getDescriptiveAttributes());
    }

    @Override
    public int getNbResults() {
        return 1;
    }

    @Override
    public void prune(ClusNode node) throws ClusException {
        ClusLogger.info("Encoding cost pruning started");
        node.numberCompleteTree();
        int totalNbNodes = node.getTotalTreeSize();
        this.m_EC.initializeLogPMatrix(totalNbNodes);
        this.doPrune(node);
        ClusLogger.info("Encoding cost pruning resulted in the following clusters (1 per line):");
        this.printInstanceLabels(node, this.m_Data);
    }

    public void doPrune(ClusNode node) throws ClusException {
        this.m_Ecc = this.calculateEncodingCost(node, this.m_Data);
        if (this.m_Ecc < this.m_BestEcc) {
            this.m_BestEcc = this.m_Ecc;
            this.m_BestTreeSoFar = node.cloneTreeWithVisitors();
        }
        this.traverseTreeAndRecordEncodingCostIfLeafChildren(node, node, this.m_Data);
        if (this.m_BestNodeToPrune != null) {
            ClusLogger.info("Pruning node such that ECC drops with " + this.m_EccGain);
            this.m_BestNodeToPrune.makeLeaf();
            this.m_EccGain = Double.NEGATIVE_INFINITY;
            this.m_BestNodeToPrune = null;
            this.doPrune(node);
        } else {
            node.setTest(this.m_BestTreeSoFar.getTest());
            ClusNode[] children = this.m_BestTreeSoFar.getChildren();
            for (int i = 0; i < children.length; ++i) {
                node.addChild(children[i]);
            }
        }
    }

    public int printInstanceLabels(ClusNode node, RowData data) {
        ArrayList<RowData> clusters = new ArrayList<RowData>();
        ArrayList<Integer> clusterIds = new ArrayList<Integer>();
        this.getLeafClusters(node, data, clusters, clusterIds);
        for (int i = 0; i < clusters.size(); ++i) {
            int nbRows = clusters.get(i).getNbRows();
            String key = clusters.get(i).getSchema().getKeyAttribute()[0].getString(clusters.get(i).getTuple(0));
            System.out.print(key);
            for (int r = 1; r < nbRows; ++r) {
                key = clusters.get(i).getSchema().getKeyAttribute()[0].getString(clusters.get(i).getTuple(r));
                System.out.print(" " + key);
            }
            System.out.print("\n");
        }
        return 0;
    }

    public double calculateEncodingCost(ClusNode node, RowData data) {
        ArrayList<RowData> clusters = new ArrayList<RowData>();
        ArrayList<Integer> clusterIds = new ArrayList<Integer>();
        this.getLeafClusters(node, data, clusters, clusterIds);
        this.m_EC.setClusters(clusters, clusterIds);
        this.m_EC.setNbSequences(data.getNbRows());
        double ecv = this.m_EC.getEncodingCostValue();
        return ecv;
    }

    public int traverseTreeAndRecordEncodingCostIfLeafChildren(ClusNode node, ClusNode rootNode, RowData rootData) {
        int arity = node.getNbChildren();
        if (arity > 0) {
            int nbLeafChildren = 0;
            for (int i = 0; i < arity; ++i) {
                ClusNode child = (ClusNode)node.getChild(i);
                nbLeafChildren += this.traverseTreeAndRecordEncodingCostIfLeafChildren(child, rootNode, rootData);
            }
            if (nbLeafChildren == arity) {
                ClusNode[] children = node.getChildren();
                NodeTest test = node.getTest();
                node.makeLeaf();
                double ecc = this.calculateEncodingCost(rootNode, rootData);
                double eccGain = this.m_Ecc - ecc;
                if (eccGain > this.m_EccGain) {
                    this.m_EccGain = eccGain;
                    this.m_BestNodeToPrune = node;
                }
                node.setTest(test);
                for (int i = 0; i < children.length; ++i) {
                    node.addChild(children[i]);
                }
            }
            return 0;
        }
        return 1;
    }

    public void traverseTreeAndRecordEncodingCost(ClusNode node, ClusNode rootNode, RowData rootData) {
        int arity = node.getNbChildren();
        if (arity > 0) {
            for (int i = 0; i < arity; ++i) {
                ClusNode child = (ClusNode)node.getChild(i);
                this.traverseTreeAndRecordEncodingCost(child, rootNode, rootData);
            }
            ClusNode[] children = node.getChildren();
            NodeTest test = node.getTest();
            node.makeLeaf();
            double ecc = this.calculateEncodingCost(rootNode, rootData);
            ClusLogger.info("new ecc = " + ecc);
            double eccGain = this.m_Ecc - ecc;
            if (eccGain > this.m_EccGain) {
                this.m_EccGain = eccGain;
                this.m_BestNodeToPrune = node;
                ClusLogger.info("better!");
            }
            node.setTest(test);
            for (int i = 0; i < children.length; ++i) {
                node.addChild(children[i]);
            }
        }
    }

    public void getLeafClusters(ClusNode node, RowData data, ArrayList<RowData> clusters, ArrayList<Integer> clusterIds) {
        if (!node.atBottomLevel()) {
            int arity = node.getNbChildren();
            for (int i = 0; i < arity; ++i) {
                RowData subset = data.applyWeighted(node.getTest(), i);
                this.getLeafClusters((ClusNode)node.getChild(i), subset, clusters, clusterIds);
            }
        } else {
            clusters.add(data);
            clusterIds.add(new Integer(node.getID() - 1));
        }
    }
}

