package org.grouplens.lenskit.svd;

import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import java.util.Iterator;
import org.grouplens.lenskit.RecommenderComponentBuilder;
import org.grouplens.lenskit.baseline.BaselinePredictor;
import org.grouplens.lenskit.data.pref.IndexedPreference;
import org.grouplens.lenskit.data.snapshot.RatingSnapshot;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.data.vector.UserRatingVector;
import org.grouplens.lenskit.svd.params.ClampingFunction;
import org.grouplens.lenskit.svd.params.FeatureCount;
import org.grouplens.lenskit.svd.params.IterationCount;
import org.grouplens.lenskit.svd.params.LearningRate;
import org.grouplens.lenskit.svd.params.RegularizationTerm;
import org.grouplens.lenskit.svd.params.TrainingThreshold;
import org.grouplens.lenskit.util.DoubleFunction;
import org.grouplens.lenskit.util.FastCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/svd/FunkSVDModelBuilder.class */
public class FunkSVDModelBuilder extends RecommenderComponentBuilder<FunkSVDModel> {
    private static Logger logger = LoggerFactory.getLogger(FunkSVDModelBuilder.class);
    private static final double DEFAULT_FEATURE_VALUE = 0.1d;
    private static final double MIN_EPOCHS = 50.0d;
    private int featureCount;
    private double learningRate;
    private double trainingThreshold;
    private double trainingRegularization;
    private DoubleFunction clampingFunction;
    private int iterationCount;
    private BaselinePredictor baseline;

    @FeatureCount
    public void setFeatureCount(int i) {
        this.featureCount = i;
    }

    @LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    @TrainingThreshold
    public void setTrainingThreshold(double d) {
        this.trainingThreshold = d;
    }

    @RegularizationTerm
    public void setGradientDescentRegularization(double d) {
        this.trainingRegularization = d;
    }

    @ClampingFunction
    public void setClampingFunction(DoubleFunction doubleFunction) {
        this.clampingFunction = doubleFunction;
    }

    @IterationCount
    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    public void setBaseline(BaselinePredictor baselinePredictor) {
        this.baseline = baselinePredictor;
    }

    /* renamed from: build, reason: merged with bridge method [inline-methods] */
    public FunkSVDModel m1build() {
        logger.debug("Setting up to build SVD recommender with {} features", Integer.valueOf(this.featureCount));
        logger.debug("Learning rate is {}", Double.valueOf(this.learningRate));
        logger.debug("Regularization term is {}", Double.valueOf(this.trainingRegularization));
        if (this.iterationCount > 0) {
            logger.debug("Training each epoch for {} iterations", Integer.valueOf(this.iterationCount));
        } else {
            logger.debug("Error epsilon is {}", Double.valueOf(this.trainingThreshold));
        }
        MutableSparseVector[] initializeEstimates = initializeEstimates(this.snapshot, this.baseline);
        FastCollection<IndexedPreference> ratings = this.snapshot.getRatings();
        logger.debug("Building SVD with {} features for {} ratings", Integer.valueOf(this.featureCount), Integer.valueOf(ratings.size()));
        int size = this.snapshot.getUserIds().size();
        int size2 = this.snapshot.getItemIds().size();
        double[][] dArr = new double[this.featureCount][size];
        double[][] dArr2 = new double[this.featureCount][size2];
        for (int i = 0; i < this.featureCount; i++) {
            trainFeature(dArr, dArr2, initializeEstimates, ratings, i);
        }
        return new FunkSVDModel(this.featureCount, dArr2, dArr, this.clampingFunction, this.snapshot.itemIndex(), this.snapshot.userIndex(), this.baseline);
    }

    private MutableSparseVector[] initializeEstimates(RatingSnapshot ratingSnapshot, BaselinePredictor baselinePredictor) {
        int objectCount = ratingSnapshot.userIndex().getObjectCount();
        MutableSparseVector[] mutableSparseVectorArr = new MutableSparseVector[objectCount];
        for (int i = 0; i < objectCount; i++) {
            long id = ratingSnapshot.userIndex().getId(i);
            UserRatingVector fromPreferences = UserRatingVector.fromPreferences(id, ratingSnapshot.getUserRatings(id));
            mutableSparseVectorArr[i] = baselinePredictor.predict(fromPreferences, fromPreferences.keySet());
        }
        return mutableSparseVectorArr;
    }

    private final void trainFeature(double[][] dArr, double[][] dArr2, MutableSparseVector[] mutableSparseVectorArr, FastCollection<IndexedPreference> fastCollection, int i) {
        logger.trace("Training feature {}", Integer.valueOf(i));
        double[] dArr3 = dArr[i];
        double[] dArr4 = dArr2[i];
        DoubleArrays.fill(dArr3, DEFAULT_FEATURE_VALUE);
        DoubleArrays.fill(dArr4, DEFAULT_FEATURE_VALUE);
        double d = ((this.featureCount - i) - 1) * DEFAULT_FEATURE_VALUE * DEFAULT_FEATURE_VALUE;
        double d2 = Double.MAX_VALUE;
        double d3 = 0.0d;
        int i2 = 0;
        while (!isDone(i2, d2, d3)) {
            logger.trace("Running epoch {} of feature {}", Integer.valueOf(i2), Integer.valueOf(i));
            d3 = d2;
            d2 = trainFeatureIteration(fastCollection, dArr3, dArr4, mutableSparseVectorArr, d);
            logger.trace("Epoch {} had RMSE of {}", Integer.valueOf(i2), Double.valueOf(d2));
            i2++;
        }
        logger.debug("Finished feature {} in {} epochs, rmse={}", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(d2)});
        for (IndexedPreference indexedPreference : fastCollection.fast()) {
            int userIndex = indexedPreference.getUserIndex();
            int itemIndex = indexedPreference.getItemIndex();
            long itemId = indexedPreference.getItemId();
            mutableSparseVectorArr[userIndex].set(itemId, this.clampingFunction.apply(mutableSparseVectorArr[userIndex].get(itemId) + (dArr3[userIndex] * dArr4[itemIndex])));
        }
    }

    protected final boolean isDone(int i, double d, double d2) {
        return this.iterationCount > 0 ? i >= this.iterationCount : ((double) i) >= MIN_EPOCHS && d2 - d < this.trainingThreshold;
    }

    private final double trainFeatureIteration(FastCollection<IndexedPreference> fastCollection, double[] dArr, double[] dArr2, MutableSparseVector[] mutableSparseVectorArr, double d) {
        double d2 = 0.0d;
        Iterator it = fastCollection.fast().iterator();
        while (it.hasNext()) {
            d2 += trainRating(dArr, dArr2, mutableSparseVectorArr, d, (IndexedPreference) it.next());
        }
        return Math.sqrt(d2 / fastCollection.size());
    }

    private final double trainRating(double[] dArr, double[] dArr2, MutableSparseVector[] mutableSparseVectorArr, double d, IndexedPreference indexedPreference) {
        int userIndex = indexedPreference.getUserIndex();
        int itemIndex = indexedPreference.getItemIndex();
        double value = indexedPreference.getValue() - this.clampingFunction.apply(this.clampingFunction.apply(mutableSparseVectorArr[userIndex].get(indexedPreference.getItemId()) + (dArr[userIndex] * dArr2[itemIndex])) + d);
        double d2 = dArr[userIndex];
        double d3 = dArr2[itemIndex];
        dArr[userIndex] = dArr[userIndex] + (((value * d3) - (this.trainingRegularization * d2)) * this.learningRate);
        dArr2[itemIndex] = dArr2[itemIndex] + (((value * d2) - (this.trainingRegularization * d3)) * this.learningRate);
        return value * value;
    }
}
