package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

/* loaded from: input_file:org/tribuo/math/optimisers/SGD.class */
public abstract class SGD implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(SGD.class.getName());

    @Config(mandatory = true, description = "Initial learning rate.")
    protected double initialLearningRate;

    @Config(mandatory = true, description = "Momentum type to use.")
    protected Momentum useMomentum;

    @Config(description = "Momentum scaling factor.")
    protected double rho;
    protected int iteration;
    private Tensor[] momentum;

    /* loaded from: input_file:org/tribuo/math/optimisers/SGD$Momentum.class */
    public enum Momentum {
        NONE,
        STANDARD,
        NESTEROV
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SGD(double d) {
        this(d, 0.0d, Momentum.NONE);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SGD(double d, double d2, Momentum momentum) {
        this.rho = 0.0d;
        this.iteration = 0;
        this.initialLearningRate = d;
        this.useMomentum = momentum;
        this.rho = d2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SGD() {
        this.rho = 0.0d;
        this.iteration = 0;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        if (this.useMomentum != Momentum.NONE) {
            this.momentum = parameters.getEmptyCopy();
        }
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        this.iteration++;
        double learningRate = learningRate();
        DoubleUnaryOperator doubleUnaryOperator = d2 -> {
            return d2 * learningRate * d;
        };
        DoubleUnaryOperator doubleUnaryOperator2 = d3 -> {
            return d3 * learningRate * d * this.rho;
        };
        for (int i = 0; i < tensorArr.length; i++) {
            switch (this.useMomentum) {
                case NONE:
                default:
                    tensorArr[i].scaleInPlace(d * learningRate);
                    break;
                case STANDARD:
                    this.momentum[i].scaleInPlace(this.rho);
                    this.momentum[i].intersectAndAddInPlace(tensorArr[i]);
                    tensorArr[i].scaleInPlace(0.0d);
                    tensorArr[i].intersectAndAddInPlace(this.momentum[i], doubleUnaryOperator);
                    break;
                case NESTEROV:
                    this.momentum[i].scaleInPlace(this.rho);
                    this.momentum[i].intersectAndAddInPlace(tensorArr[i]);
                    tensorArr[i].scaleInPlace(d * learningRate);
                    tensorArr[i].intersectAndAddInPlace(this.momentum[i], doubleUnaryOperator2);
                    break;
            }
        }
        return tensorArr;
    }

    public abstract double learningRate();

    protected abstract String sgdType();

    public String toString() {
        switch (this.useMomentum.ordinal()) {
            case 1:
                return "SGD+Momentum(type=" + sgdType() + ",initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ")";
            case 2:
                return "SGD+NesterovMomentum(type=" + sgdType() + ",initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ")";
            default:
                return "SGD(type=" + sgdType() + ",initialLearningRate=" + this.initialLearningRate + ")";
        }
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.momentum = null;
        this.iteration = 0;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m39getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }

    public static SGD getSimpleSGD(double d) {
        return new SimpleSGD(d);
    }

    public static SGD getSimpleSGD(double d, double d2, Momentum momentum) {
        return new SimpleSGD(d, d2, momentum);
    }

    public static SGD getLinearDecaySGD(double d) {
        return new LinearDecaySGD(d);
    }

    public static SGD getLinearDecaySGD(double d, double d2, Momentum momentum) {
        return new LinearDecaySGD(d, d2, momentum);
    }

    public static SGD getSqrtDecaySGD(double d) {
        return new SqrtDecaySGD(d);
    }

    public static SGD getSqrtDecaySGD(double d, double d2, Momentum momentum) {
        return new SqrtDecaySGD(d, d2, momentum);
    }
}
