package org.tribuo.math.optimisers;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.AdaGradRDADenseTensorProto;
import org.tribuo.math.protos.DenseTensorProto;
import org.tribuo.math.protos.TensorProto;

/* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA.class */
public class AdaGradRDA implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(AdaGradRDA.class.getName());

    @Config(mandatory = true, description = "Initial learning rate used to scale the gradients.")
    private double initialLearningRate;

    @Config(description = "Epsilon for numerical stability around zero.")
    private double epsilon;

    @Config(description = "l1 regularization penalty.")
    private double l1;

    @Config(description = "l2 regularization penalty.")
    private double l2;

    @Config(description = "Number of examples to scale the l1 and l2 penalties by.")
    private int numExamples;
    private Parameters parameters;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA$AdaGradRDAMatrix.class */
    public static class AdaGradRDAMatrix extends DenseMatrix implements AdaGradRDATensor {
        private final double learningRate;
        private final double epsilon;
        private final double l1;
        private final double l2;
        private final double[][] gradSquares;
        private int iteration;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA$AdaGradRDAMatrix$RDAMatrixIterator.class */
        public static class RDAMatrixIterator implements MatrixIterator {
            private final AdaGradRDAMatrix matrix;
            private final int dim2;
            private final MatrixTuple tuple = new MatrixTuple();
            private int i = 0;
            private int j = 0;

            public RDAMatrixIterator(AdaGradRDAMatrix adaGradRDAMatrix) {
                this.matrix = adaGradRDAMatrix;
                this.dim2 = adaGradRDAMatrix.dim2;
            }

            @Override // org.tribuo.math.la.MatrixIterator
            public MatrixTuple getReference() {
                return this.tuple;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.i < this.matrix.dim1 && this.j < this.matrix.dim2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public MatrixTuple next() {
                this.tuple.i = this.i;
                this.tuple.j = this.j;
                this.tuple.value = this.matrix.get(this.i, this.j);
                if (this.j < this.dim2 - 1) {
                    this.j++;
                } else {
                    this.i++;
                    this.j = 0;
                }
                return this.tuple;
            }
        }

        AdaGradRDAMatrix(DenseMatrix denseMatrix, double d, double d2, double d3, double d4) {
            super(denseMatrix);
            this.learningRate = d;
            this.epsilon = d2;
            this.l1 = d3;
            this.l2 = d4;
            this.gradSquares = new double[denseMatrix.getDimension1Size()][denseMatrix.getDimension2Size()];
            this.iteration = 0;
        }

        private AdaGradRDAMatrix(DenseMatrix denseMatrix, double d, double d2, double d3, double d4, double[][] dArr, int i) {
            super(denseMatrix);
            this.learningRate = d;
            this.epsilon = d2;
            this.l1 = d3;
            this.l2 = d4;
            this.gradSquares = dArr;
            if (dArr.length != this.dim1 || dArr[0].length != this.dim2) {
                throw new IllegalArgumentException("Invalid AdaGradRDAMatrix, value matrix is a different shape to gradient matrix, value [" + this.dim1 + ", " + this.dim2 + "], gradient [" + dArr.length + ", " + dArr[0].length + "]");
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                if (dArr[i2].length != this.dim2) {
                    throw new IllegalArgumentException("Invalid AdaGradRDAMatrix, gradient matrix is ragged, expected " + this.dim2 + ", found " + dArr[i2].length + " at index " + i2);
                }
                for (int i3 = 0; i3 < dArr[i2].length; i3++) {
                    if (dArr[i2][i3] < 0.0d) {
                        throw new IllegalArgumentException("Invalid AdaGradRDAMatrix, squared gradient is negative at index [" + i2 + ", " + i3 + "] = " + dArr[i2][i3]);
                    }
                }
            }
            this.iteration = i;
            if (i < 0) {
                throw new IllegalArgumentException("Invalid AdaGradRDAMatrix, iteration must be non-negative, found " + i);
            }
        }

        public static AdaGradRDAMatrix deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
            if (i < 0 || i > 0) {
                throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
            }
            AdaGradRDADenseTensorProto unpack = any.unpack(AdaGradRDADenseTensorProto.class);
            DenseMatrix unpackProto = DenseMatrix.unpackProto(unpack.getData());
            DoubleBuffer asDoubleBuffer = unpack.getGradNorms().asReadOnlyByteBuffer().asDoubleBuffer();
            if (asDoubleBuffer.remaining() != unpackProto.getDimension1Size() * unpackProto.getDimension2Size()) {
                throw new IllegalArgumentException("Invalid proto, claimed " + (unpackProto.getDimension1Size() * unpackProto.getDimension2Size()) + ", but only had " + asDoubleBuffer.remaining() + " values");
            }
            double[][] dArr = new double[unpackProto.getDimension1Size()][unpackProto.getDimension2Size()];
            for (double[] dArr2 : dArr) {
                asDoubleBuffer.get(dArr2);
            }
            return new AdaGradRDAMatrix(unpackProto, unpack.getLearningRate(), unpack.getEpsilon(), unpack.getL1(), unpack.getL2(), dArr, unpack.getIteration());
        }

        @Override // org.tribuo.math.la.DenseMatrix
        /* renamed from: serialize */
        public TensorProto mo18serialize() {
            TensorProto.Builder newBuilder = TensorProto.newBuilder();
            newBuilder.setVersion(0);
            newBuilder.setClassName(AdaGradRDAMatrix.class.getName());
            AdaGradRDADenseTensorProto.Builder newBuilder2 = AdaGradRDADenseTensorProto.newBuilder();
            DenseTensorProto.Builder newBuilder3 = DenseTensorProto.newBuilder();
            newBuilder3.addDimensions(this.dim1);
            newBuilder3.addDimensions(this.dim2);
            ByteBuffer order = ByteBuffer.allocate(this.dim1 * this.dim2 * 8).order(ByteOrder.LITTLE_ENDIAN);
            DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
            for (int i = 0; i < this.values.length; i++) {
                asDoubleBuffer.put(this.values[i]);
            }
            asDoubleBuffer.rewind();
            newBuilder3.setValues(ByteString.copyFrom(order));
            newBuilder2.setData(newBuilder3.m175build());
            newBuilder2.setLearningRate(this.learningRate);
            newBuilder2.setEpsilon(this.epsilon);
            newBuilder2.setL1(this.l1);
            newBuilder2.setL2(this.l2);
            ByteBuffer order2 = ByteBuffer.allocate(this.dim1 * this.dim2 * 8).order(ByteOrder.LITTLE_ENDIAN);
            DoubleBuffer asDoubleBuffer2 = order2.asDoubleBuffer();
            for (int i2 = 0; i2 < this.gradSquares.length; i2++) {
                asDoubleBuffer2.put(this.gradSquares[i2]);
            }
            asDoubleBuffer2.rewind();
            newBuilder2.setGradNorms(ByteString.copyFrom(order2));
            newBuilder2.setIteration(this.iteration);
            newBuilder.setSerializedData(Any.pack(newBuilder2.m81build()));
            return newBuilder.m833build();
        }

        @Override // org.tribuo.math.optimisers.AdaGradRDA.AdaGradRDATensor
        public DenseMatrix convertToDense() {
            return new DenseMatrix((DenseMatrix) this);
        }

        @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Matrix
        public DenseVector leftMultiply(SGDVector sGDVector) {
            if (sGDVector.size() != this.dim2) {
                throw new IllegalArgumentException("input.size() != dim2");
            }
            double[] dArr = new double[this.dim1];
            for (VectorTuple vectorTuple : sGDVector) {
                for (int i = 0; i < dArr.length; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + (get(i, vectorTuple.index) * vectorTuple.value);
                }
            }
            return DenseVector.createDenseVector(dArr);
        }

        @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Tensor
        public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
            if (!(tensor instanceof Matrix)) {
                throw new IllegalStateException("Adding a non-Matrix to a Matrix");
            }
            Matrix matrix = (Matrix) tensor;
            if (this.dim1 != matrix.getDimension1Size() || this.dim2 != matrix.getDimension2Size()) {
                throw new IllegalStateException("Matrices are not the same size, this(" + this.dim1 + "," + this.dim2 + "), other(" + matrix.getDimension1Size() + "," + matrix.getDimension2Size() + ")");
            }
            for (MatrixTuple matrixTuple : matrix) {
                double applyAsDouble = doubleUnaryOperator.applyAsDouble(matrixTuple.value);
                double[] dArr = this.values[matrixTuple.i];
                int i = matrixTuple.j;
                dArr[i] = dArr[i] + applyAsDouble;
                double[] dArr2 = this.gradSquares[matrixTuple.i];
                int i2 = matrixTuple.j;
                dArr2[i2] = dArr2[i2] + (applyAsDouble * applyAsDouble);
            }
        }

        @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Matrix
        public double get(int i, int i2) {
            return this.gradSquares[i][i2] == 0.0d ? this.values[i][i2] : (1.0d / (((Math.sqrt(this.gradSquares[i][i2]) + this.epsilon) / this.learningRate) + (this.iteration * this.l2))) * AdaGradRDATensor.truncate(this.values[i][i2], this.iteration * this.l1);
        }

        @Override // org.tribuo.math.la.DenseMatrix, java.lang.Iterable
        /* renamed from: iterator */
        public Iterator<MatrixTuple> iterator2() {
            return new RDAMatrixIterator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA$AdaGradRDATensor.class */
    public interface AdaGradRDATensor {
        Tensor convertToDense();

        static double truncate(double d, double d2) {
            if (d > d2) {
                return d - d2;
            }
            if (d < (-d2)) {
                return d + d2;
            }
            return 0.0d;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA$AdaGradRDAVector.class */
    public static class AdaGradRDAVector extends DenseVector implements AdaGradRDATensor {
        private final double learningRate;
        private final double epsilon;
        private final double l1;
        private final double l2;
        private final double[] gradSquares;
        private int iteration;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/tribuo/math/optimisers/AdaGradRDA$AdaGradRDAVector$RDAVectorIterator.class */
        public static class RDAVectorIterator implements VectorIterator {
            private final AdaGradRDAVector vector;
            private final VectorTuple tuple = new VectorTuple();
            private int index = 0;

            public RDAVectorIterator(AdaGradRDAVector adaGradRDAVector) {
                this.vector = adaGradRDAVector;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.index < this.vector.size();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public VectorTuple next() {
                this.tuple.index = this.index;
                this.tuple.value = this.vector.get(this.index);
                this.index++;
                return this.tuple;
            }

            @Override // org.tribuo.math.la.VectorIterator
            public VectorTuple getReference() {
                return this.tuple;
            }
        }

        AdaGradRDAVector(DenseVector denseVector, double d, double d2, double d3, double d4) {
            super(denseVector);
            this.learningRate = d;
            this.epsilon = d2;
            this.l1 = d3;
            this.l2 = d4;
            this.gradSquares = new double[denseVector.size()];
            this.iteration = 0;
        }

        private AdaGradRDAVector(double[] dArr, double d, double d2, double d3, double d4, double[] dArr2, int i) {
            super(dArr);
            this.learningRate = d;
            this.epsilon = d2;
            this.l1 = d3;
            this.l2 = d4;
            this.gradSquares = dArr2;
            if (dArr2.length != dArr.length) {
                throw new IllegalArgumentException("Invalid AdaGradRDAVector, value vector is a different shape to gradient vector, value [" + dArr.length + "], gradient [" + dArr2.length + "]");
            }
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                if (dArr2[i2] < 0.0d) {
                    throw new IllegalArgumentException("Invalid AdaGradRDAVector, squared gradient is negative at index [" + i2 + "] = " + dArr2[i2]);
                }
            }
            this.iteration = i;
            if (i < 0) {
                throw new IllegalArgumentException("Invalid AdaGradRDAVector, iteration must be non-negative, found " + i);
            }
        }

        public static AdaGradRDAVector deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
            if (i < 0 || i > 0) {
                throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
            }
            AdaGradRDADenseTensorProto unpack = any.unpack(AdaGradRDADenseTensorProto.class);
            DenseVector unpackProto = DenseVector.unpackProto(unpack.getData());
            DoubleBuffer asDoubleBuffer = unpack.getGradNorms().asReadOnlyByteBuffer().asDoubleBuffer();
            if (asDoubleBuffer.remaining() != unpackProto.size()) {
                throw new IllegalArgumentException("Invalid proto, claimed " + unpackProto.size() + ", but only had " + asDoubleBuffer.remaining() + " values");
            }
            double[] dArr = new double[unpackProto.size()];
            asDoubleBuffer.get(dArr);
            return new AdaGradRDAVector(unpackProto.toArray(), unpack.getLearningRate(), unpack.getEpsilon(), unpack.getL1(), unpack.getL2(), dArr, unpack.getIteration());
        }

        @Override // org.tribuo.math.la.DenseVector
        /* renamed from: serialize */
        public TensorProto mo20serialize() {
            TensorProto.Builder newBuilder = TensorProto.newBuilder();
            newBuilder.setVersion(0);
            newBuilder.setClassName(AdaGradRDAVector.class.getName());
            AdaGradRDADenseTensorProto.Builder newBuilder2 = AdaGradRDADenseTensorProto.newBuilder();
            DenseTensorProto.Builder newBuilder3 = DenseTensorProto.newBuilder();
            newBuilder3.addDimensions(size());
            ByteBuffer order = ByteBuffer.allocate(this.elements.length * 8).order(ByteOrder.LITTLE_ENDIAN);
            DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
            asDoubleBuffer.put(this.elements);
            asDoubleBuffer.rewind();
            newBuilder3.setValues(ByteString.copyFrom(order));
            newBuilder2.setData(newBuilder3.m175build());
            newBuilder2.setLearningRate(this.learningRate);
            newBuilder2.setEpsilon(this.epsilon);
            newBuilder2.setL1(this.l1);
            newBuilder2.setL2(this.l2);
            ByteBuffer order2 = ByteBuffer.allocate(this.gradSquares.length * 8).order(ByteOrder.LITTLE_ENDIAN);
            DoubleBuffer asDoubleBuffer2 = order2.asDoubleBuffer();
            asDoubleBuffer2.put(this.gradSquares);
            asDoubleBuffer2.rewind();
            newBuilder2.setGradNorms(ByteString.copyFrom(order2));
            newBuilder2.setIteration(this.iteration);
            newBuilder.setSerializedData(Any.pack(newBuilder2.m81build()));
            return newBuilder.m833build();
        }

        @Override // org.tribuo.math.optimisers.AdaGradRDA.AdaGradRDATensor
        public DenseVector convertToDense() {
            return DenseVector.createDenseVector(toArray());
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
        public AdaGradRDAVector copy() {
            return new AdaGradRDAVector(Arrays.copyOf(this.elements, this.elements.length), this.learningRate, this.epsilon, this.l1, this.l2, Arrays.copyOf(this.gradSquares, this.gradSquares.length), this.iteration);
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double[] toArray() {
            double[] dArr = new double[this.elements.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = get(i);
            }
            return dArr;
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double get(int i) {
            return this.gradSquares[i] == 0.0d ? this.elements[i] : (1.0d / (((Math.sqrt(this.gradSquares[i]) + this.epsilon) / this.learningRate) + (this.iteration * this.l2))) * AdaGradRDATensor.truncate(this.elements[i], this.iteration * this.l1);
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double sum() {
            double d = 0.0d;
            for (int i = 0; i < this.elements.length; i++) {
                d += get(i);
            }
            return d;
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.Tensor
        public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
            this.iteration++;
            for (VectorTuple vectorTuple : (SGDVector) tensor) {
                double applyAsDouble = doubleUnaryOperator.applyAsDouble(vectorTuple.value);
                double[] dArr = this.elements;
                int i = vectorTuple.index;
                dArr[i] = dArr[i] + applyAsDouble;
                double[] dArr2 = this.gradSquares;
                int i2 = vectorTuple.index;
                dArr2[i2] = dArr2[i2] + (applyAsDouble * applyAsDouble);
            }
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public int indexOfMax() {
            int i = 0;
            double d = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.elements.length; i2++) {
                double d2 = get(i2);
                if (d2 > d) {
                    i = i2;
                    d = d2;
                }
            }
            return i;
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double maxValue() {
            double d = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.elements.length; i++) {
                double d2 = get(i);
                if (d2 > d) {
                    d = d2;
                }
            }
            return d;
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double minValue() {
            double d = Double.POSITIVE_INFINITY;
            for (int i = 0; i < this.elements.length; i++) {
                double d2 = get(i);
                if (d2 < d) {
                    d = d2;
                }
            }
            return d;
        }

        @Override // org.tribuo.math.la.DenseVector, org.tribuo.math.la.SGDVector
        public double dot(SGDVector sGDVector) {
            double d = 0.0d;
            for (VectorTuple vectorTuple : sGDVector) {
                d += get(vectorTuple.index) * vectorTuple.value;
            }
            return d;
        }

        @Override // org.tribuo.math.la.DenseVector, java.lang.Iterable
        /* renamed from: iterator */
        public Iterator<VectorTuple> iterator2() {
            return new RDAVectorIterator(this);
        }
    }

    public AdaGradRDA(double d, double d2, double d3, double d4, int i) {
        this.epsilon = 1.0E-6d;
        this.l1 = 0.0d;
        this.l2 = 0.0d;
        this.numExamples = 1;
        this.parameters = null;
        this.initialLearningRate = d;
        this.epsilon = d2;
        this.l1 = d3;
        this.l2 = d4;
        this.numExamples = i;
    }

    public AdaGradRDA(double d, double d2) {
        this(d, d2, 0.0d, 0.0d, 1);
    }

    private AdaGradRDA() {
        this.epsilon = 1.0E-6d;
        this.l1 = 0.0d;
        this.l2 = 0.0d;
        this.numExamples = 1;
        this.parameters = null;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void initialise(Parameters parameters) {
        this.parameters = parameters;
        Tensor[] tensorArr = parameters.get();
        Tensor[] tensorArr2 = new Tensor[tensorArr.length];
        for (int i = 0; i < tensorArr2.length; i++) {
            if (tensorArr[i] instanceof DenseVector) {
                tensorArr2[i] = new AdaGradRDAVector((DenseVector) tensorArr[i], this.initialLearningRate, this.epsilon, this.l1 / this.numExamples, this.l2 / this.numExamples);
            } else {
                if (!(tensorArr[i] instanceof DenseMatrix)) {
                    throw new IllegalStateException("Unknown Tensor subclass");
                }
                tensorArr2[i] = new AdaGradRDAMatrix((DenseMatrix) tensorArr[i], this.initialLearningRate, this.epsilon, this.l1 / this.numExamples, this.l2 / this.numExamples);
            }
        }
        parameters.set(tensorArr2);
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public Tensor[] step(Tensor[] tensorArr, double d) {
        for (Tensor tensor : tensorArr) {
            tensor.scaleInPlace(d);
        }
        return tensorArr;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void finalise() {
        Tensor[] tensorArr = this.parameters.get();
        Tensor[] tensorArr2 = new Tensor[tensorArr.length];
        for (int i = 0; i < tensorArr2.length; i++) {
            if (!(tensorArr[i] instanceof AdaGradRDATensor)) {
                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with AdaGradRDA");
            }
            tensorArr2[i] = ((AdaGradRDATensor) tensorArr[i]).convertToDense();
        }
        this.parameters.set(tensorArr2);
    }

    public String toString() {
        return "AdaGradRDA(initialLearningRate=" + this.initialLearningRate + ",epsilon=" + this.epsilon + ",l1=" + this.l1 + ",l2=" + this.l2 + ")";
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public void reset() {
        this.parameters = null;
    }

    @Override // org.tribuo.math.StochasticGradientOptimiser
    public AdaGradRDA copy() {
        return new AdaGradRDA(this.initialLearningRate, this.epsilon, this.l1, this.l2, this.numExamples);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m30getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "StochasticGradientOptimiser");
    }
}
