package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Supplier;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.time.StopWatch;
import org.grouplens.lenskit.ItemRecommender;
import org.grouplens.lenskit.RatingPredictor;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.collections.ScoredLongList;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.UserHistory;
import org.grouplens.lenskit.data.dao.DataAccessObject;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.eval.AlgorithmInstance;
import org.grouplens.lenskit.eval.Job;
import org.grouplens.lenskit.eval.SharedPreferenceSnapshot;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.util.io.LKFileUtils;
import org.grouplens.lenskit.util.tablewriter.TableWriter;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob.class */
public class TrainTestEvalJob implements Job {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalJob.class);
    private final int numRecs;

    @Nonnull
    private final AlgorithmInstance algorithm;

    @Nonnull
    private final List<TestUserMetric> evaluators;

    @Nonnull
    private final TTDataSet data;

    @Nonnull
    private final Supplier<TableWriter> outputSupplier;

    @Nonnull
    private Supplier<TableWriter> userOutputSupplier;

    @Nonnull
    private Supplier<TableWriter> predictOutputSupplier;
    private final Supplier<SharedPreferenceSnapshot> snapshot;
    private final int outputColumnCount;

    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$HistorySupplier.class */
    private class HistorySupplier implements Supplier<UserHistory<Rating>> {
        private final DataAccessObject dao;
        private final long user;

        public HistorySupplier(DataAccessObject dataAccessObject, long j) {
            this.dao = dataAccessObject;
            this.user = j;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public UserHistory<Rating> m65get() {
            return this.dao.getUserHistory(this.user, Rating.class);
        }
    }

    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$PredictionSupplier.class */
    private class PredictionSupplier implements Supplier<SparseVector> {
        private final RatingPredictor predictor;
        private final long user;
        private final LongSet items;

        public PredictionSupplier(RatingPredictor ratingPredictor, long j, LongSet longSet) {
            this.predictor = ratingPredictor;
            this.user = j;
            this.items = longSet;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public SparseVector m66get() {
            if (this.predictor == null) {
                throw new IllegalArgumentException("cannot compute predictions without a predictor");
            }
            return this.predictor.score(this.user, this.items);
        }
    }

    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$RecommendationSupplier.class */
    private class RecommendationSupplier implements Supplier<ScoredLongList> {
        private final ItemRecommender recommender;
        private final long user;
        private final LongSet items;

        public RecommendationSupplier(ItemRecommender itemRecommender, long j, LongSet longSet) {
            this.recommender = itemRecommender;
            this.user = j;
            this.items = longSet;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public ScoredLongList m67get() {
            if (this.recommender == null) {
                throw new IllegalArgumentException("cannot compute recommendations without a recommender");
            }
            return this.recommender.recommend(this.user, TrainTestEvalJob.this.numRecs, this.items, (Set) null);
        }
    }

    public TrainTestEvalJob(AlgorithmInstance algorithmInstance, List<TestUserMetric> list, TTDataSet tTDataSet, Supplier<SharedPreferenceSnapshot> supplier, Supplier<TableWriter> supplier2, int i) {
        this.algorithm = algorithmInstance;
        this.evaluators = list;
        this.data = tTDataSet;
        this.snapshot = supplier;
        this.outputSupplier = supplier2;
        this.numRecs = i;
        int i2 = 2;
        for (TestUserMetric testUserMetric : list) {
            if (testUserMetric.getColumnLabels() != null) {
                i2 += testUserMetric.getColumnLabels().length;
            }
        }
        this.outputColumnCount = i2;
    }

    public void setUserOutput(Supplier<TableWriter> supplier) {
        this.userOutputSupplier = supplier;
    }

    public void setPredictOutput(Supplier<TableWriter> supplier) {
        this.predictOutputSupplier = supplier;
    }

    @Override // org.grouplens.lenskit.eval.Job
    public String getName() {
        return this.algorithm.getName();
    }

    /* JADX WARN: Finally extract failed */
    @Override // org.grouplens.lenskit.eval.Job, java.lang.Runnable
    public void run() {
        DataAccessObject snapshot = this.data.getTrainFactory().snapshot();
        try {
            try {
                TableWriter tableWriter = (TableWriter) this.userOutputSupplier.get();
                TableWriter tableWriter2 = (TableWriter) this.predictOutputSupplier.get();
                logger.info("Building {}", this.algorithm.getName());
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                Recommender buildRecommender = this.algorithm.buildRecommender(snapshot, this.snapshot, this.data.getPreferenceDomain());
                RatingPredictor ratingPredictor = buildRecommender.getRatingPredictor();
                ItemRecommender itemRecommender = buildRecommender.getItemRecommender();
                stopWatch.stop();
                logger.info("Built {} in {}", this.algorithm.getName(), stopWatch);
                logger.info("Testing {}", this.algorithm.getName());
                StopWatch stopWatch2 = new StopWatch();
                stopWatch2.start();
                ArrayList arrayList = new ArrayList(this.evaluators.size());
                Object[] objArr = tableWriter != null ? new Object[tableWriter.getLayout().getColumnCount()] : null;
                DataAccessObject create = this.data.getTestFactory().create();
                try {
                    Iterator<TestUserMetric> it = this.evaluators.iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next().makeAccumulator(this.algorithm, this.data));
                    }
                    Cursor<UserHistory> userHistories = create.getUserHistories(Rating.class);
                    try {
                        for (UserHistory userHistory : userHistories) {
                            long userId = userHistory.getUserId();
                            SparseVector makeRatingVector = RatingVectorUserHistorySummarizer.makeRatingVector(userHistory);
                            TestUser testUser = new TestUser(userId, makeRatingVector, new HistorySupplier(snapshot, userId), new PredictionSupplier(ratingPredictor, userId, makeRatingVector.keySet()), new RecommendationSupplier(itemRecommender, userId, makeRatingVector.keySet()));
                            int i = 0;
                            Iterator<TestUserMetricAccumulator> it2 = arrayList.iterator();
                            while (it2.hasNext()) {
                                Object[] evaluate = it2.next().evaluate(testUser);
                                if (evaluate != null && objArr != null) {
                                    System.arraycopy(evaluate, 0, objArr, i, evaluate.length);
                                    i += evaluate.length;
                                }
                            }
                            if (objArr != null) {
                                try {
                                    tableWriter.writeRow(objArr);
                                } catch (IOException e) {
                                    throw new RuntimeException("error writing user row", e);
                                }
                            }
                            if (tableWriter2 != null) {
                                writePredictions(tableWriter2, userId, makeRatingVector, testUser.getPredictions());
                            }
                        }
                        userHistories.close();
                        create.close();
                        stopWatch2.stop();
                        logger.info("Tested {} in {}", this.algorithm.getName(), stopWatch2);
                        try {
                            writeOutput(stopWatch, stopWatch2, arrayList);
                        } catch (IOException e2) {
                            logger.error("Error writing output", e2);
                        }
                        LKFileUtils.close(new Closeable[]{tableWriter, tableWriter2});
                        snapshot.close();
                    } catch (Throwable th) {
                        userHistories.close();
                        throw th;
                    }
                } catch (Throwable th2) {
                    create.close();
                    throw th2;
                }
            } catch (RecommenderBuildException e3) {
                logger.error("error building recommender {}: {}", this.algorithm, e3);
                throw new RuntimeException((Throwable) e3);
            }
        } catch (Throwable th3) {
            LKFileUtils.close(new Closeable[]{null, null});
            snapshot.close();
            throw th3;
        }
    }

    private void writePredictions(TableWriter tableWriter, long j, SparseVector sparseVector, SparseVector sparseVector2) {
        String[] strArr = new String[4];
        strArr[0] = Long.toString(j);
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            strArr[1] = Long.toString(key);
            strArr[2] = Double.toString(vectorEntry.getValue());
            if (sparseVector2.containsKey(key)) {
                strArr[3] = Double.toString(sparseVector2.get(key));
            } else {
                strArr[3] = null;
            }
            try {
                tableWriter.writeRow(strArr);
            } catch (IOException e) {
                throw new RuntimeException("error writing predictions", e);
            }
        }
    }

    private void writeOutput(StopWatch stopWatch, StopWatch stopWatch2, List<TestUserMetricAccumulator> list) throws IOException {
        Object[] objArr = new Object[this.outputColumnCount];
        objArr[0] = Long.valueOf(stopWatch.getTime());
        objArr[1] = Long.valueOf(stopWatch2.getTime());
        int i = 2;
        Iterator<TestUserMetricAccumulator> it = list.iterator();
        while (it.hasNext()) {
            Object[] finalResults = it.next().finalResults();
            if (finalResults != null) {
                int length = finalResults.length;
                System.arraycopy(finalResults, 0, objArr, i, length);
                i += length;
            }
        }
        TableWriter tableWriter = (TableWriter) this.outputSupplier.get();
        try {
            tableWriter.writeRow(objArr);
            tableWriter.close();
        } catch (Throwable th) {
            tableWriter.close();
            throw th;
        }
    }
}
