package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

/* loaded from: input_file:org/tribuo/math/optimisers/Adam.class */
public class Adam implements StochasticGradientOptimiser {

    @Config(description = "Learning rate to scale the gradients by.")
    private double initialLearningRate;

    @Config(description = "The beta one parameter.")
    private double betaOne;

    @Config(description = "The beta two parameter.")
    private double betaTwo;

    @Config(description = "Epsilon for numerical stability.")
    private double epsilon;
    private int iterations;
    private Tensor[] firstMoment;
    private Tensor[] secondMoment;

    public Adam(double d, double d2, double d3, double d4) {
        this.initialLearningRate = 0.001d;
        this.betaOne = 0.9d;
        this.betaTwo = 0.99d;
        this.epsilon = 1.0E-6d;
        this.iterations = 0;
        this.initialLearningRate = d;
        this.betaOne = d2;
        this.betaTwo = d3;
        this.epsilon = d4;
        this.iterations = 0;
    }

    public Adam(double d, double d2) {
        this(d, 0.9d, 0.999d, d2);
    }

    public Adam() {
        this(0.001d, 0.9d, 0.999d, 1.0E-6d);
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.firstMoment = parameters.getEmptyCopy();
        this.secondMoment = parameters.getEmptyCopy();
        this.iterations = 0;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        this.iterations++;
        double sqrt = (this.initialLearningRate * Math.sqrt(1.0d - Math.pow(this.betaTwo, this.iterations))) / (1.0d - Math.pow(this.betaOne, this.iterations));
        DoubleUnaryOperator doubleUnaryOperator = d2 -> {
            return d2 * sqrt;
        };
        for (int i = 0; i < tensorArr.length; i++) {
            this.firstMoment[i].scaleInPlace(this.betaOne);
            this.firstMoment[i].intersectAndAddInPlace(tensorArr[i], d3 -> {
                return d3 * (1.0d - this.betaOne);
            });
            this.secondMoment[i].scaleInPlace(this.betaTwo);
            this.secondMoment[i].intersectAndAddInPlace(tensorArr[i], d4 -> {
                return d4 * d4 * (1.0d - this.betaTwo);
            });
            tensorArr[i].scaleInPlace(0.0d);
            tensorArr[i].intersectAndAddInPlace(this.firstMoment[i], doubleUnaryOperator);
            tensorArr[i].hadamardProductInPlace(this.secondMoment[i], d5 -> {
                return Math.sqrt(d5) + this.epsilon;
            });
        }
        return tensorArr;
    }

    public String toString() {
        return "Adam(learningRate=" + this.initialLearningRate + ",betaOne=" + this.betaOne + ",betaTwo=" + this.betaTwo + ",epsilon=" + this.epsilon + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.firstMoment = null;
        this.secondMoment = null;
        this.iterations = 0;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Adam copy() {
        return new Adam(this.initialLearningRate, this.betaOne, this.betaTwo, this.epsilon);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m31getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
