package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import java.util.logging.Logger;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.SGD;
import org.tribuo.math.protos.ShrinkingDenseTensorProto;

/* loaded from: input_file:org/tribuo/math/optimisers/GradientOptimiserOptions.class */
public class GradientOptimiserOptions implements Options {
    private static final Logger logger = Logger.getLogger(GradientOptimiserOptions.class.getName());

    @Option(longName = "sgo-type", usage = "Selects the gradient optimiser. Defaults to ADAGRAD.")
    private StochasticGradientOptimiserType optimiserType = StochasticGradientOptimiserType.ADAGRAD;

    @Option(longName = "sgo-learning-rate", usage = "Learning rate for AdaGrad, AdaGradRDA, Adam, Pegasos.")
    public double learningRate = 0.18d;

    @Option(longName = "sgo-epsilon", usage = "Epsilon for AdaDelta, AdaGrad, AdaGradRDA, Adam.")
    public double epsilon = 0.066d;

    @Option(longName = "sgo-rho", usage = "Rho for RMSProp, AdaDelta, SGD with Momentum.")
    public double rho = 0.95d;

    @Option(longName = "sgo-lambda", usage = "Lambda for Pegasos.")
    public double lambda = 0.01d;

    @Option(longName = "sgo-parameter-averaging", usage = "Use parameter averaging.")
    public boolean paramAve = false;

    @Option(longName = "sgo-momentum", usage = "Use momentum in SGD.")
    public SGD.Momentum momentum = SGD.Momentum.NONE;

    /* loaded from: input_file:org/tribuo/math/optimisers/GradientOptimiserOptions$StochasticGradientOptimiserType.class */
    public enum StochasticGradientOptimiserType {
        ADADELTA,
        ADAGRAD,
        ADAGRADRDA,
        ADAM,
        PEGASOS,
        RMSPROP,
        CONSTANTSGD,
        LINEARSGD,
        SQRTSGD
    }

    public StochasticGradientOptimiser getOptimiser() {
        StochasticGradientOptimiser sqrtDecaySGD;
        switch (this.optimiserType.ordinal()) {
            case 0:
                sqrtDecaySGD = new AdaDelta(this.rho, this.epsilon);
                break;
            case 1:
                sqrtDecaySGD = new AdaGrad(this.learningRate, this.epsilon);
                break;
            case 2:
                sqrtDecaySGD = new AdaGradRDA(this.learningRate, this.epsilon);
                break;
            case 3:
                sqrtDecaySGD = new Adam(this.learningRate, this.epsilon);
                break;
            case 4:
                sqrtDecaySGD = new Pegasos(this.learningRate, this.lambda);
                break;
            case 5:
                sqrtDecaySGD = new RMSProp(this.learningRate, this.rho);
                break;
            case 6:
                sqrtDecaySGD = SGD.getSimpleSGD(this.learningRate, this.rho, this.momentum);
                break;
            case 7:
                sqrtDecaySGD = SGD.getLinearDecaySGD(this.learningRate, this.rho, this.momentum);
                break;
            case ShrinkingDenseTensorProto.MULTIPLIER_FIELD_NUMBER /* 8 */:
                sqrtDecaySGD = SGD.getSqrtDecaySGD(this.learningRate, this.rho, this.momentum);
                break;
            default:
                throw new IllegalArgumentException("Unhandled StochasticGradientOptimiser type: " + this.optimiserType);
        }
        if (!this.paramAve) {
            return sqrtDecaySGD;
        }
        logger.info("Using parameter averaging");
        return new ParameterAveraging(sqrtDecaySGD);
    }
}
