package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.util.ShrinkingMatrix;
import org.tribuo.math.optimisers.util.ShrinkingTensor;
import org.tribuo.math.optimisers.util.ShrinkingVector;

/* loaded from: input_file:org/tribuo/math/optimisers/Pegasos.class */
public class Pegasos implements StochasticGradientOptimiser {

    @Config(description = "Step size shrinkage.")
    private double lambda;

    @Config(description = "Base learning rate.")
    private double baseRate;
    private int iteration;
    private Parameters parameters;

    private Pegasos() {
        this.lambda = 0.01d;
        this.baseRate = 0.1d;
        this.iteration = 1;
    }

    public Pegasos(double d, double d2) {
        this.lambda = 0.01d;
        this.baseRate = 0.1d;
        this.iteration = 1;
        this.baseRate = d;
        this.lambda = d2;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.parameters = parameters;
        Tensor[] tensorArr = parameters.get();
        Tensor[] tensorArr2 = new Tensor[tensorArr.length];
        for (int i = 0; i < tensorArr2.length; i++) {
            if (tensorArr[i] instanceof DenseVector) {
                tensorArr2[i] = new ShrinkingVector((DenseVector) tensorArr[i], this.baseRate, this.lambda);
            } else {
                if (!(tensorArr[i] instanceof DenseMatrix)) {
                    throw new IllegalStateException("Unknown Tensor subclass");
                }
                tensorArr2[i] = new ShrinkingMatrix((DenseMatrix) tensorArr[i], this.baseRate, this.lambda);
            }
        }
        parameters.set(tensorArr2);
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        double d2 = this.baseRate / (this.lambda * this.iteration);
        for (Tensor tensor : tensorArr) {
            tensor.scaleInPlace(d2 * d);
        }
        this.iteration++;
        return tensorArr;
    }

    public String toString() {
        return "Pegasos(baseRate=" + this.baseRate + ",lambda=" + this.lambda + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void finalise() {
        Tensor[] tensorArr = this.parameters.get();
        Tensor[] tensorArr2 = new Tensor[tensorArr.length];
        for (int i = 0; i < tensorArr2.length; i++) {
            if (!(tensorArr[i] instanceof ShrinkingTensor)) {
                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with Pegasos");
            }
            tensorArr2[i] = ((ShrinkingTensor) tensorArr[i]).convertToDense();
        }
        this.parameters.set(tensorArr2);
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.parameters = null;
        this.iteration = 1;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Pegasos copy() {
        return new Pegasos(this.lambda, this.baseRate);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m35getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
