/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.featureRanking.relief.statistics;

import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.ejml.simple.SimpleMatrix;
import org.jblas.DoubleMatrix;
import org.jblas.Singular;
import org.jblas.Solve;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.ext.featureRanking.relief.ClusReliefFeatureRanking;
import si.ijs.kt.clus.ext.featureRanking.relief.nearestNeighbour.NearestNeighbour;
import si.ijs.kt.clus.ext.featureRanking.relief.statistics.Statistics;
import si.ijs.kt.clus.util.ClusUtil;
import si.ijs.kt.clus.util.exception.ClusException;

public class Steepness
extends Statistics {
    private double[][][] m_Coefficient;
    private double[][][] mDistAttribute;
    private double[][] mDistTarget;
    private double[] tempNormX;
    private double[] tempXtY;
    private double[] tempDistTarget;
    private double[][] tempDistAttribute;
    private double[][] m_NbConstant;
    private boolean m_IsTLS;
    private boolean m_IsMacro;
    private int m_NeighIndex;
    private int[][] mWhereTo;
    private int[] mWhereToTarget;

    public Steepness(ClusReliefFeatureRanking relief, int nbTargets, int nbDiffNbNeighbours, int nbDescriptiveAttributes, int maxNbNeighbours, boolean isTotalLeastSquares, boolean isMacro) {
        this.initializeSuperFields(relief, nbDescriptiveAttributes);
        this.m_NbConstant = new double[nbDiffNbNeighbours][nbDescriptiveAttributes];
        this.m_Coefficient = new double[nbTargets][nbDiffNbNeighbours][nbDescriptiveAttributes];
        this.tempNormX = new double[nbDescriptiveAttributes];
        this.tempXtY = new double[nbDescriptiveAttributes];
        this.tempDistTarget = new double[maxNbNeighbours];
        this.tempDistAttribute = new double[nbDescriptiveAttributes][maxNbNeighbours];
        this.mDistAttribute = new double[nbDiffNbNeighbours][nbDescriptiveAttributes][relief.getMaxNbIterations()];
        this.mDistTarget = new double[nbDiffNbNeighbours][relief.getMaxNbIterations()];
        this.mWhereTo = new int[nbDiffNbNeighbours][nbDescriptiveAttributes];
        this.mWhereToTarget = new int[nbDiffNbNeighbours];
        this.m_IsTLS = isTotalLeastSquares;
        this.m_IsMacro = isMacro;
        this.m_NeighIndex = 0;
    }

    @Override
    public void updateTempStatistics(int targetIndex, boolean isStdClassification, DataTuple tuple, RowData data, NearestNeighbour neigh, double neighWeightNonnormalized, int trueIndex, int targetValue) throws ClusException {
        double targetDistance = 0.0;
        targetDistance = targetIndex >= 0 && !isStdClassification ? this.mRelief.computeDistance1D(tuple, data.getTuple(neigh.getIndexInDataset()), this.mRelief.getTargetAttribute(trueIndex)) : this.mRelief.computeDistance(tuple, data.getTuple(neigh.getIndexInDataset()), 1);
        this.tempDistTarget[this.m_NeighIndex] = targetDistance;
        int attrInd = 0;
        while (attrInd < this.m_NbDescriptiveAttrs) {
            double distAttr;
            ClusAttrType attr = this.mRelief.getDescriptiveAttribute(attrInd);
            this.tempDistAttribute[attrInd][this.m_NeighIndex] = distAttr = this.mRelief.computeDistance1D(tuple, data.getTuple(neigh.getIndexInDataset()), attr);
            int n = attrInd;
            this.tempNormX[n] = this.tempNormX[n] + distAttr * distAttr;
            int n2 = attrInd++;
            this.tempXtY[n2] = this.tempXtY[n2] + distAttr * targetDistance;
        }
        ++this.m_NeighIndex;
    }

    @Override
    public void updateStatistics(int targetIndex, int numNeighInd, double sumNeighbourWeights) {
        if (targetIndex != -1) {
            throw new RuntimeException("Wrong target index: " + targetIndex);
        }
        double[] distT = Arrays.copyOf(this.tempDistTarget, this.m_NeighIndex);
        if (this.m_IsMacro) {
            this.mDistTarget[numNeighInd][this.mWhereToTarget[numNeighInd]] = ClusUtil.mean(distT);
            int n = numNeighInd;
            this.mWhereToTarget[n] = this.mWhereToTarget[n] + 1;
        }
        for (int attrInd = 0; attrInd < this.m_NbDescriptiveAttrs; ++attrInd) {
            double[] distA = Arrays.copyOf(this.tempDistAttribute[attrInd], this.m_NeighIndex);
            if (this.m_IsMacro) {
                this.mDistAttribute[numNeighInd][attrInd][this.mWhereTo[numNeighInd][attrInd]] = ClusUtil.mean(distA);
                int[] nArray = this.mWhereTo[numNeighInd];
                int n = attrInd;
                nArray[n] = nArray[n] + 1;
                continue;
            }
            double k = this.m_IsTLS ? Steepness.optimalViaTotalSquares(distA, distT) : Steepness.optimalViaStandardSquares(distA, distT);
            double[] dArray = this.m_Coefficient[targetIndex + 1][numNeighInd];
            int n = attrInd;
            dArray[n] = dArray[n] + k;
        }
    }

    @Override
    public void resetTempFields() {
        Arrays.fill(this.tempNormX, 0.0);
        Arrays.fill(this.tempXtY, 0.0);
        Arrays.fill(this.tempDistTarget, 0.0);
        for (double[] row : this.tempDistAttribute) {
            Arrays.fill(row, 0.0);
        }
        this.m_NeighIndex = 0;
    }

    @Override
    public double computeImportances(int targetIndex, int nbNeighInd, int attrInd, boolean isStdClassification, double[] successfulItearions) {
        double k;
        if (this.m_IsMacro) {
            if ((double)this.mWhereTo[nbNeighInd][attrInd] + this.m_NbConstant[nbNeighInd][attrInd] != (double)this.mDistAttribute[nbNeighInd][attrInd].length) {
                throw new RuntimeException(String.format("Something went wrong with the indices: %d + %d != %d", this.mWhereTo[nbNeighInd][attrInd], this.m_NbConstant[nbNeighInd][attrInd], this.mDistAttribute[nbNeighInd][attrInd].length));
            }
            if (this.mWhereToTarget[nbNeighInd] != this.mDistTarget[nbNeighInd].length) {
                throw new RuntimeException(String.format("Something went wrong with the indices: %d != %d", this.mWhereToTarget[nbNeighInd], this.mDistTarget[nbNeighInd].length));
            }
            double[] distA = Arrays.copyOf(this.mDistAttribute[nbNeighInd][attrInd], this.mWhereTo[nbNeighInd][attrInd]);
            double[] distT = this.mDistTarget[nbNeighInd];
            k = this.m_IsTLS ? Steepness.optimalViaTotalSquares(distA, distT) : Steepness.optimalViaStandardSquares(distA, distT);
        } else {
            k = this.m_Coefficient[targetIndex + 1][nbNeighInd][attrInd] / (successfulItearions[targetIndex + 1] - this.m_NbConstant[nbNeighInd][attrInd]);
        }
        return k;
    }

    private static double optimalViaTotalSquares(double[] xs, double[] ys) {
        DoubleMatrix A = new DoubleMatrix(new double[][]{xs, ys}).transpose();
        DoubleMatrix[] svd = Singular.fullSVD(A);
        if (svd[2].get(1, 1) == 0.0) {
            System.out.println(A);
        }
        return -svd[2].get(0, 1) / svd[2].get(1, 1);
    }

    private static double optimalViaStandardSquares(double[] xs, double[] ys) {
        double[] ones = new double[xs.length];
        Arrays.fill(ones, 1.0);
        DoubleMatrix A = new DoubleMatrix(new double[][]{xs}).transpose();
        DoubleMatrix b = new DoubleMatrix(ys);
        DoubleMatrix solution = Solve.solveLeastSquares(A, b);
        return Math.max(solution.get(0, 0), 0.0);
    }

    public static void updateRange(SimpleMatrix m, int row0, int row1, int col0, int col1, double value) {
        for (int r = row0; r < row1; ++r) {
            for (int c = col0; c < col1; ++c) {
                m.set(r, c, value);
            }
        }
    }

    public static void updateRange(SimpleMatrix m, int row0, int row1, int col0, int col1, Random random) {
        for (int r = row0; r < row1; ++r) {
            for (int c = col0; c < col1; ++c) {
                m.set(r, c, random.nextDouble());
            }
        }
    }

    public static void updateRange(SimpleMatrix m, int row0, int col0, SimpleMatrix values) {
        for (int r = 0; r < values.numRows(); ++r) {
            for (int c = 0; c < values.numCols(); ++c) {
                m.set(r + row0, c + col0, values.get(r, c));
            }
        }
    }

    public static int[] range(int i) {
        return IntStream.rangeClosed(0, i - 1).toArray();
    }

    public static int[] range(int i, int j) {
        return IntStream.rangeClosed(i, j - 1).toArray();
    }
}

