package org.grouplens.lenskit.svd;

import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.util.Arrays;
import java.util.Collection;
import org.grouplens.lenskit.AbstractRecommenderService;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.data.Index;
import org.grouplens.lenskit.data.ScoredId;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.data.vector.SparseVector;
import org.grouplens.lenskit.util.CollectionUtils;
import org.grouplens.lenskit.util.DoubleFunction;

/* loaded from: input_file:org/grouplens/lenskit/svd/SVDRecommenderService.class */
public class SVDRecommenderService extends AbstractRecommenderService implements RatingPredictor {
    private final Index itemIndex;
    private final RatingPredictor baseline;
    private final int numFeatures;
    private final double[][] itemFeatures;
    private final double[] singularValues;
    private final DoubleFunction clampingFunction;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SVDRecommenderService(int i, Index index, RatingPredictor ratingPredictor, double[][] dArr, double[] dArr2, DoubleFunction doubleFunction) {
        this.numFeatures = i;
        this.itemIndex = index;
        this.baseline = ratingPredictor;
        this.itemFeatures = dArr;
        this.singularValues = dArr2;
        this.clampingFunction = doubleFunction;
        if (!$assertionsDisabled && dArr.length != this.numFeatures) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr2.length != this.numFeatures) {
            throw new AssertionError();
        }
    }

    public int getFeatureCount() {
        return this.numFeatures;
    }

    protected double[] foldIn(long j, SparseVector sparseVector, SparseVector sparseVector2) {
        double[] dArr = new double[this.numFeatures];
        DoubleArrays.fill(dArr, 0.0d);
        for (Long2DoubleMap.Entry entry : sparseVector.fast()) {
            long longKey = entry.getLongKey();
            int index = this.itemIndex.getIndex(longKey);
            if (index >= 0) {
                double doubleValue = ((Double) entry.getValue()).doubleValue() - sparseVector2.get(longKey);
                for (int i = 0; i < this.numFeatures; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + ((doubleValue * this.itemFeatures[i][index]) / this.singularValues[i]);
                }
            }
        }
        return dArr;
    }

    public ScoredId predict(long j, SparseVector sparseVector, long j2) {
        LongArrayList longArrayList = new LongArrayList(1);
        longArrayList.add(j2);
        MutableSparseVector predict = predict(j, sparseVector, (Collection<Long>) longArrayList);
        if (predict.containsId(j2)) {
            return new ScoredId(j2, predict.get(j2));
        }
        return null;
    }

    public MutableSparseVector predict(long j, SparseVector sparseVector, Collection<Long> collection) {
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(sparseVector.keySet());
        longOpenHashSet.addAll(collection);
        SparseVector predict = this.baseline.predict(j, sparseVector, longOpenHashSet);
        double[] foldIn = foldIn(j, sparseVector, predict);
        long[] longArray = CollectionUtils.fastCollection(collection).toLongArray();
        Arrays.sort(longArray);
        double[] dArr = new double[longArray.length];
        for (int i = 0; i < longArray.length; i++) {
            long j2 = longArray[i];
            int index = this.itemIndex.getIndex(j2);
            if (index >= 0) {
                double d = predict.get(j2);
                for (int i2 = 0; i2 < this.numFeatures; i2++) {
                    d = this.clampingFunction.apply(d + (foldIn[i2] * this.singularValues[i2] * this.itemFeatures[i2][index]));
                }
                dArr[i] = d;
            }
        }
        return MutableSparseVector.wrap(longArray, dArr);
    }

    public RatingPredictor getRatingPredictor() {
        return this;
    }

    /* renamed from: predict, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseVector m2predict(long j, SparseVector sparseVector, Collection collection) {
        return predict(j, sparseVector, (Collection<Long>) collection);
    }

    static {
        $assertionsDisabled = !SVDRecommenderService.class.desiredAssertionStatus();
    }
}
