package org.grouplens.lenskit.mf.funksvd;

import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import javax.inject.Provider;
import org.apache.commons.lang3.time.StopWatch;
import org.grouplens.lenskit.baseline.BaselinePredictor;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.collections.FastCollection;
import org.grouplens.lenskit.core.Transient;
import org.grouplens.lenskit.data.pref.IndexedPreference;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.mf.funksvd.FeatureInfo;
import org.grouplens.lenskit.vectors.MutableVec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDModelBuilder.class */
public class FunkSVDModelBuilder implements Provider<FunkSVDModel> {
    private static Logger logger = LoggerFactory.getLogger(FunkSVDModelBuilder.class);
    protected final int featureCount;
    protected final BaselinePredictor baseline;
    protected final PreferenceSnapshot snapshot;
    protected final double initialValue;
    protected final FunkSVDUpdateRule rule;

    @Inject
    public FunkSVDModelBuilder(@Nonnull @Transient PreferenceSnapshot preferenceSnapshot, @Nonnull @Transient FunkSVDUpdateRule funkSVDUpdateRule, @Nonnull BaselinePredictor baselinePredictor, @FeatureCount int i, @InitialFeatureValue double d) {
        this.featureCount = i;
        this.initialValue = d;
        this.baseline = baselinePredictor;
        this.snapshot = preferenceSnapshot;
        this.rule = funkSVDUpdateRule;
    }

    /* renamed from: get, reason: merged with bridge method [inline-methods] */
    public FunkSVDModel m3get() {
        double[][] dArr = new double[this.featureCount][this.snapshot.getUserIds().size()];
        double[][] dArr2 = new double[this.featureCount][this.snapshot.getItemIds().size()];
        logger.debug("Setting up to build SVD recommender with {} features", Integer.valueOf(this.featureCount));
        logger.debug("Learning rate is {}", Double.valueOf(this.rule.getLearningRate()));
        logger.debug("Regularization term is {}", Double.valueOf(this.rule.getTrainingRegularization()));
        logger.debug("Building SVD with {} features for {} ratings", Integer.valueOf(this.featureCount), Integer.valueOf(this.snapshot.getRatings().size()));
        TrainingEstimator makeEstimator = this.rule.makeEstimator(this.snapshot);
        ArrayList arrayList = new ArrayList(this.featureCount);
        for (int i = 0; i < this.featureCount; i++) {
            logger.trace("Training feature {}", Integer.valueOf(i));
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            DoubleArrays.fill(dArr[i], this.initialValue);
            DoubleArrays.fill(dArr2[i], this.initialValue);
            FeatureInfo.Builder builder = new FeatureInfo.Builder(i);
            trainFeature(i, makeEstimator, dArr[i], dArr2[i], builder);
            summarizeFeature(dArr[i], dArr2[i], builder);
            arrayList.add(builder.m0build());
            makeEstimator.update(dArr[i], dArr2[i]);
            stopWatch.stop();
            logger.debug("Finished feature {} in {}", Integer.valueOf(i), stopWatch);
        }
        return new FunkSVDModel(this.featureCount, dArr2, dArr, this.rule.getClampingFunction(), this.snapshot.itemIndex(), this.snapshot.userIndex(), this.baseline, arrayList);
    }

    protected double computeTrailingValue(int i) {
        return ((this.featureCount - i) - 1) * this.initialValue * this.initialValue;
    }

    protected void trainFeature(int i, TrainingEstimator trainingEstimator, double[] dArr, double[] dArr2, FeatureInfo.Builder builder) {
        double computeTrailingValue = computeTrailingValue(i);
        double d = Double.MAX_VALUE;
        TrainingLoopController trainingLoopController = this.rule.getTrainingLoopController();
        FastCollection<IndexedPreference> ratings = this.snapshot.getRatings();
        while (trainingLoopController.keepTraining(d)) {
            d = doFeatureIteration(trainingEstimator, ratings, dArr, dArr2, computeTrailingValue);
            builder.addTrainingRound(d);
            logger.trace("iteration {} finished with RMSE {}", Integer.valueOf(trainingLoopController.getIterationCount()), Double.valueOf(d));
        }
    }

    protected double doFeatureIteration(TrainingEstimator trainingEstimator, FastCollection<IndexedPreference> fastCollection, double[] dArr, double[] dArr2, double d) {
        double d2 = 0.0d;
        int i = 0;
        for (IndexedPreference indexedPreference : CollectionUtils.fast(fastCollection)) {
            int userIndex = indexedPreference.getUserIndex();
            int itemIndex = indexedPreference.getItemIndex();
            double d3 = dArr[userIndex];
            double d4 = dArr2[itemIndex];
            double computeError = this.rule.computeError(indexedPreference.getUserId(), indexedPreference.getItemId(), d, trainingEstimator.get(indexedPreference), indexedPreference.getValue(), d3, d4);
            dArr[userIndex] = dArr[userIndex] + this.rule.userUpdate(computeError, d3, d4);
            dArr2[itemIndex] = dArr2[itemIndex] + this.rule.itemUpdate(computeError, d3, d4);
            d2 += computeError * computeError;
            i++;
        }
        return Math.sqrt(d2 / i);
    }

    protected void summarizeFeature(double[] dArr, double[] dArr2, FeatureInfo.Builder builder) {
        MutableVec wrap = MutableVec.wrap(dArr);
        MutableVec wrap2 = MutableVec.wrap(dArr2);
        builder.setUserAverage(wrap.mean()).setItemAverage(wrap2.mean()).setSingularValue(wrap.norm() * wrap2.norm());
    }
}
