package org.tribuo.math.optimisers.util;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.util.Iterator;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.DenseTensorProto;
import org.tribuo.math.protos.ShrinkingDenseTensorProto;
import org.tribuo.math.protos.TensorProto;

/* loaded from: input_file:org/tribuo/math/optimisers/util/ShrinkingMatrix.class */
public class ShrinkingMatrix extends DenseMatrix implements ShrinkingTensor {
    private final double baseRate;
    private final double lambdaSqrt;
    private final boolean scaleShrinking;
    private final boolean reproject;
    private double squaredTwoNorm;
    private int iteration;
    private double multiplier;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/optimisers/util/ShrinkingMatrix$ShrinkingMatrixIterator.class */
    public class ShrinkingMatrixIterator implements MatrixIterator {
        private final ShrinkingMatrix matrix;
        private final MatrixTuple tuple = new MatrixTuple();
        private int i = 0;
        private int j = 0;

        public ShrinkingMatrixIterator(ShrinkingMatrix shrinkingMatrix) {
            this.matrix = shrinkingMatrix;
        }

        @Override // org.tribuo.math.la.MatrixIterator
        public MatrixTuple getReference() {
            return this.tuple;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.i < this.matrix.dim1 && this.j < this.matrix.dim2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public MatrixTuple next() {
            this.tuple.i = this.i;
            this.tuple.j = this.j;
            this.tuple.value = this.matrix.get(this.i, this.j);
            if (this.j < ShrinkingMatrix.this.dim2 - 1) {
                this.j++;
            } else {
                this.i++;
                this.j = 0;
            }
            return this.tuple;
        }
    }

    public ShrinkingMatrix(DenseMatrix denseMatrix, double d, boolean z) {
        super(denseMatrix);
        this.baseRate = d;
        this.scaleShrinking = z;
        this.lambdaSqrt = 0.0d;
        this.reproject = false;
        this.squaredTwoNorm = 0.0d;
        this.iteration = 1;
        this.multiplier = 1.0d;
    }

    public ShrinkingMatrix(DenseMatrix denseMatrix, double d, double d2) {
        super(denseMatrix);
        this.baseRate = d;
        this.scaleShrinking = true;
        this.lambdaSqrt = Math.sqrt(d2);
        this.reproject = true;
        this.squaredTwoNorm = 0.0d;
        this.iteration = 1;
        this.multiplier = 1.0d;
    }

    private ShrinkingMatrix(DenseMatrix denseMatrix, double d, boolean z, double d2, boolean z2, double d3, int i, double d4) {
        super(denseMatrix);
        this.baseRate = d;
        this.scaleShrinking = z;
        this.lambdaSqrt = d2;
        if (!z2 && d2 != 0.0d) {
            throw new IllegalStateException("Invalid ShrinkingMatrix, when reproject is true lambda must be zero");
        }
        this.reproject = z2;
        this.squaredTwoNorm = d3;
        if (i < 0) {
            throw new IllegalArgumentException("Invalid ShrinkingMatrix, iteration must be non-negative");
        }
        this.iteration = i;
        this.multiplier = d4;
    }

    public static ShrinkingMatrix deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        ShrinkingDenseTensorProto unpack = any.unpack(ShrinkingDenseTensorProto.class);
        return new ShrinkingMatrix(DenseMatrix.unpackProto(unpack.getData()), unpack.getBaseRate(), unpack.getScaleShrinking(), unpack.getLambdaSqrt(), unpack.getReproject(), unpack.getSquaredTwoNorm(), unpack.getIteration(), unpack.getMultiplier());
    }

    @Override // org.tribuo.math.la.DenseMatrix
    /* renamed from: serialize */
    public TensorProto mo18serialize() {
        TensorProto.Builder newBuilder = TensorProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(ShrinkingMatrix.class.getName());
        ShrinkingDenseTensorProto.Builder newBuilder2 = ShrinkingDenseTensorProto.newBuilder();
        DenseTensorProto.Builder newBuilder3 = DenseTensorProto.newBuilder();
        newBuilder3.addDimensions(this.dim1);
        newBuilder3.addDimensions(this.dim2);
        ByteBuffer order = ByteBuffer.allocate(this.dim1 * this.dim2 * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer asDoubleBuffer = order.asDoubleBuffer();
        for (int i = 0; i < this.values.length; i++) {
            asDoubleBuffer.put(this.values[i]);
        }
        asDoubleBuffer.rewind();
        newBuilder3.setValues(ByteString.copyFrom(order));
        newBuilder2.setData(newBuilder3.m175build());
        newBuilder2.setBaseRate(this.baseRate);
        newBuilder2.setLambdaSqrt(this.lambdaSqrt);
        newBuilder2.setScaleShrinking(this.scaleShrinking);
        newBuilder2.setReproject(this.reproject);
        newBuilder2.setSquaredTwoNorm(this.squaredTwoNorm);
        newBuilder2.setIteration(this.iteration);
        newBuilder2.setMultiplier(this.multiplier);
        newBuilder.setSerializedData(Any.pack(newBuilder2.m692build()));
        return newBuilder.m833build();
    }

    @Override // org.tribuo.math.optimisers.util.ShrinkingTensor
    public DenseMatrix convertToDense() {
        return new DenseMatrix((DenseMatrix) this);
    }

    @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Matrix
    public DenseVector leftMultiply(SGDVector sGDVector) {
        if (sGDVector.size() != this.dim2) {
            throw new IllegalArgumentException("input.size() != dim2");
        }
        double[] dArr = new double[this.dim1];
        for (VectorTuple vectorTuple : sGDVector) {
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + (get(i, vectorTuple.index) * vectorTuple.value);
            }
        }
        return DenseVector.createDenseVector(dArr);
    }

    @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Tensor
    public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        if (!(tensor instanceof Matrix)) {
            throw new IllegalStateException("Adding a non-Matrix to a Matrix");
        }
        Matrix matrix = (Matrix) tensor;
        if (this.dim1 != matrix.getDimension1Size() || this.dim2 != matrix.getDimension2Size()) {
            throw new IllegalStateException("Matrices are not the same size, this(" + this.dim1 + "," + this.dim2 + "), other(" + matrix.getDimension1Size() + "," + matrix.getDimension2Size() + ")");
        }
        scaleInPlace(this.scaleShrinking ? 1.0d - (this.baseRate / this.iteration) : 1.0d - this.baseRate);
        for (MatrixTuple matrixTuple : matrix) {
            double applyAsDouble = doubleUnaryOperator.applyAsDouble(matrixTuple.value);
            double d = this.values[matrixTuple.i][matrixTuple.j] * this.multiplier;
            double d2 = d + applyAsDouble;
            this.squaredTwoNorm -= d * d;
            this.squaredTwoNorm += d2 * d2;
            this.values[matrixTuple.i][matrixTuple.j] = d2 / this.multiplier;
        }
        if (this.reproject) {
            double twoNorm = (1.0d / this.lambdaSqrt) / twoNorm();
            if (twoNorm < 1.0d) {
                scaleInPlace(twoNorm);
            }
        }
        this.iteration++;
    }

    @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Matrix
    public double get(int i, int i2) {
        return this.values[i][i2] * this.multiplier;
    }

    @Override // org.tribuo.math.la.Tensor
    public void scaleInPlace(double d) {
        this.multiplier *= d;
        if (Math.abs(this.multiplier) < 1.0E-6d) {
            reifyMultiplier();
        }
    }

    private void reifyMultiplier() {
        for (int i = 0; i < this.dim1; i++) {
            for (int i2 = 0; i2 < this.dim2; i2++) {
                double[] dArr = this.values[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] * this.multiplier;
            }
        }
        this.multiplier = 1.0d;
    }

    @Override // org.tribuo.math.la.DenseMatrix, org.tribuo.math.la.Tensor
    public double twoNorm() {
        return Math.sqrt(this.squaredTwoNorm);
    }

    @Override // org.tribuo.math.la.DenseMatrix, java.lang.Iterable
    /* renamed from: iterator */
    public Iterator<MatrixTuple> iterator2() {
        return new ShrinkingMatrixIterator(this);
    }
}
