package org.grouplens.lenskit.eval.traintest;

import java.io.PrintStream;
import java.sql.Connection;
import java.util.List;
import org.grouplens.common.cursors.Cursor;
import org.grouplens.common.cursors.Cursors;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.data.Ratings;
import org.grouplens.lenskit.data.UserRatingProfile;
import org.grouplens.lenskit.data.dao.RatingCollectionDAO;
import org.grouplens.lenskit.data.snapshot.PackedRatingSnapshot;
import org.grouplens.lenskit.data.sql.BasicSQLStatementFactory;
import org.grouplens.lenskit.data.sql.JDBCRatingDAO;
import org.grouplens.lenskit.data.vector.MutableSparseVector;
import org.grouplens.lenskit.eval.AlgorithmInstance;
import org.grouplens.lenskit.eval.SharedRatingSnapshot;
import org.grouplens.lenskit.eval.results.AlgorithmTestAccumulator;
import org.grouplens.lenskit.eval.results.ResultAccumulator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestPredictEvaluator.class */
public class TrainTestPredictEvaluator {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestPredictEvaluator.class);
    private Connection connection;
    private String trainingTable;
    private String testTable;
    private boolean timestamp = true;
    private PrintStream progressStream;

    public TrainTestPredictEvaluator(Connection connection, String str, String str2) {
        this.connection = connection;
        this.trainingTable = str;
        this.testTable = str2;
    }

    public boolean isTimestampEnabled() {
        return this.timestamp;
    }

    public void setTimestampEnabled(boolean z) {
        this.timestamp = z;
    }

    public void setProgressStream(PrintStream printStream) {
        this.progressStream = printStream;
    }

    public void evaluateAlgorithms(List<AlgorithmInstance> list, ResultAccumulator resultAccumulator) {
        RatingCollectionDAO ratingCollectionDAO;
        BasicSQLStatementFactory basicSQLStatementFactory = new BasicSQLStatementFactory();
        basicSQLStatementFactory.setTableName(this.trainingTable);
        if (!this.timestamp) {
            basicSQLStatementFactory.setTimestampColumn((String) null);
        }
        RatingCollectionDAO open = new JDBCRatingDAO.Manager((String) null, basicSQLStatementFactory).open(this.connection);
        RatingCollectionDAO ratingCollectionDAO2 = null;
        logger.debug("Preloading rating snapshot data");
        SharedRatingSnapshot sharedRatingSnapshot = new SharedRatingSnapshot(new PackedRatingSnapshot.Builder(open).build());
        BasicSQLStatementFactory basicSQLStatementFactory2 = new BasicSQLStatementFactory();
        basicSQLStatementFactory2.setTableName(this.testTable);
        basicSQLStatementFactory2.setTimestampColumn((String) null);
        JDBCRatingDAO open2 = new JDBCRatingDAO.Manager((String) null, basicSQLStatementFactory2).open(this.connection);
        try {
            int userCount = open2.getUserCount();
            logger.debug("Evaluating algorithms with {} users", Integer.valueOf(userCount));
            for (AlgorithmInstance algorithmInstance : list) {
                AlgorithmTestAccumulator makeAlgorithmAccumulator = resultAccumulator.makeAlgorithmAccumulator(algorithmInstance);
                if (algorithmInstance.getPreload()) {
                    if (ratingCollectionDAO2 == null) {
                        logger.info("Preloading rating data for {}", algorithmInstance.getName());
                        ratingCollectionDAO2 = new RatingCollectionDAO.Manager(Cursors.makeList(open.getRatings())).open();
                    }
                    ratingCollectionDAO = ratingCollectionDAO2;
                } else {
                    ratingCollectionDAO = open;
                }
                logger.debug("Building {}", algorithmInstance.getName());
                makeAlgorithmAccumulator.startBuildTimer();
                RatingPredictor ratingPredictor = algorithmInstance.buildRecommender(ratingCollectionDAO, sharedRatingSnapshot).getRatingPredictor();
                makeAlgorithmAccumulator.finishBuild();
                logger.info("Testing {}", algorithmInstance.getName());
                makeAlgorithmAccumulator.startTestTimer();
                Cursor<UserRatingProfile> userRatingProfiles = open2.getUserRatingProfiles();
                try {
                    int i = 0;
                    for (UserRatingProfile userRatingProfile : userRatingProfiles) {
                        if (this.progressStream != null) {
                            this.progressStream.format("users: %d / %d\r", Integer.valueOf(i), Integer.valueOf(userCount));
                        }
                        MutableSparseVector userRatingVector = Ratings.userRatingVector(userRatingProfile.getRatings());
                        makeAlgorithmAccumulator.evaluatePrediction(userRatingProfile.getUser(), userRatingVector, ratingPredictor.predict(userRatingProfile.getUser(), userRatingVector.keySet()));
                        i++;
                    }
                    if (this.progressStream != null) {
                        this.progressStream.format("tested users: %d / %d\n", Integer.valueOf(i), Integer.valueOf(userCount));
                    }
                    userRatingProfiles.close();
                    makeAlgorithmAccumulator.finish();
                } finally {
                }
            }
        } finally {
            if (ratingCollectionDAO2 != null) {
                ratingCollectionDAO2.close();
            }
            open.close();
            open2.close();
        }
    }
}
