package org.grouplens.lenskit.mf.funksvd;

import it.unimi.dsi.fastutil.longs.LongSortedSet;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Inject;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.baseline.BaselineScorer;
import org.grouplens.lenskit.basic.AbstractItemScorer;
import org.grouplens.lenskit.collections.LongUtils;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.event.Ratings;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.PreferenceDomain;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.mf.svd.BiasedMFKernel;
import org.grouplens.lenskit.mf.svd.DomainClampingKernel;
import org.grouplens.lenskit.mf.svd.DotProductKernel;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;

/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDItemScorer.class */
public class FunkSVDItemScorer extends AbstractItemScorer {
    protected final FunkSVDModel model;
    protected final BiasedMFKernel kernel;
    private UserEventDAO dao;
    private final ItemScorer baselineScorer;
    private final int featureCount;

    @Nullable
    private final FunkSVDUpdateRule rule;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Inject
    public FunkSVDItemScorer(UserEventDAO userEventDAO, FunkSVDModel funkSVDModel, @BaselineScorer ItemScorer itemScorer, @Nullable PreferenceDomain preferenceDomain, @Nullable @RuntimeUpdate FunkSVDUpdateRule funkSVDUpdateRule) {
        this.dao = userEventDAO;
        this.model = funkSVDModel;
        this.baselineScorer = itemScorer;
        this.rule = funkSVDUpdateRule;
        if (preferenceDomain == null) {
            this.kernel = new DotProductKernel();
        } else {
            this.kernel = new DomainClampingKernel(preferenceDomain);
        }
        this.featureCount = funkSVDModel.getFeatureCount();
    }

    @Nullable
    public FunkSVDUpdateRule getUpdateRule() {
        return this.rule;
    }

    private void computeScores(long j, AVector aVector, MutableSparseVector mutableSparseVector) {
        for (VectorEntry vectorEntry : mutableSparseVector.fast()) {
            AVector itemVector = this.model.getItemVector(vectorEntry.getKey());
            if (itemVector != null) {
                mutableSparseVector.set(vectorEntry, this.kernel.apply(vectorEntry.getValue(), aVector, itemVector));
            }
        }
    }

    private MutableSparseVector initialEstimates(long j, SparseVector sparseVector, LongSortedSet longSortedSet) {
        MutableSparseVector create = MutableSparseVector.create(LongUtils.setUnion(longSortedSet, sparseVector.keySet()));
        this.baselineScorer.score(j, create);
        return create;
    }

    public void score(long j, @Nonnull MutableSparseVector mutableSparseVector) {
        UserHistory eventsForUser = this.dao.getEventsForUser(j, Rating.class);
        if (eventsForUser == null) {
            eventsForUser = History.forUser(j);
        }
        MutableSparseVector userRatingVector = Ratings.userRatingVector(eventsForUser);
        MutableSparseVector initialEstimates = initialEstimates(j, userRatingVector, mutableSparseVector.keyDomain());
        mutableSparseVector.set(initialEstimates);
        AVector userVector = this.model.getUserVector(j);
        if (userVector == null) {
            if (userRatingVector.isEmpty()) {
                return;
            } else {
                userVector = this.model.getAverageUserVector();
            }
        }
        if (!userRatingVector.isEmpty() && this.rule != null) {
            AVector create = Vector.create(userVector);
            for (int i = 0; i < this.featureCount; i++) {
                trainUserFeature(j, create, userRatingVector, initialEstimates, i);
            }
            userVector = create;
        }
        computeScores(j, userVector, mutableSparseVector);
    }

    private void trainUserFeature(long j, AVector aVector, SparseVector sparseVector, MutableSparseVector mutableSparseVector, int i) {
        if (!$assertionsDisabled && this.rule == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && aVector.length() != this.featureCount) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (i < 0 || i >= this.featureCount)) {
            throw new AssertionError();
        }
        int i2 = i + 1;
        int i3 = (this.featureCount - i) - 1;
        AVector subVector = aVector.subVector(i2, i3);
        MutableSparseVector create = MutableSparseVector.create(sparseVector.keySet());
        for (VectorEntry vectorEntry : create.fast(VectorEntry.State.EITHER)) {
            AVector itemVector = this.model.getItemVector(vectorEntry.getKey());
            if (itemVector == null) {
                create.set(vectorEntry, 0.0d);
            } else {
                create.set(vectorEntry, subVector.dotProduct(itemVector.subVector(i2, i3)));
            }
        }
        double d = Double.MAX_VALUE;
        TrainingLoopController trainingLoopController = this.rule.getTrainingLoopController();
        while (trainingLoopController.keepTraining(d)) {
            d = doFeatureIteration(j, aVector, sparseVector, mutableSparseVector, i, create);
        }
    }

    private double doFeatureIteration(long j, AVector aVector, SparseVector sparseVector, MutableSparseVector mutableSparseVector, int i, SparseVector sparseVector2) {
        if (!$assertionsDisabled && this.rule == null) {
            throw new AssertionError();
        }
        FunkSVDUpdater createUpdater = this.rule.createUpdater();
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            AVector itemVector = this.model.getItemVector(key);
            if (itemVector != null) {
                createUpdater.prepare(i, vectorEntry.getValue(), mutableSparseVector.get(key), aVector.get(i), itemVector.get(i), sparseVector2.get(key));
                aVector.addAt(i, createUpdater.getUserFeatureUpdate());
            }
        }
        return createUpdater.getRMSE();
    }

    static {
        $assertionsDisabled = !FunkSVDItemScorer.class.desiredAssertionStatus();
    }
}
