package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

/* loaded from: input_file:org/tribuo/math/optimisers/ParameterAveraging.class */
public class ParameterAveraging implements StochasticGradientOptimiser {

    @Config(mandatory = true, description = "Inner optimiser to average parameters across.")
    private StochasticGradientOptimiser optimiser;
    private int iterations = 0;
    private Tensor[] weights;
    private Parameters parameters;

    public ParameterAveraging(StochasticGradientOptimiser stochasticGradientOptimiser) {
        this.optimiser = stochasticGradientOptimiser;
    }

    private ParameterAveraging() {
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.optimiser.initialise(parameters);
        this.weights = parameters.getEmptyCopy();
        this.parameters = parameters;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        this.iterations++;
        Tensor[] step = this.optimiser.step(tensorArr, d);
        for (int i = 0; i < step.length; i++) {
            this.weights[i].intersectAndAddInPlace(step[i], d2 -> {
                return d2 * this.iterations;
            });
        }
        return step;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void finalise() {
        Tensor[] tensorArr = this.parameters.get();
        for (int i = 0; i < tensorArr.length; i++) {
            tensorArr[i].intersectAndAddInPlace(this.weights[i], d -> {
                return (-d) / this.iterations;
            });
        }
    }

    public String toString() {
        return "ParameterAveraging(optimiser=" + this.optimiser.toString() + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.optimiser.reset();
        this.iterations = 0;
        this.weights = null;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public ParameterAveraging copy() {
        return new ParameterAveraging(this.optimiser.copy());
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m34getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
