package org.grouplens.lenskit.mf.funksvd;

import it.unimi.dsi.fastutil.longs.LongSortedSet;
import java.util.Iterator;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Inject;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.baseline.BaselineScorer;
import org.grouplens.lenskit.basic.AbstractItemScorer;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.event.Ratings;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.transform.clamp.ClampingFunction;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;

/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDItemScorer.class */
public class FunkSVDItemScorer extends AbstractItemScorer {
    protected final FunkSVDModel model;
    private UserEventDAO dao;
    private final ItemScorer baselineScorer;
    private final int featureCount;
    private final ClampingFunction clamp;

    @Nullable
    private final FunkSVDUpdateRule rule;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Inject
    public FunkSVDItemScorer(UserEventDAO userEventDAO, FunkSVDModel funkSVDModel, @BaselineScorer ItemScorer itemScorer, @Nullable @RuntimeUpdate FunkSVDUpdateRule funkSVDUpdateRule) {
        this.dao = userEventDAO;
        this.model = funkSVDModel;
        this.baselineScorer = itemScorer;
        this.rule = funkSVDUpdateRule;
        this.featureCount = funkSVDModel.getFeatureCount();
        this.clamp = funkSVDModel.getClampingFunction();
    }

    @Nullable
    public FunkSVDUpdateRule getUpdateRule() {
        return this.rule;
    }

    private void predict(long j, double[] dArr, MutableSparseVector mutableSparseVector) {
        for (VectorEntry vectorEntry : mutableSparseVector.fast()) {
            long key = vectorEntry.getKey();
            int index = this.model.getItemIndex().getIndex(key);
            if (index >= 0) {
                double value = vectorEntry.getValue();
                for (int i = 0; i < this.featureCount; i++) {
                    value = this.clamp.apply(j, key, value + (dArr[i] * this.model.getItemFeatures()[i][index]));
                }
                mutableSparseVector.set(vectorEntry, value);
            }
        }
    }

    private MutableSparseVector initialEstimates(long j, SparseVector sparseVector, LongSortedSet longSortedSet) {
        MutableSparseVector create = MutableSparseVector.create(LongUtils.setUnion(longSortedSet, sparseVector.keySet()));
        this.baselineScorer.score(j, create);
        return create;
    }

    public void score(long j, @Nonnull MutableSparseVector mutableSparseVector) {
        double[] dArr;
        int index = this.model.getUserIndex().getIndex(j);
        MutableSparseVector userRatingVector = Ratings.userRatingVector(this.dao.getEventsForUser(j, Rating.class));
        MutableSparseVector initialEstimates = initialEstimates(j, userRatingVector, mutableSparseVector.keyDomain());
        mutableSparseVector.set(initialEstimates);
        if (index >= 0 || !userRatingVector.isEmpty()) {
            if (index < 0) {
                dArr = new double[this.model.getFeatureCount()];
                for (int i = 0; i < this.model.getFeatureCount(); i++) {
                    dArr[i] = this.model.getFeatureInfo(i).getUserAverage();
                }
            } else {
                dArr = new double[this.featureCount];
                for (int i2 = 0; i2 < this.featureCount; i2++) {
                    dArr[i2] = this.model.getUserFeatures()[i2][index];
                }
            }
            if (!userRatingVector.isEmpty() && this.rule != null) {
                for (int i3 = 0; i3 < this.featureCount; i3++) {
                    trainUserFeature(j, dArr, userRatingVector, initialEstimates, i3);
                }
            }
            predict(j, dArr, mutableSparseVector);
        }
    }

    private void trainUserFeature(long j, double[] dArr, SparseVector sparseVector, MutableSparseVector mutableSparseVector, int i) {
        if (!$assertionsDisabled && this.rule == null) {
            throw new AssertionError();
        }
        double d = Double.MAX_VALUE;
        TrainingLoopController trainingLoopController = this.rule.getTrainingLoopController();
        while (trainingLoopController.keepTraining(d)) {
            d = doFeatureIteration(j, dArr, sparseVector, mutableSparseVector, i);
        }
        Iterator it = sparseVector.fast().iterator();
        while (it.hasNext()) {
            long key = ((VectorEntry) it.next()).getKey();
            mutableSparseVector.set(key, this.clamp.apply(j, key, mutableSparseVector.get(key) + (dArr[i] * this.model.getItemFeature(key, i))));
        }
    }

    private double doFeatureIteration(long j, double[] dArr, SparseVector sparseVector, MutableSparseVector mutableSparseVector, int i) {
        if (!$assertionsDisabled && this.rule == null) {
            throw new AssertionError();
        }
        double d = 0.0d;
        int i2 = 0;
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            int index = this.model.getItemIndex().getIndex(key);
            double d2 = 0.0d;
            for (int i3 = i + 1; i3 < this.featureCount; i3++) {
                d2 += dArr[i3] * this.model.getItemFeatures()[i3][index];
            }
            double d3 = dArr[i];
            double d4 = this.model.getItemFeatures()[i][index];
            double computeError = this.rule.computeError(j, key, d2, mutableSparseVector.get(key), vectorEntry.getValue(), d3, d4);
            dArr[i] = dArr[i] + this.rule.userUpdate(computeError, d3, d4);
            d += computeError * computeError;
            i2++;
        }
        return Math.sqrt(d / i2);
    }

    static {
        $assertionsDisabled = !FunkSVDItemScorer.class.desiredAssertionStatus();
    }
}
