package org.grouplens.lenskit.eval.crossfold;

import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.util.Collection;
import java.util.List;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.RecommenderNotAvailableException;
import org.grouplens.lenskit.data.UserRatingProfile;
import org.grouplens.lenskit.data.dao.RatingDataAccessObject;
import org.grouplens.lenskit.data.vector.SparseVector;
import org.grouplens.lenskit.eval.AlgorithmInstance;
import org.grouplens.lenskit.eval.CrossfoldOptions;
import org.grouplens.lenskit.eval.TaskTimer;
import org.grouplens.lenskit.tablewriter.CSVWriterBuilder;
import org.grouplens.lenskit.tablewriter.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/crossfold/CrossfoldEvaluator.class */
public class CrossfoldEvaluator implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(CrossfoldEvaluator.class);
    private final CrossfoldManager manager;
    private final int numFolds;
    private final UserRatingProfileSplitter profileSplitter;
    private final List<AlgorithmInstance> algorithms;
    private TableWriter writer;
    private TableWriter predWriter;
    private int colRunNumber;
    private int colTestSize;
    private int colTrainSize;
    private int colAlgo;
    private int colMAE;
    private int colRMSE;
    private int colNTry;
    private int colNGood;
    private int colCoverage;
    private int colBuildTime;
    private int colPredTime;

    public CrossfoldEvaluator(RatingDataAccessObject ratingDataAccessObject, List<AlgorithmInstance> list, int i, UserRatingProfileSplitter userRatingProfileSplitter, Writer writer) throws IOException {
        this.numFolds = i;
        this.algorithms = list;
        this.writer = makeWriter(writer);
        this.profileSplitter = userRatingProfileSplitter;
        this.manager = new CrossfoldManager(i, ratingDataAccessObject);
    }

    public CrossfoldEvaluator(RatingDataAccessObject ratingDataAccessObject, CrossfoldOptions crossfoldOptions, List<AlgorithmInstance> list, Writer writer) throws IOException {
        this(ratingDataAccessObject, list, crossfoldOptions.getNumFolds(), crossfoldOptions.timeSplit() ? new TimestampUserRatingProfileSplitter(crossfoldOptions.getHoldoutFraction()) : new RandomUserRatingProfileSplitter(crossfoldOptions.getHoldoutFraction()), writer);
        if (crossfoldOptions.predictionFile().isEmpty()) {
            return;
        }
        logger.info("Writing predictions to {}", crossfoldOptions.predictionFile());
        CSVWriterBuilder cSVWriterBuilder = new CSVWriterBuilder();
        cSVWriterBuilder.addColumn("Fold");
        cSVWriterBuilder.addColumn("Algorithm");
        cSVWriterBuilder.addColumn("User");
        cSVWriterBuilder.addColumn("Item");
        cSVWriterBuilder.addColumn("Rating");
        cSVWriterBuilder.addColumn("Prediction");
        this.predWriter = cSVWriterBuilder.makeWriter((Writer) new FileWriter(crossfoldOptions.predictionFile()));
    }

    @Override // java.lang.Runnable
    public void run() {
        for (int i = 0; i < this.numFolds; i++) {
            RatingDataAccessObject trainingSet = this.manager.trainingSet(i);
            try {
                Collection<UserRatingProfile> testSet = this.manager.testSet(i);
                int userCount = trainingSet.getUserCount();
                logger.info(String.format("Running benchmark %d with %d training and %d test users", Integer.valueOf(i + 1), Integer.valueOf(userCount), Integer.valueOf(testSet.size())));
                for (AlgorithmInstance algorithmInstance : this.algorithms) {
                    this.writer.setValue(this.colRunNumber, i + 1);
                    this.writer.setValue(this.colTrainSize, userCount);
                    this.writer.setValue(this.colTestSize, testSet.size());
                    this.writer.setValue(this.colAlgo, algorithmInstance.getName());
                    benchmarkAlgorithm(i + 1, algorithmInstance, trainingSet, testSet);
                    this.writer.finishRow();
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        try {
            this.writer.finish();
            if (this.predWriter != null) {
                this.predWriter.finish();
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    private TableWriter makeWriter(Writer writer) throws IOException {
        CSVWriterBuilder cSVWriterBuilder = new CSVWriterBuilder();
        this.colRunNumber = cSVWriterBuilder.addColumn("RunNumber");
        this.colTrainSize = cSVWriterBuilder.addColumn("TrainSize");
        this.colTestSize = cSVWriterBuilder.addColumn("TestSize");
        this.colAlgo = cSVWriterBuilder.addColumn("Algorithm");
        this.colMAE = cSVWriterBuilder.addColumn("MAE");
        this.colRMSE = cSVWriterBuilder.addColumn("RMSE");
        this.colNTry = cSVWriterBuilder.addColumn("NTried");
        this.colNGood = cSVWriterBuilder.addColumn("NGood");
        this.colCoverage = cSVWriterBuilder.addColumn("Coverage");
        this.colBuildTime = cSVWriterBuilder.addColumn("BuildTime");
        this.colPredTime = cSVWriterBuilder.addColumn("PredTime");
        return cSVWriterBuilder.makeWriter(writer);
    }

    private void benchmarkAlgorithm(int i, AlgorithmInstance algorithmInstance, RatingDataAccessObject ratingDataAccessObject, Collection<UserRatingProfile> collection) {
        TaskTimer taskTimer = new TaskTimer();
        logger.debug("Benchmarking {}", algorithmInstance.getName());
        logger.debug("Building recommender");
        try {
            RatingPredictor ratingPredictor = algorithmInstance.getRecommenderService(ratingDataAccessObject).getRatingPredictor();
            this.writer.setValue(this.colBuildTime, taskTimer.elapsed());
            logger.debug("Built model {} model in {}", algorithmInstance.getName(), taskTimer.elapsedPretty());
            logger.debug("Testing recommender");
            TaskTimer taskTimer2 = new TaskTimer();
            double d = 0.0d;
            double d2 = 0.0d;
            int i2 = 0;
            int i3 = 0;
            for (UserRatingProfile userRatingProfile : collection) {
                long user = userRatingProfile.getUser();
                SplitUserRatingProfile splitProfile = this.profileSplitter.splitProfile(userRatingProfile);
                SparseVector predict = ratingPredictor.predict(user, splitProfile.getQueryVector(), splitProfile.getProbeVector().keySet());
                for (Long2DoubleMap.Entry entry : splitProfile.getProbeVector().fast()) {
                    long longKey = entry.getLongKey();
                    double doubleValue = entry.getDoubleValue();
                    double d3 = predict.get(longKey);
                    i2++;
                    if (this.predWriter != null) {
                        try {
                            TableWriter tableWriter = this.predWriter;
                            Object[] objArr = new Object[6];
                            objArr[0] = Integer.valueOf(i);
                            objArr[1] = algorithmInstance.getName();
                            objArr[2] = Long.valueOf(user);
                            objArr[3] = Long.valueOf(longKey);
                            objArr[4] = Double.valueOf(doubleValue);
                            objArr[5] = Double.isNaN(d3) ? null : Double.valueOf(d3);
                            tableWriter.writeRow(objArr);
                        } catch (IOException e) {
                            logger.error("Error writing to pred. table: {}", e);
                            this.predWriter = null;
                        }
                    }
                    if (!Double.isNaN(d3)) {
                        double d4 = d3 - doubleValue;
                        i3++;
                        d += Math.abs(d4);
                        d2 += d4 * d4;
                    }
                }
            }
            double d5 = d / i3;
            double d6 = d2 / i3;
            double d7 = i3 / i2;
            logger.info(String.format("Recommender %s finished in %s (cov=%f, mae=%f, rmse=%f)", algorithmInstance.getName(), taskTimer.elapsedPretty(), Double.valueOf(d7), Double.valueOf(d5), Double.valueOf(d6)));
            this.writer.setValue(this.colPredTime, taskTimer2.elapsed());
            this.writer.setValue(this.colMAE, d5);
            this.writer.setValue(this.colRMSE, d6);
            this.writer.setValue(this.colNTry, i2);
            this.writer.setValue(this.colNGood, i3);
            this.writer.setValue(this.colCoverage, d7);
        } catch (RecommenderNotAvailableException e2) {
            logger.error("Recommender not available: {}", e2);
            throw new RuntimeException((Throwable) e2);
        }
    }
}
