package org.grouplens.lenskit.svd;

import com.google.inject.Inject;
import it.unimi.dsi.fastutil.doubles.DoubleArrays;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.RecommenderBuilder;
import org.grouplens.lenskit.RecommenderService;
import org.grouplens.lenskit.data.Index;
import org.grouplens.lenskit.data.IndexedRating;
import org.grouplens.lenskit.data.context.BuildContext;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.data.vector.SparseVector;
import org.grouplens.lenskit.svd.params.ClampingFunction;
import org.grouplens.lenskit.svd.params.FeatureCount;
import org.grouplens.lenskit.svd.params.FeatureTrainingThreshold;
import org.grouplens.lenskit.svd.params.GradientDescentRegularization;
import org.grouplens.lenskit.svd.params.IterationCount;
import org.grouplens.lenskit.svd.params.LearningRate;
import org.grouplens.lenskit.util.DoubleFunction;
import org.grouplens.lenskit.util.FastCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/svd/GradientDescentSVDRecommenderBuilder.class */
public class GradientDescentSVDRecommenderBuilder implements RecommenderBuilder {
    private static Logger logger = LoggerFactory.getLogger(GradientDescentSVDRecommenderBuilder.class);
    private static final double DEFAULT_FEATURE_VALUE = 0.1d;
    private static final double MIN_EPOCHS = 50.0d;
    private static final double MIN_FEAT_NORM = 1.0E-10d;
    final int featureCount;
    final double learningRate;
    final double trainingThreshold;
    final double trainingRegularization;
    final DoubleFunction clampingFunction;
    final int iterationCount;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/grouplens/lenskit/svd/GradientDescentSVDRecommenderBuilder$Model.class */
    public static final class Model {
        ArrayList<SparseVector> userBaselines;
        double[][] userFeatures;
        double[][] itemFeatures;
        double[] singularValues;

        private Model() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/grouplens/lenskit/svd/GradientDescentSVDRecommenderBuilder$SVDRating.class */
    public final class SVDRating {
        public final long uid;
        public final long iid;
        public final int user;
        public final int item;
        public final double value;
        public double cachedValue;

        public SVDRating(Model model, IndexedRating indexedRating) {
            this.uid = indexedRating.getUserId();
            this.iid = indexedRating.getItemId();
            this.user = indexedRating.getUserIndex();
            this.item = indexedRating.getItemIndex();
            this.value = indexedRating.getRating();
        }

        public double trainStep(double[] dArr, double[] dArr2, double d) {
            double apply = this.value - GradientDescentSVDRecommenderBuilder.this.clampingFunction.apply(GradientDescentSVDRecommenderBuilder.this.clampingFunction.apply(this.cachedValue + (dArr[this.user] * dArr2[this.item])) + d);
            double d2 = dArr[this.user];
            double d3 = dArr2[this.item];
            double d4 = (apply * d3) - (GradientDescentSVDRecommenderBuilder.this.trainingRegularization * d2);
            int i = this.user;
            dArr[i] = dArr[i] + (d4 * GradientDescentSVDRecommenderBuilder.this.learningRate);
            double d5 = (apply * d2) - (GradientDescentSVDRecommenderBuilder.this.trainingRegularization * d3);
            int i2 = this.item;
            dArr2[i2] = dArr2[i2] + (d5 * GradientDescentSVDRecommenderBuilder.this.learningRate);
            return apply * apply;
        }

        public void updateCachedValue(double[] dArr, double[] dArr2) {
            this.cachedValue = GradientDescentSVDRecommenderBuilder.this.clampingFunction.apply(this.cachedValue + (dArr[this.user] * dArr2[this.item]));
        }
    }

    @Inject
    public GradientDescentSVDRecommenderBuilder(@FeatureCount int i, @LearningRate double d, @FeatureTrainingThreshold double d2, @IterationCount int i2, @GradientDescentRegularization double d3, @ClampingFunction DoubleFunction doubleFunction) {
        this.featureCount = i;
        this.learningRate = d;
        this.trainingThreshold = d2;
        this.trainingRegularization = d3;
        this.clampingFunction = doubleFunction;
        this.iterationCount = i2;
    }

    public RecommenderService build(BuildContext buildContext, RatingPredictor ratingPredictor) {
        logger.debug("Setting up to build SVD recommender with {} features", Integer.valueOf(this.featureCount));
        logger.debug("Learning rate is {}", Double.valueOf(this.learningRate));
        logger.debug("Regularization term is {}", Double.valueOf(this.trainingRegularization));
        if (this.iterationCount > 0) {
            logger.debug("Training each epoch for {} iterations", Integer.valueOf(this.iterationCount));
        } else {
            logger.debug("Error epsilon is {}", Double.valueOf(this.trainingThreshold));
        }
        Model model = new Model();
        List<SVDRating> indexData = indexData(buildContext, ratingPredictor, model);
        for (SVDRating sVDRating : indexData) {
            sVDRating.cachedValue = model.userBaselines.get(sVDRating.user).get(sVDRating.iid);
        }
        logger.debug("Building SVD with {} features for {} ratings", Integer.valueOf(this.featureCount), Integer.valueOf(indexData.size()));
        int size = buildContext.getUserIds().size();
        int size2 = buildContext.getItemIds().size();
        model.userFeatures = new double[this.featureCount][size];
        model.itemFeatures = new double[this.featureCount][size2];
        for (int i = 0; i < this.featureCount; i++) {
            trainFeature(model, indexData, i);
        }
        logger.debug("Extracting singular values");
        model.singularValues = new double[this.featureCount];
        for (int i2 = 0; i2 < this.featureCount; i2++) {
            double[] dArr = model.userFeatures[i2];
            double d = 0.0d;
            for (int i3 = 0; i3 < size; i3++) {
                double d2 = dArr[i3];
                d += d2 * d2;
            }
            double sqrt = Math.sqrt(d);
            if (sqrt > MIN_FEAT_NORM) {
                for (int i4 = 0; i4 < size; i4++) {
                    int i5 = i4;
                    dArr[i5] = dArr[i5] / sqrt;
                }
            }
            double[] dArr2 = model.itemFeatures[i2];
            double d3 = 0.0d;
            for (int i6 = 0; i6 < size2; i6++) {
                double d4 = dArr2[i6];
                d3 += d4 * d4;
            }
            double sqrt2 = Math.sqrt(d3);
            if (sqrt2 > MIN_FEAT_NORM) {
                for (int i7 = 0; i7 < size2; i7++) {
                    int i8 = i7;
                    dArr2[i8] = dArr2[i8] / sqrt2;
                }
            }
            model.singularValues[i2] = sqrt * sqrt2;
        }
        return new SVDRecommenderService(this.featureCount, buildContext.itemIndex(), ratingPredictor, model.itemFeatures, model.singularValues, this.clampingFunction);
    }

    private final void trainFeature(Model model, List<SVDRating> list, int i) {
        logger.trace("Training feature {}", Integer.valueOf(i));
        double[] dArr = model.userFeatures[i];
        double[] dArr2 = model.itemFeatures[i];
        DoubleArrays.fill(dArr, DEFAULT_FEATURE_VALUE);
        DoubleArrays.fill(dArr2, DEFAULT_FEATURE_VALUE);
        double d = ((this.featureCount - i) - 1) * DEFAULT_FEATURE_VALUE * DEFAULT_FEATURE_VALUE;
        double d2 = Double.MAX_VALUE;
        double d3 = 0.0d;
        int i2 = 0;
        while (true) {
            if (this.iterationCount <= 0) {
                if (i2 >= MIN_EPOCHS && d2 >= d3 - this.trainingThreshold) {
                    break;
                }
                logger.trace("Running epoch {} of feature {}", Integer.valueOf(i2), Integer.valueOf(i));
                d3 = d2;
                d2 = trainFeatureIteration(list, dArr, dArr2, d);
                logger.trace("Epoch {} had RMSE of {}", Integer.valueOf(i2), Double.valueOf(d2));
                i2++;
            } else {
                if (i2 >= this.iterationCount) {
                    break;
                }
                logger.trace("Running epoch {} of feature {}", Integer.valueOf(i2), Integer.valueOf(i));
                d3 = d2;
                d2 = trainFeatureIteration(list, dArr, dArr2, d);
                logger.trace("Epoch {} had RMSE of {}", Integer.valueOf(i2), Double.valueOf(d2));
                i2++;
            }
        }
        logger.debug("Finished feature {} in {} epochs", Integer.valueOf(i), Integer.valueOf(i2));
        logger.debug("Final RMSE for feature {} was {}", Integer.valueOf(i), Double.valueOf(d2));
        Iterator<SVDRating> it = list.iterator();
        while (it.hasNext()) {
            it.next().updateCachedValue(dArr, dArr2);
        }
    }

    private final double trainFeatureIteration(List<SVDRating> list, double[] dArr, double[] dArr2, double d) {
        double d2 = 0.0d;
        Iterator<SVDRating> it = list.iterator();
        while (it.hasNext()) {
            d2 += it.next().trainStep(dArr, dArr2, d);
        }
        return Math.sqrt(d2 / list.size());
    }

    private List<SVDRating> indexData(BuildContext buildContext, RatingPredictor ratingPredictor, Model model) {
        ArrayList arrayList = new ArrayList(buildContext.getUserIds().size());
        FastCollection ratings = buildContext.getRatings();
        int size = ratings.size();
        logger.debug("pre-processing {} ratings", Integer.valueOf(size));
        ArrayList arrayList2 = new ArrayList(size);
        Iterator it = ratings.iterator();
        while (it.hasNext()) {
            SVDRating sVDRating = new SVDRating(model, (IndexedRating) it.next());
            arrayList2.add(sVDRating);
            while (sVDRating.user >= arrayList.size()) {
                arrayList.add(new Long2DoubleOpenHashMap());
            }
            ((Long2DoubleMap) arrayList.get(sVDRating.user)).put(sVDRating.iid, sVDRating.value);
        }
        model.userBaselines = new ArrayList<>(arrayList.size());
        Index userIndex = buildContext.userIndex();
        int size2 = arrayList.size();
        for (int i = 0; i < size2; i++) {
            MutableSparseVector mutableSparseVector = new MutableSparseVector((Long2DoubleMap) arrayList.get(i));
            model.userBaselines.add(ratingPredictor.predict(userIndex.getId(i), mutableSparseVector, mutableSparseVector.keySet()));
        }
        arrayList2.trimToSize();
        return arrayList2;
    }
}
