package org.grouplens.lenskit.mf.funksvd;

import javax.annotation.concurrent.NotThreadSafe;
import org.grouplens.lenskit.transform.clamp.ClampingFunction;
import org.grouplens.lenskit.util.iterative.StoppingCondition;

@NotThreadSafe
/* loaded from: input_file:org/grouplens/lenskit/mf/funksvd/FunkSVDFeatureTrainer.class */
public final class FunkSVDFeatureTrainer {
    private final double learningRate;
    private final double trainingRegularization;
    private final ClampingFunction clampingFunction;
    private final StoppingCondition stopper;
    private int epoch = 0;
    private int ratingCount = 0;
    private double err = 0.0d;
    private double ssq = 0.0d;
    private double oldRmse = 0.0d;
    private double rmse = Double.MAX_VALUE;
    private double ufv = 0.0d;
    private double ifv = 0.0d;
    private final double MIN_EPOCHS = 50.0d;

    public FunkSVDFeatureTrainer(FunkSVDTrainingConfig funkSVDTrainingConfig) {
        this.learningRate = funkSVDTrainingConfig.getLearningRate();
        this.trainingRegularization = funkSVDTrainingConfig.getTrainingRegularization();
        this.clampingFunction = funkSVDTrainingConfig.getClampingFunction();
        this.stopper = funkSVDTrainingConfig.getStoppingCondition();
    }

    public void compute(long j, long j2, double d, double d2, double d3, double d4, double d5) {
        this.ufv = d4;
        this.ifv = d5;
        double apply = this.clampingFunction.apply(j, j2, this.clampingFunction.apply(j, j2, d2 + (d4 * d5)) + d);
        this.err = d3 - apply;
        this.ssq += (d3 - apply) * (d3 - apply);
        this.ratingCount++;
    }

    public double getUserUpdate() {
        return ((this.err * this.ifv) - (this.trainingRegularization * this.ufv)) * this.learningRate;
    }

    public double getItemUpdate() {
        return ((this.err * this.ufv) - (this.trainingRegularization * this.ifv)) * this.learningRate;
    }

    public int getEpoch() {
        return this.epoch;
    }

    public double getLastRMSE() {
        return this.rmse;
    }

    public boolean nextEpoch() {
        if (this.ratingCount > 0) {
            this.oldRmse = this.rmse;
            this.rmse = Math.sqrt(this.ssq / this.ratingCount);
            this.ssq = 0.0d;
        }
        if (this.stopper.isFinished(this.epoch, this.oldRmse - this.rmse)) {
            return false;
        }
        this.epoch++;
        this.ratingCount = 0;
        return true;
    }
}
