package org.grouplens.lenskit.mf.funksvd;

import java.util.ArrayList;
import java.util.Collection;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import javax.inject.Provider;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import org.apache.commons.lang3.time.StopWatch;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.collections.FastCollection;
import org.grouplens.lenskit.core.Transient;
import org.grouplens.lenskit.data.pref.IndexedPreference;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.mf.funksvd.FeatureInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDModelBuilder.class */
public class FunkSVDModelBuilder implements Provider<FunkSVDModel> {
    private static Logger logger;
    protected final int featureCount;
    protected final PreferenceSnapshot snapshot;
    protected final double initialValue;
    protected final FunkSVDUpdateRule rule;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Inject
    public FunkSVDModelBuilder(@Nonnull @Transient PreferenceSnapshot preferenceSnapshot, @Nonnull @Transient FunkSVDUpdateRule funkSVDUpdateRule, @FeatureCount int i, @InitialFeatureValue double d) {
        this.featureCount = i;
        this.initialValue = d;
        this.snapshot = preferenceSnapshot;
        this.rule = funkSVDUpdateRule;
    }

    /* renamed from: get, reason: merged with bridge method [inline-methods] */
    public FunkSVDModel m3get() {
        int size = this.snapshot.getUserIds().size();
        Matrix create = Matrix.create(size, this.featureCount);
        int size2 = this.snapshot.getItemIds().size();
        Matrix create2 = Matrix.create(size2, this.featureCount);
        logger.debug("Learning rate is {}", Double.valueOf(this.rule.getLearningRate()));
        logger.debug("Regularization term is {}", Double.valueOf(this.rule.getTrainingRegularization()));
        logger.info("Building SVD with {} features for {} ratings", Integer.valueOf(this.featureCount), Integer.valueOf(this.snapshot.getRatings().size()));
        TrainingEstimator makeEstimator = this.rule.makeEstimator(this.snapshot);
        ArrayList arrayList = new ArrayList(this.featureCount);
        Vector createLength = Vector.createLength(size);
        Vector createLength2 = Vector.createLength(size2);
        for (int i = 0; i < this.featureCount; i++) {
            logger.debug("Training feature {}", Integer.valueOf(i));
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            createLength.fill(this.initialValue);
            createLength2.fill(this.initialValue);
            FeatureInfo.Builder builder = new FeatureInfo.Builder(i);
            trainFeature(i, makeEstimator, createLength, createLength2, builder);
            summarizeFeature(createLength, createLength2, builder);
            arrayList.add(builder.m0build());
            makeEstimator.update(createLength, createLength2);
            create.setColumn(i, createLength);
            if (!$assertionsDisabled && Math.abs(create.getColumnView(i).elementSum() - createLength.elementSum()) >= 1.0E-4d) {
                throw new AssertionError("user column sum matches");
            }
            create2.setColumn(i, createLength2);
            if (!$assertionsDisabled && Math.abs(create2.getColumnView(i).elementSum() - createLength2.elementSum()) >= 1.0E-4d) {
                throw new AssertionError("item column sum matches");
            }
            stopWatch.stop();
            logger.info("Finished feature {} in {}", Integer.valueOf(i), stopWatch);
        }
        return new FunkSVDModel(ImmutableMatrix.wrap(create), ImmutableMatrix.wrap(create2), this.snapshot.userIndex(), this.snapshot.itemIndex(), arrayList);
    }

    protected void trainFeature(int i, TrainingEstimator trainingEstimator, Vector vector, Vector vector2, FeatureInfo.Builder builder) {
        double d = Double.MAX_VALUE;
        double d2 = this.initialValue * this.initialValue * ((this.featureCount - i) - 1);
        TrainingLoopController trainingLoopController = this.rule.getTrainingLoopController();
        FastCollection ratings = this.snapshot.getRatings();
        while (trainingLoopController.keepTraining(d)) {
            d = doFeatureIteration(trainingEstimator, ratings, vector, vector2, d2);
            builder.addTrainingRound(d);
            logger.trace("iteration {} finished with RMSE {}", Integer.valueOf(trainingLoopController.getIterationCount()), Double.valueOf(d));
        }
    }

    protected double doFeatureIteration(TrainingEstimator trainingEstimator, Collection<IndexedPreference> collection, Vector vector, Vector vector2, double d) {
        FunkSVDUpdater createUpdater = this.rule.createUpdater();
        for (IndexedPreference indexedPreference : CollectionUtils.fast(collection)) {
            int userIndex = indexedPreference.getUserIndex();
            int itemIndex = indexedPreference.getItemIndex();
            createUpdater.prepare(0, indexedPreference.getValue(), trainingEstimator.get(indexedPreference), vector.get(userIndex), vector2.get(itemIndex), d);
            vector.addAt(userIndex, createUpdater.getUserFeatureUpdate());
            vector2.addAt(itemIndex, createUpdater.getItemFeatureUpdate());
        }
        return createUpdater.getRMSE();
    }

    protected void summarizeFeature(AVector aVector, AVector aVector2, FeatureInfo.Builder builder) {
        builder.setUserAverage(aVector.elementSum() / aVector.length()).setItemAverage(aVector2.elementSum() / aVector2.length()).setSingularValue(aVector.magnitude() * aVector2.magnitude());
    }

    static {
        $assertionsDisabled = !FunkSVDModelBuilder.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(FunkSVDModelBuilder.class);
    }
}
