package org.grouplens.lenskit.svd;

import it.unimi.dsi.fastutil.longs.LongBidirectionalIterator;
import it.unimi.dsi.fastutil.longs.LongSortedSet;
import java.util.Collection;
import org.grouplens.common.cursors.Cursors;
import org.grouplens.lenskit.AbstractRatingPredictor;
import org.grouplens.lenskit.data.Ratings;
import org.grouplens.lenskit.data.dao.RatingDataAccessObject;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.data.vector.SparseVector;
import org.grouplens.lenskit.util.DoubleFunction;
import org.grouplens.lenskit.util.LongSortedArraySet;

/* loaded from: input_file:org/grouplens/lenskit/svd/FunkSVDRatingPredictor.class */
public class FunkSVDRatingPredictor extends AbstractRatingPredictor {
    protected final FunkSVDModel model;
    private RatingDataAccessObject dao;

    public FunkSVDRatingPredictor(RatingDataAccessObject ratingDataAccessObject, FunkSVDModel funkSVDModel) {
        this.dao = ratingDataAccessObject;
        this.model = funkSVDModel;
    }

    public int getFeatureCount() {
        return this.model.featureCount;
    }

    private MutableSparseVector predict(long j, double[] dArr, SparseVector sparseVector, Collection<Long> collection) {
        int i = this.model.featureCount;
        DoubleFunction doubleFunction = this.model.clampingFunction;
        LongSortedSet longSortedArraySet = collection instanceof LongSortedSet ? (LongSortedSet) collection : new LongSortedArraySet(collection);
        if (sparseVector == null) {
            sparseVector = Ratings.userRatingVector(Cursors.makeList(this.dao.getUserRatings(j)));
        }
        MutableSparseVector predict = this.model.baseline.predict(j, sparseVector, collection);
        LongBidirectionalIterator it = longSortedArraySet.iterator();
        while (it.hasNext()) {
            long nextLong = it.nextLong();
            int itemIndex = this.model.getItemIndex(nextLong);
            if (itemIndex >= 0) {
                double d = predict.get(nextLong);
                for (int i2 = 0; i2 < i; i2++) {
                    d = doubleFunction.apply(d + (dArr[i2] * this.model.getItemFeatureValue(itemIndex, i2)));
                }
                predict.set(nextLong, d);
            }
        }
        return predict;
    }

    public MutableSparseVector predict(long j, Collection<Long> collection) {
        int index = this.model.userIndex.getIndex(j);
        double[] dArr = new double[this.model.featureCount];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.model.userFeatures[i][index];
        }
        return predict(j, dArr, null, collection);
    }

    /* renamed from: predict, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseVector m2predict(long j, Collection collection) {
        return predict(j, (Collection<Long>) collection);
    }
}
