package org.grouplens.lenskit.mf.funksvd;

import it.unimi.dsi.fastutil.longs.LongIterator;
import mikera.vectorz.AVector;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.collections.FastCollection;
import org.grouplens.lenskit.data.pref.IndexedPreference;
import org.grouplens.lenskit.data.pref.PreferenceDomain;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.vectors.MutableSparseVector;

/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/TrainingEstimator.class */
public final class TrainingEstimator {
    private final FastCollection<IndexedPreference> ratings;
    private final double[] estimates;
    private final PreferenceDomain domain;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainingEstimator(PreferenceSnapshot preferenceSnapshot, ItemScorer itemScorer, PreferenceDomain preferenceDomain) {
        this.ratings = preferenceSnapshot.getRatings();
        this.domain = preferenceDomain;
        this.estimates = new double[this.ratings.size()];
        LongIterator it = preferenceSnapshot.getUserIds().iterator();
        while (it.hasNext()) {
            long nextLong = it.nextLong();
            MutableSparseVector create = MutableSparseVector.create(preferenceSnapshot.userRatingVector(nextLong).keySet());
            itemScorer.score(nextLong, create);
            for (IndexedPreference indexedPreference : CollectionUtils.fast(preferenceSnapshot.getUserRatings(nextLong))) {
                this.estimates[indexedPreference.getIndex()] = create.get(indexedPreference.getItemId());
            }
        }
    }

    public double get(IndexedPreference indexedPreference) {
        return this.estimates[indexedPreference.getIndex()];
    }

    public void update(AVector aVector, AVector aVector2) {
        for (IndexedPreference indexedPreference : CollectionUtils.fast(this.ratings)) {
            int index = indexedPreference.getIndex();
            double d = this.estimates[index] + (aVector.get(indexedPreference.getUserIndex()) * aVector2.get(indexedPreference.getItemIndex()));
            if (this.domain != null) {
                d = this.domain.clampValue(d);
            }
            this.estimates[index] = d;
        }
    }
}
