package org.tribuo.math.distributions;

import java.util.Arrays;
import java.util.Optional;
import java.util.Random;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;

/* loaded from: input_file:org/tribuo/math/distributions/MultivariateNormalDistribution.class */
public final class MultivariateNormalDistribution {
    private final long seed;
    private final Random rng;
    private final DenseVector means;
    private final DenseMatrix covariance;
    private final DenseMatrix samplingCovariance;
    private final boolean eigenDecomposition;

    public MultivariateNormalDistribution(double[] dArr, double[][] dArr2, long j) {
        this(DenseVector.createDenseVector(dArr), DenseMatrix.createDenseMatrix(dArr2), j);
    }

    public MultivariateNormalDistribution(double[] dArr, double[][] dArr2, long j, boolean z) {
        this(DenseVector.createDenseVector(dArr), DenseMatrix.createDenseMatrix(dArr2), j, z);
    }

    public MultivariateNormalDistribution(DenseVector denseVector, DenseMatrix denseMatrix, long j) {
        this(denseVector, denseMatrix, j, false);
    }

    public MultivariateNormalDistribution(DenseVector denseVector, DenseMatrix denseMatrix, long j, boolean z) {
        this.seed = j;
        this.rng = new Random(j);
        this.means = denseVector.copy();
        this.covariance = denseMatrix.copy();
        if (this.covariance.getDimension1Size() != this.means.size() || this.covariance.getDimension2Size() != this.means.size()) {
            throw new IllegalArgumentException("Covariance matrix must be square and the same dimension as the mean vector. Mean vector size = " + denseVector.size() + ", covariance size = " + Arrays.toString(this.covariance.getShape()));
        }
        this.eigenDecomposition = z;
        if (!z) {
            Optional<DenseMatrix.CholeskyFactorization> choleskyFactorization = this.covariance.choleskyFactorization();
            if (!choleskyFactorization.isPresent()) {
                throw new IllegalArgumentException("Covariance matrix is not positive definite.");
            }
            this.samplingCovariance = choleskyFactorization.get().lMatrix();
            return;
        }
        Optional<DenseMatrix.EigenDecomposition> eigenDecomposition = this.covariance.eigenDecomposition();
        if (!eigenDecomposition.isPresent() || !eigenDecomposition.get().positiveEigenvalues()) {
            throw new IllegalArgumentException("Covariance matrix is not positive definite.");
        }
        DenseVector eigenvalues = eigenDecomposition.get().eigenvalues();
        Matrix denseMatrix2 = new DenseMatrix(eigenDecomposition.get().eigenvectors());
        eigenvalues.foreachInPlace(Math::sqrt);
        this.samplingCovariance = denseMatrix2.matrixMultiply((Matrix) DenseSparseMatrix.createDiagonal(eigenvalues)).matrixMultiply(denseMatrix2, false, true);
    }

    public DenseVector sampleVector() {
        DenseVector denseVector = new DenseVector(this.means.size());
        for (int i = 0; i < this.means.size(); i++) {
            denseVector.set(i, this.rng.nextGaussian());
        }
        return this.means.add((SGDVector) this.samplingCovariance.leftMultiply((SGDVector) denseVector));
    }

    public double[] sampleArray() {
        return sampleVector().toArray();
    }

    public String toString() {
        return "MultivariateNormal(mean=" + this.means + ",covariance=" + this.covariance + ",seed=" + this.seed + ",useEigenDecomposition=" + this.eigenDecomposition + ")";
    }
}
