package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.grouplens.common.cursors.Cursor;
import org.grouplens.common.cursors.Cursors;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.data.UserHistory;
import org.grouplens.lenskit.data.dao.DAOFactory;
import org.grouplens.lenskit.data.dao.DataAccessObject;
import org.grouplens.lenskit.data.dao.EventCollectionDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.snapshot.PackedRatingSnapshot;
import org.grouplens.lenskit.data.snapshot.RatingSnapshot;
import org.grouplens.lenskit.data.sql.BasicSQLStatementFactory;
import org.grouplens.lenskit.data.sql.JDBCRatingDAO;
import org.grouplens.lenskit.data.vector.UserRatingVector;
import org.grouplens.lenskit.eval.AlgorithmInstance;
import org.grouplens.lenskit.eval.SharedRatingSnapshot;
import org.grouplens.lenskit.eval.results.AlgorithmTestAccumulator;
import org.grouplens.lenskit.eval.results.ResultAccumulator;
import org.grouplens.lenskit.util.LazyValue;
import org.grouplens.lenskit.util.parallel.ExecHelpers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestPredictEvaluator.class */
public class TrainTestPredictEvaluator {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestPredictEvaluator.class);
    private String databaseUrl;
    private String trainingTable;
    private String testTable;
    private boolean timestamp = true;
    private int threadCount = 0;
    private String name;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestPredictEvaluator$EvalTask.class */
    public class EvalTask implements Runnable {
        private AlgorithmInstance algorithm;
        private ResultAccumulator resultAccumulator;
        private DAOFactory daoManager;
        private DAOFactory testDaoManager;
        private LazyValue<List<Event>> ratingCache;
        private LazyValue<? extends RatingSnapshot> ratingSnapshot;

        public EvalTask(DAOFactory dAOFactory, DAOFactory dAOFactory2, ResultAccumulator resultAccumulator, AlgorithmInstance algorithmInstance, LazyValue<List<Event>> lazyValue, LazyValue<? extends RatingSnapshot> lazyValue2) {
            this.daoManager = dAOFactory;
            this.testDaoManager = dAOFactory2;
            this.resultAccumulator = resultAccumulator;
            this.algorithm = algorithmInstance;
            this.ratingCache = lazyValue;
            this.ratingSnapshot = lazyValue2;
        }

        /* JADX WARN: Finally extract failed */
        @Override // java.lang.Runnable
        public void run() {
            AlgorithmTestAccumulator makeAlgorithmAccumulator = this.resultAccumulator.makeAlgorithmAccumulator(this.algorithm);
            EventCollectionDAO eventCollectionDAO = this.algorithm.getPreload() ? new EventCollectionDAO((Collection) this.ratingCache.get()) : this.daoManager.create();
            try {
                TrainTestPredictEvaluator.logger.info("Building {}", this.algorithm.getName());
                makeAlgorithmAccumulator.startBuildTimer();
                RatingPredictor ratingPredictor = this.algorithm.buildRecommender(eventCollectionDAO, (RatingSnapshot) this.ratingSnapshot.get()).getRatingPredictor();
                makeAlgorithmAccumulator.finishBuild();
                TrainTestPredictEvaluator.logger.info("Testing {}", this.algorithm.getName());
                makeAlgorithmAccumulator.startTestTimer();
                DataAccessObject create = this.testDaoManager.create();
                try {
                    Cursor<UserHistory> userHistories = create.getUserHistories(Rating.class);
                    try {
                        for (UserHistory userHistory : userHistories) {
                            UserRatingVector fromRatings = UserRatingVector.fromRatings(userHistory);
                            makeAlgorithmAccumulator.evaluatePrediction(userHistory.getUserId(), fromRatings, ratingPredictor.predict(userHistory.getUserId(), fromRatings.keySet()));
                        }
                        userHistories.close();
                        create.close();
                        makeAlgorithmAccumulator.finish();
                        eventCollectionDAO.close();
                    } catch (Throwable th) {
                        userHistories.close();
                        throw th;
                    }
                } catch (Throwable th2) {
                    create.close();
                    throw th2;
                }
            } catch (Throwable th3) {
                eventCollectionDAO.close();
                throw th3;
            }
        }
    }

    public TrainTestPredictEvaluator(String str, String str2, String str3) {
        this.databaseUrl = str;
        this.name = str;
        this.trainingTable = str2;
        this.testTable = str3;
    }

    public boolean isTimestampEnabled() {
        return this.timestamp;
    }

    public void setTimestampEnabled(boolean z) {
        this.timestamp = z;
    }

    public void setThreadCount(int i) {
        this.threadCount = i;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String str) {
        this.name = str;
    }

    public int getThreadCount() {
        if (this.threadCount > 0) {
            return this.threadCount;
        }
        int parseInt = Integer.parseInt(System.getProperty("lenskit.eval.thread.count", "0"));
        return parseInt > 0 ? parseInt : Runtime.getRuntime().availableProcessors();
    }

    protected JDBCRatingDAO.Factory trainingDAOManager() {
        BasicSQLStatementFactory basicSQLStatementFactory = new BasicSQLStatementFactory();
        basicSQLStatementFactory.setTableName(this.trainingTable);
        if (!this.timestamp) {
            basicSQLStatementFactory.setTimestampColumn((String) null);
        }
        return new JDBCRatingDAO.Factory(this.databaseUrl, basicSQLStatementFactory);
    }

    protected JDBCRatingDAO.Factory testDAOManager() {
        BasicSQLStatementFactory basicSQLStatementFactory = new BasicSQLStatementFactory();
        basicSQLStatementFactory.setTableName(this.testTable);
        basicSQLStatementFactory.setTimestampColumn((String) null);
        return new JDBCRatingDAO.Factory(this.databaseUrl, basicSQLStatementFactory);
    }

    @Deprecated
    public void evaluateAlgorithms(EvaluationRecipe evaluationRecipe) {
        runEvaluation(evaluationRecipe);
    }

    public void runEvaluation(EvaluationRecipe evaluationRecipe) {
        List<Runnable> makeEvalTasks = makeEvalTasks(evaluationRecipe);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(getThreadCount());
        try {
            try {
                ExecHelpers.parallelRun(newFixedThreadPool, makeEvalTasks);
                newFixedThreadPool.shutdown();
            } catch (ExecutionException e) {
                throw new RuntimeException("Error evaluating recommenders", ExecHelpers.unwrapExecutionException(e));
            }
        } catch (Throwable th) {
            newFixedThreadPool.shutdown();
            throw th;
        }
    }

    public List<Runnable> makeEvalTasks(EvaluationRecipe evaluationRecipe) {
        final JDBCRatingDAO.Factory trainingDAOManager = trainingDAOManager();
        final JDBCRatingDAO.Factory testDAOManager = testDAOManager();
        final ResultAccumulator makeAccumulator = evaluationRecipe.makeAccumulator(getName());
        final LazyValue lazyValue = new LazyValue(new Callable<SharedRatingSnapshot>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestPredictEvaluator.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public SharedRatingSnapshot call() {
                TrainTestPredictEvaluator.logger.info("Loading snapshot for {}", TrainTestPredictEvaluator.this.name);
                return TrainTestPredictEvaluator.this.loadSnapshot(trainingDAOManager);
            }
        });
        final LazyValue lazyValue2 = new LazyValue(new Callable<List<Event>>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestPredictEvaluator.2
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public List<Event> call() {
                TrainTestPredictEvaluator.logger.info("Preloading ratings for {}", TrainTestPredictEvaluator.this.name);
                DataAccessObject create = trainingDAOManager.create();
                try {
                    ArrayList makeList = Cursors.makeList(create.getEvents());
                    create.close();
                    return makeList;
                } catch (Throwable th) {
                    create.close();
                    throw th;
                }
            }
        });
        return Lists.transform(evaluationRecipe.getAlgorithms(), new Function<AlgorithmInstance, Runnable>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestPredictEvaluator.3
            public Runnable apply(AlgorithmInstance algorithmInstance) {
                return new EvalTask(trainingDAOManager, testDAOManager, makeAccumulator, algorithmInstance, lazyValue2, lazyValue);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SharedRatingSnapshot loadSnapshot(DAOFactory dAOFactory) {
        DataAccessObject create = dAOFactory.create();
        try {
            SharedRatingSnapshot sharedRatingSnapshot = new SharedRatingSnapshot(new PackedRatingSnapshot.Builder(create).build());
            create.close();
            return sharedRatingSnapshot;
        } catch (Throwable th) {
            create.close();
            throw th;
        }
    }
}
