/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.model.test;

import java.util.Arrays;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.ClusRandom;

public class SubsetTest
extends NodeTest {
    public static final long serialVersionUID = 1L;
    protected int[] m_Values;
    protected NominalAttrType m_Type;
    protected double m_PosFreq;
    protected int m_MissIndex;

    public SubsetTest(NominalAttrType attr, int nb, boolean[] isin, double posfreq) {
        this.m_Type = attr;
        this.setArity(2);
        this.setPosFreq(posfreq);
        this.m_Values = this.initValues(nb, isin);
        this.m_MissIndex = attr.getNbValues();
    }

    public SubsetTest(NominalAttrType attr, int nb) {
        this.setArity(2);
        this.m_Type = attr;
        this.m_Values = new int[nb];
        this.m_MissIndex = attr.getNbValues();
    }

    @Override
    public ClusAttrType getType() {
        return this.m_Type;
    }

    @Override
    public void setType(ClusAttrType type) {
        this.m_Type = (NominalAttrType)type;
    }

    @Override
    public String getString() {
        if (this.m_Values.length == 1) {
            return this.m_Type.getName() + " = " + this.m_Type.getValue(this.m_Values[0]);
        }
        StringBuffer buffer = new StringBuffer();
        buffer.append(this.m_Type.getName());
        if (this.m_Values.length == 0) {
            buffer.append(" in ?");
        } else {
            buffer.append(" in {");
            for (int i = 0; i < this.m_Values.length; ++i) {
                if (i != 0) {
                    buffer.append(",");
                }
                buffer.append(this.m_Type.getValue(this.m_Values[i]));
            }
            buffer.append("}");
        }
        return buffer.toString();
    }

    @Override
    public String getPythonString(String xsElement) {
        if (this.m_Values.length == 1) {
            return xsElement + " == '" + this.m_Type.getValue(this.m_Values[0]) + "'";
        }
        StringBuffer buffer = new StringBuffer();
        buffer.append(xsElement);
        if (this.m_Values.length == 0) {
            buffer.append(" in ?");
        } else {
            buffer.append(" in (");
            for (int i = 0; i < this.m_Values.length; ++i) {
                if (i != 0) {
                    buffer.append(",");
                }
                buffer.append("'" + this.m_Type.getValue(this.m_Values[i]) + "'");
            }
            buffer.append(")");
        }
        return buffer.toString();
    }

    @Override
    public boolean hasConstants() {
        return this.m_Values.length > 0;
    }

    public int getNbValues() {
        return this.m_Values.length;
    }

    public int getValue(int i) {
        return this.m_Values[i];
    }

    public void setValue(int idx, int val) {
        this.m_Values[idx] = val;
    }

    @Override
    public boolean equals(NodeTest test) {
        if (this.m_Type != test.getType()) {
            return false;
        }
        SubsetTest ntest = (SubsetTest)test;
        int nb = this.m_Values.length;
        int[] ovalues = ntest.m_Values;
        if (nb != ovalues.length) {
            return false;
        }
        for (int i = 0; i < nb; ++i) {
            if (this.m_Values[i] == ovalues[i]) continue;
            return false;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int code = this.m_Type.getIndex() * 1000;
        for (int i = 0; i < this.m_Values.length; ++i) {
            code += this.m_Values[i];
        }
        return code + this.m_Values.length;
    }

    @Override
    public int nominalPredict(int value) {
        if (value == this.m_MissIndex) {
            return ClusRandom.nextDouble(0) < this.m_PosFreq ? 0 : 1;
        }
        for (int i = 0; i < this.m_Values.length; ++i) {
            if (this.m_Values[i] != value) continue;
            return 0;
        }
        return 1;
    }

    @Override
    public int nominalPredictWeighted(int value) {
        if (value == this.m_MissIndex) {
            return this.hasUnknownBranch() ? 2 : -1;
        }
        for (int i = 0; i < this.m_Values.length; ++i) {
            if (this.m_Values[i] != value) continue;
            return 0;
        }
        return 1;
    }

    @Override
    public int predictWeighted(DataTuple tuple) {
        int val = this.m_Type.getNominal(tuple);
        return this.nominalPredictWeighted(val);
    }

    @Override
    public NodeTest getBranchTest(int i) {
        if (i == 0) {
            return this;
        }
        int pos = 0;
        int nb = this.m_Type.getNbValues() - this.getNbValues();
        SubsetTest test = new SubsetTest(this.m_Type, nb);
        boolean[] isin = this.getIsInArray();
        for (int j = 0; j < isin.length; ++j) {
            if (isin[j]) continue;
            test.setValue(pos++, j);
        }
        test.setPosFreq(1.0 - this.getPosFreq());
        return test;
    }

    @Override
    public NodeTest simplifyConjunction(NodeTest other) {
        if (this.getType() != other.getType()) {
            return null;
        }
        if (other instanceof SubsetTest) {
            SubsetTest oset = (SubsetTest)other;
            boolean[] isin_me = this.getIsInArray();
            boolean[] isin_other = oset.getIsInArray();
            int count = 0;
            for (int i = 0; i < isin_me.length; ++i) {
                if (!isin_me[i] || !isin_other[i]) continue;
                ++count;
            }
            int pos = 0;
            SubsetTest test = new SubsetTest(this.m_Type, count);
            for (int i = 0; i < isin_me.length; ++i) {
                if (!isin_me[i] || !isin_other[i]) continue;
                test.setValue(pos++, i);
            }
            test.setPosFreq(Math.min(this.getPosFreq(), oset.getPosFreq()));
            return test;
        }
        return null;
    }

    public boolean[] getIsInArray() {
        boolean[] res = new boolean[this.m_Type.getNbValues()];
        for (int i = 0; i < this.getNbValues(); ++i) {
            res[this.getValue((int)i)] = true;
        }
        return res;
    }

    private int[] initValues(int nb, boolean[] isin) {
        int i = 0;
        int[] values = new int[nb];
        try {
            for (int j = 0; j < isin.length; ++j) {
                if (!isin[j]) continue;
                values[i++] = j;
            }
        }
        catch (Exception e) {
            ClusLogger.info("nb: " + nb);
            ClusLogger.info("isin: " + Arrays.toString(isin));
        }
        return values;
    }
}

