package org.grouplens.lenskit.mf.funksvd;

import java.io.Serializable;
import javax.inject.Inject;
import org.grouplens.lenskit.baseline.BaselinePredictor;
import org.grouplens.lenskit.core.Shareable;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.grouplens.lenskit.iterative.StoppingCondition;
import org.grouplens.lenskit.iterative.TrainingLoopController;
import org.grouplens.lenskit.transform.clamp.ClampingFunction;

@Shareable
/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDUpdateRule.class */
public final class FunkSVDUpdateRule implements Serializable {
    private static final long serialVersionUID = 2;
    private final double learningRate;
    private final double trainingRegularization;
    private final boolean useTrailingEstimate;
    private final BaselinePredictor baseline;
    private final ClampingFunction clampingFunction;
    private final StoppingCondition stoppingCondition;

    @Inject
    public FunkSVDUpdateRule(@LearningRate double d, @RegularizationTerm double d2, @UseTrailingEstimate boolean z, BaselinePredictor baselinePredictor, ClampingFunction clampingFunction, StoppingCondition stoppingCondition) {
        this.learningRate = d;
        this.trainingRegularization = d2;
        this.baseline = baselinePredictor;
        this.clampingFunction = clampingFunction;
        this.stoppingCondition = stoppingCondition;
        this.useTrailingEstimate = z;
    }

    public TrainingEstimator makeEstimator(PreferenceSnapshot preferenceSnapshot) {
        return new TrainingEstimator(preferenceSnapshot, this.baseline, this.clampingFunction);
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public double getTrainingRegularization() {
        return this.trainingRegularization;
    }

    public ClampingFunction getClampingFunction() {
        return this.clampingFunction;
    }

    public StoppingCondition getStoppingCondition() {
        return this.stoppingCondition;
    }

    public TrainingLoopController getTrainingLoopController() {
        return this.stoppingCondition.newLoop();
    }

    public double computeError(long j, long j2, double d, double d2, double d3, double d4, double d5) {
        double apply = this.clampingFunction.apply(j, j2, d2 + (d4 * d5));
        if (this.useTrailingEstimate) {
            apply = this.clampingFunction.apply(j, j2, apply + d);
        }
        return d3 - apply;
    }

    public double userUpdate(double d, double d2, double d3) {
        return ((d * d3) - (this.trainingRegularization * d2)) * this.learningRate;
    }

    public double itemUpdate(double d, double d2, double d3) {
        return ((d * d2) - (this.trainingRegularization * d3)) * this.learningRate;
    }
}
