package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

/* loaded from: input_file:org/tribuo/math/optimisers/AdaDelta.class */
public class AdaDelta implements StochasticGradientOptimiser {

    @Config(description = "Momentum value.")
    private double rho;

    @Config(description = "Epsilon for numerical stability.")
    private double epsilon;
    private Tensor[] gradsSquared;
    private Tensor[] velocitySquared;

    public AdaDelta(double d, double d2) {
        this.rho = 0.95d;
        this.epsilon = 1.0E-6d;
        this.rho = d;
        this.epsilon = d2;
    }

    public AdaDelta(double d) {
        this(0.95d, d);
    }

    public AdaDelta() {
        this(0.95d, 1.0E-6d);
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        this.velocitySquared = parameters.getEmptyCopy();
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        for (int i = 0; i < tensorArr.length; i++) {
            this.gradsSquared[i].scaleInPlace(this.rho);
            this.gradsSquared[i].intersectAndAddInPlace(tensorArr[i], d2 -> {
                return d2 * d2 * (1.0d - this.rho);
            });
            tensorArr[i].hadamardProductInPlace(this.velocitySquared[i], d3 -> {
                return Math.sqrt(d3 + this.epsilon);
            });
            tensorArr[i].hadamardProductInPlace(this.gradsSquared[i], d4 -> {
                return 1.0d / Math.sqrt(d4 + this.epsilon);
            });
            this.velocitySquared[i].scaleInPlace(this.rho);
            this.velocitySquared[i].intersectAndAddInPlace(tensorArr[i], d5 -> {
                return d5 * d5 * (1.0d - this.rho);
            });
        }
        return tensorArr;
    }

    public String toString() {
        return "AdaDelta(rho=" + this.rho + ",epsilon=" + this.epsilon + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.gradsSquared = null;
        this.velocitySquared = null;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public AdaDelta copy() {
        return new AdaDelta(this.rho, this.epsilon);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m26getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
