/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.util.tools.optimization.gd;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import si.ijs.kt.clus.algo.rules.ClusRuleSet;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsRules;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.format.ClusFormat;
import si.ijs.kt.clus.util.format.ClusNumberFormat;
import si.ijs.kt.clus.util.tools.optimization.OptimizationAlgorithm;
import si.ijs.kt.clus.util.tools.optimization.OptimizationProblem;
import si.ijs.kt.clus.util.tools.optimization.gd.GDProblem;

public class GDAlgorithm
extends OptimizationAlgorithm {
    private GDProblem m_GDProbl;
    protected ArrayList<Double> m_weights;
    protected int m_earlyStopStep;
    protected int m_earlyStopStepsizeReducedNb;
    private ArrayList<Integer> m_iOscillatingWeights;
    private double m_minStepSizeReduction;
    private double[] m_prevChange;
    private double[] m_newChange;
    private int[] m_iPrevDimension;
    private int[] m_iNewDimension;

    public GDAlgorithm(ClusStatManager stat_mgr, OptimizationProblem.OptimizationParameter dataInformation, ClusRuleSet rset) {
        super(stat_mgr);
        this.m_GDProbl = new GDProblem(stat_mgr, dataInformation, rset);
        this.initGDForNewRunWithSamePredictions();
        this.m_earlyStopStep = 100;
        if (GDProblem.m_printGDDebugInformation) {
            String fname = this.getSettings().getData().getDataFile();
            try (PrintWriter wrt_pred = new PrintWriter(new OutputStreamWriter(new FileOutputStream(fname + ".gd-pred")));
                 PrintWriter wrt_true = new PrintWriter(new OutputStreamWriter(new FileOutputStream(fname + ".gd-true")));){
                this.m_GDProbl.printPredictionsToFile(wrt_pred);
                wrt_pred.close();
                this.m_GDProbl.printTrueValuesToFile(wrt_true);
                wrt_true.close();
            }
            catch (FileNotFoundException e) {
                e.printStackTrace();
                System.exit(1);
            }
        }
    }

    public void initGDForNewRunWithSamePredictions() {
        this.m_GDProbl.initGDForNewRunWithSamePredictions();
        this.m_weights = this.m_GDProbl.getInitialWeightVector();
        this.m_prevChange = null;
        this.m_iPrevDimension = null;
        this.m_iNewDimension = null;
        this.m_newChange = null;
        this.m_minStepSizeReduction = 1.0;
        this.m_earlyStopStepsizeReducedNb = 0;
        this.m_iOscillatingWeights = this.m_GDProbl.m_bannedWeights != null ? new ArrayList() : null;
    }

    @Override
    public ArrayList<Double> optimize() {
        ClusLogger.info("Gradient descent: Optimizing rule weights (" + this.getSettings().getRules().getOptGDMaxIter() + ") ");
        PrintWriter wrt_log = null;
        if (GDProblem.m_printGDDebugInformation) {
            try {
                wrt_log = new PrintWriter(new OutputStreamWriter(new FileOutputStream("gradDesc.log")));
            }
            catch (Exception e) {
                e.printStackTrace();
                ClusLogger.severe("Log file could not be opened. Logging omitted.");
            }
        }
        if (this.m_GDProbl.isClassifTask()) {
            try {
                throw new ClusException("Classification not yeat implemented for gradient descent. Skipping the optimization.");
            }
            catch (Exception s) {
                s.printStackTrace();
                return null;
            }
        }
        this.m_GDProbl.fullGradientComputation(this.m_weights);
        int nbOfIterations = 0;
        while (nbOfIterations < this.getSettings().getRules().getOptGDMaxIter()) {
            int iiGradient;
            boolean debugPrint;
            if ((double)nbOfIterations % Math.ceil((double)this.getSettings().getRules().getOptGDMaxIter() / 50.0) == 0.0) {
                ClusLogger.finer(String.format("%s %%", (double)(nbOfIterations * 100) / Math.ceil(this.getSettings().getRules().getOptGDMaxIter())));
            }
            if (nbOfIterations % this.m_earlyStopStep == 0 && this.getSettings().getRules().getOptGDEarlyStopAmount() > 0.0 && this.m_GDProbl.isEarlyStop(this.m_weights)) {
                if (GDProblem.m_printGDDebugInformation) {
                    wrt_log.println("Increase in test fitness. Reducing step size or stopping.");
                }
                ClusLogger.fine("Overfitting after " + nbOfIterations + " iterations.");
                if (!this.getSettings().getRules().isOptGDIsDynStepsize() && this.m_earlyStopStepsizeReducedNb < this.getSettings().getRules().getOptGDNbOfStepSizeReduce()) {
                    ++this.m_earlyStopStepsizeReducedNb;
                    this.m_GDProbl.dropStepSize(0.1);
                    this.m_GDProbl.restoreBestWeight(this.m_weights);
                    this.m_GDProbl.fullGradientComputation(this.m_weights);
                    ClusLogger.fine("Reducing step, continuing.");
                } else {
                    ClusLogger.info("Stopping.");
                    if (!GDProblem.m_printGDDebugInformation) break;
                    wrt_log.println("Early stopping detected after " + nbOfIterations + " iterations.");
                    break;
                }
            }
            this.OutputLog(nbOfIterations, wrt_log);
            int[] maxGradients = this.m_GDProbl.getMaxGradients(nbOfIterations);
            boolean oscillation = false;
            this.storeGradientsForOscillation(maxGradients);
            double[] valueChange = new double[maxGradients.length];
            if (this.getSettings().getGeneral().getVerbose() > 0) {
                // empty if block
            }
            if (debugPrint = false) {
                System.out.println("\nDEBUG: Computing covariances, total " + maxGradients.length);
            }
            for (iiGradient = 0; iiGradient < maxGradients.length; ++iiGradient) {
                if (debugPrint) {
                    System.out.print(iiGradient % 10);
                }
                this.m_GDProbl.computeCovariancesIfNeeded(maxGradients[iiGradient]);
                valueChange[iiGradient] = this.m_GDProbl.howMuchWeightChanges(maxGradients[iiGradient]);
                if (nbOfIterations >= 100) continue;
                oscillation = this.detectOscillation(iiGradient, valueChange[iiGradient]) || oscillation;
            }
            if (debugPrint) {
                System.out.println("\nDEBUG: Computing covariances ended");
            }
            if (oscillation && !this.getSettings().getRules().isOptGDIsDynStepsize()) {
                if (GDProblem.m_printGDDebugInformation) {
                    wrt_log.println("Detected oscillation, reducing step size of: " + this.m_GDProbl.m_stepSize);
                }
                if (debugPrint) {
                    ClusLogger.info("DEBUG: Detected oscillation on iteration " + nbOfIterations + ", reducing step size of: " + this.m_GDProbl.m_stepSize);
                }
                if (nbOfIterations > 10 && (this.getSettings().getRules().getOptGDMTGradientCombine().equals((Object)SettingsRules.OptimizationGDMTCombineGradient.MaxLoss) || this.getSettings().getRules().getOptGDMTGradientCombine().equals((Object)SettingsRules.OptimizationGDMTCombineGradient.MaxLossFast))) {
                    this.putOscillatingWeightsToBan(nbOfIterations);
                    continue;
                }
                this.reversePreviousStep();
                this.reduceStepSizeDueOscillation();
                continue;
            }
            if (nbOfIterations < 100) {
                this.storeTheOscillationData();
            }
            for (iiGradient = 0; iiGradient < maxGradients.length; ++iiGradient) {
                this.m_weights.set(maxGradients[iiGradient], this.m_weights.get(maxGradients[iiGradient]) + valueChange[iiGradient]);
            }
            this.m_GDProbl.modifyGradients(maxGradients, this.m_weights);
            ++nbOfIterations;
        }
        ClusLogger.info("Done!");
        if (this.getSettings().getRules().getOptGDEarlyStopAmount() > 0.0) {
            this.m_GDProbl.isEarlyStop(this.m_weights);
            this.m_GDProbl.restoreBestWeight(this.m_weights);
        }
        if (GDProblem.m_printGDDebugInformation) {
            wrt_log.println("The result of optimization");
        }
        this.OutputLog(nbOfIterations, wrt_log);
        if (GDProblem.m_printGDDebugInformation) {
            wrt_log.close();
        }
        return this.m_weights;
    }

    private void putOscillatingWeightsToBan(int iterationNb) {
        for (int iWeight = 0; iWeight < this.m_iOscillatingWeights.size(); ++iWeight) {
            this.m_GDProbl.m_bannedWeights[this.m_iOscillatingWeights.get((int)iWeight).intValue()] = iterationNb + 50;
        }
        this.m_iOscillatingWeights.clear();
    }

    private void reduceStepSizeDueOscillation() {
        this.m_GDProbl.dropStepSize(this.m_minStepSizeReduction * 0.99);
        this.m_minStepSizeReduction = 1.0;
    }

    private void storeTheOscillationData() {
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            return;
        }
        this.m_prevChange = (double[])this.m_newChange.clone();
        this.m_iPrevDimension = (int[])this.m_iNewDimension.clone();
    }

    private void reversePreviousStep() {
        for (int iiGradient = 0; iiGradient < this.m_iPrevDimension.length; ++iiGradient) {
            this.m_weights.set(this.m_iPrevDimension[iiGradient], this.m_weights.get(this.m_iPrevDimension[iiGradient]) - this.m_prevChange[iiGradient]);
        }
        this.m_GDProbl.fullGradientComputation(this.m_weights);
        this.m_prevChange = null;
        this.m_iPrevDimension = null;
        if (this.m_iOscillatingWeights != null) {
            this.m_iOscillatingWeights.clear();
        }
    }

    private void storeGradientsForOscillation(int[] maxGradients) {
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            return;
        }
        this.m_iNewDimension = (int[])maxGradients.clone();
        this.m_newChange = new double[maxGradients.length];
    }

    private boolean detectOscillation(int iiNewChange, double valueChange) {
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            return false;
        }
        boolean detectOscillation = false;
        this.m_newChange[iiNewChange] = valueChange;
        for (int iiPrevChange = 0; this.m_prevChange != null && iiPrevChange < this.m_prevChange.length; ++iiPrevChange) {
            double needReduction;
            if (this.m_iPrevDimension[iiPrevChange] != this.m_iNewDimension[iiNewChange]) continue;
            if (!(valueChange * this.m_prevChange[iiPrevChange] < 0.0) || !(Math.abs(valueChange) > Math.abs(this.m_prevChange[iiPrevChange]))) break;
            if (this.m_iOscillatingWeights != null) {
                this.m_iOscillatingWeights.add(this.m_iPrevDimension[iiPrevChange]);
            }
            if ((needReduction = Math.abs(this.m_prevChange[iiPrevChange]) / Math.abs(valueChange)) < this.m_minStepSizeReduction) {
                this.m_minStepSizeReduction = needReduction;
            }
            detectOscillation = true;
            break;
        }
        return detectOscillation;
    }

    public void OutputLog(int iterNro, PrintWriter wrt) {
        if (!GDProblem.m_printGDDebugInformation) {
            return;
        }
        ClusNumberFormat fr = ClusFormat.SIX_AFTER_DOT;
        double trainingFitness = this.m_GDProbl.calcFitness(this.m_weights, this.m_GDProbl.m_ruleSet);
        double testFitness = 0.0;
        if (this.getSettings().getRules().getOptGDEarlyStopAmount() > 0.0) {
            testFitness = this.m_GDProbl.m_earlyStopProbl.calcFitness(this.m_weights, this.m_GDProbl.m_ruleSet);
        }
        wrt.print("Iteration " + iterNro + " ");
        if (this.getSettings().getRules().isOptGDIsDynStepsize()) {
            wrt.print("Step size: " + this.m_GDProbl.m_stepSize + " ");
        }
        wrt.print("(" + fr.format(trainingFitness) + ", " + fr.format(testFitness) + "): ");
        for (int i = 0; i < this.m_weights.size(); ++i) {
            wrt.print(fr.format(this.m_weights.get(i)) + "\t");
        }
        wrt.print("\n");
    }

    public double getBestFitness() {
        return this.m_GDProbl.getBestFitness();
    }

    @Override
    public void postProcess(ClusRuleSet rset) {
        this.m_GDProbl.changeRuleSetToUndoNormNormalization(rset);
    }
}

