package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import javax.annotation.Nonnull;
import javax.inject.Provider;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.RecommenderInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob.class */
public class TrainTestEvalJob implements Runnable {
    private static final Logger logger;
    private final int numRecs;

    @Nonnull
    private final AlgorithmInstance algorithm;

    @Nonnull
    private final List<TestUserMetric> evaluators;

    @Nonnull
    private final List<ModelMetric> modelMetrics;

    @Nonnull
    private final List<Pair<Symbol, String>> channels;

    @Nonnull
    private final TTDataSet data;

    @Nonnull
    private final Supplier<TableWriter> outputSupplier;

    @Nonnull
    private final Supplier<TableWriter> userOutputSupplier;

    @Nonnull
    private final Supplier<TableWriter> predictOutputSupplier;
    private final Provider<PreferenceSnapshot> snapshot;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$HistorySupplier.class */
    public class HistorySupplier implements Supplier<UserHistory<Event>> {
        private final UserEventDAO userEventDAO;
        private final long user;

        public HistorySupplier(UserEventDAO userEventDAO, long j) {
            this.userEventDAO = userEventDAO;
            this.user = j;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public UserHistory<Event> m42get() {
            return this.userEventDAO.getEventsForUser(this.user);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$PredictionSupplier.class */
    public class PredictionSupplier implements Supplier<SparseVector> {
        private final RecommenderInstance predictor;
        private final long user;
        private final LongSet items;

        public PredictionSupplier(RecommenderInstance recommenderInstance, long j, LongSet longSet) {
            this.predictor = recommenderInstance;
            this.user = j;
            this.items = longSet;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public SparseVector m43get() {
            if (this.predictor == null) {
                throw new IllegalArgumentException("cannot compute predictions without a predictor");
            }
            SparseVector predictions = this.predictor.getPredictions(this.user, this.items);
            if (predictions == null) {
                throw new IllegalArgumentException("no predictions");
            }
            return predictions;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalJob$RecommendationSupplier.class */
    public class RecommendationSupplier implements Supplier<List<ScoredId>> {
        private final RecommenderInstance recommender;
        private final long user;
        private final LongSet items;

        public RecommendationSupplier(RecommenderInstance recommenderInstance, long j, LongSet longSet) {
            this.recommender = recommenderInstance;
            this.user = j;
            this.items = longSet;
        }

        /* renamed from: get, reason: merged with bridge method [inline-methods] */
        public List<ScoredId> m44get() {
            if (this.recommender == null) {
                throw new IllegalArgumentException("cannot compute recommendations without a recommender");
            }
            List<ScoredId> recommendations = this.recommender.getRecommendations(this.user, this.items, TrainTestEvalJob.this.numRecs);
            if (recommendations == null) {
                throw new IllegalArgumentException("no recommendations");
            }
            return recommendations;
        }
    }

    public TrainTestEvalJob(@Nonnull AlgorithmInstance algorithmInstance, @Nonnull List<TestUserMetric> list, @Nonnull List<ModelMetric> list2, @Nonnull List<Pair<Symbol, String>> list3, @Nonnull TTDataSet tTDataSet, Provider<PreferenceSnapshot> provider, @Nonnull Supplier<TableWriter> supplier, @Nonnull Supplier<TableWriter> supplier2, @Nonnull Supplier<TableWriter> supplier3, int i) {
        this.algorithm = algorithmInstance;
        this.evaluators = list;
        this.modelMetrics = list2;
        this.channels = list3;
        this.data = tTDataSet;
        this.snapshot = provider;
        this.outputSupplier = supplier;
        this.userOutputSupplier = supplier2;
        this.predictOutputSupplier = supplier3;
        this.numRecs = i;
    }

    @Override // java.lang.Runnable
    public void run() {
        try {
            runEvaluation();
        } catch (Exception e) {
            throw new TrainTestJobException(e);
        }
    }

    private void runEvaluation() throws IOException, RecommenderBuildException {
        Closer create = Closer.create();
        try {
            try {
                TableWriter tableWriter = (TableWriter) this.userOutputSupplier.get();
                if (tableWriter != null) {
                    create.register(tableWriter);
                }
                TableWriter tableWriter2 = (TableWriter) this.predictOutputSupplier.get();
                if (tableWriter2 != null) {
                    create.register(tableWriter2);
                }
                ArrayList newArrayList = Lists.newArrayList();
                ExecutionInfo buildExecInfo = buildExecInfo();
                logger.info("Building {}", this.algorithm.getName());
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                RecommenderInstance makeTestableRecommender = this.algorithm.makeTestableRecommender(this.data, this.snapshot, buildExecInfo);
                stopWatch.stop();
                logger.info("Built {} in {}", this.algorithm.getName(), stopWatch);
                logger.info("Measuring {}", this.algorithm.getName());
                Iterator<ModelMetric> it = this.modelMetrics.iterator();
                while (it.hasNext()) {
                    newArrayList.addAll(it.next().measureAlgorithm(this.algorithm, this.data, makeTestableRecommender.mo6getRecommender()));
                }
                logger.info("Testing {}", this.algorithm.getName());
                StopWatch stopWatch2 = new StopWatch();
                stopWatch2.start();
                ArrayList arrayList = new ArrayList(this.evaluators.size());
                ArrayList arrayList2 = new ArrayList();
                UserEventDAO userEventDAO = this.data.getTestData().getUserEventDAO();
                Iterator<TestUserMetric> it2 = this.evaluators.iterator();
                while (it2.hasNext()) {
                    arrayList.add(it2.next().makeAccumulator(this.algorithm, this.data));
                }
                for (UserHistory userHistory : create.register(userEventDAO.streamEventsByUser())) {
                    if (!$assertionsDisabled && !arrayList2.isEmpty()) {
                        throw new AssertionError();
                    }
                    arrayList2.add(Long.valueOf(userHistory.getUserId()));
                    long userId = userHistory.getUserId();
                    LongSet itemSet = userHistory.itemSet();
                    TestUser testUser = new TestUser(userId, new HistorySupplier(makeTestableRecommender.getUserEventDAO(), userId), Suppliers.ofInstance(userHistory), new PredictionSupplier(makeTestableRecommender, userId, itemSet), new RecommendationSupplier(makeTestableRecommender, userId, itemSet));
                    Iterator<TestUserMetricAccumulator> it3 = arrayList.iterator();
                    while (it3.hasNext()) {
                        Object[] evaluate = it3.next().evaluate(testUser);
                        if (evaluate != null) {
                            arrayList2.addAll(Arrays.asList(evaluate));
                        }
                    }
                    if (tableWriter != null) {
                        try {
                            tableWriter.writeRow(arrayList2);
                        } catch (IOException e) {
                            throw new RuntimeException("error writing user row", e);
                        }
                    }
                    arrayList2.clear();
                    if (tableWriter2 != null) {
                        writePredictions(tableWriter2, userId, RatingVectorUserHistorySummarizer.makeRatingVector(userHistory), testUser.getPredictions());
                    }
                }
                stopWatch2.stop();
                logger.info("Tested {} in {}", this.algorithm.getName(), stopWatch2);
                writeOutput(stopWatch, stopWatch2, newArrayList, arrayList);
                create.close();
            } catch (Throwable th) {
                throw create.rethrow(th, RecommenderBuildException.class);
            }
        } catch (Throwable th2) {
            create.close();
            throw th2;
        }
    }

    private ExecutionInfo buildExecInfo() {
        ExecutionInfo.Builder builder = new ExecutionInfo.Builder();
        builder.setAlgoName(this.algorithm.getName()).setAlgoAttributes(this.algorithm.getAttributes()).setDataName(this.data.getName()).setDataAttributes(this.data.getAttributes());
        return builder.m2build();
    }

    private void writePredictions(TableWriter tableWriter, long j, SparseVector sparseVector, SparseVector sparseVector2) throws IOException {
        String[] strArr = new String[tableWriter.getLayout().getColumnCount()];
        strArr[0] = Long.toString(j);
        for (VectorEntry vectorEntry : sparseVector.fast()) {
            long key = vectorEntry.getKey();
            strArr[1] = Long.toString(key);
            strArr[2] = Double.toString(vectorEntry.getValue());
            if (sparseVector2.containsKey(key)) {
                strArr[3] = Double.toString(sparseVector2.get(key));
            } else {
                strArr[3] = null;
            }
            int i = 4;
            Iterator<Pair<Symbol, String>> it = this.channels.iterator();
            while (it.hasNext()) {
                Symbol symbol = (Symbol) it.next().getLeft();
                if (sparseVector2.hasChannelVector(symbol) && sparseVector2.getChannelVector(symbol).containsKey(key)) {
                    strArr[i] = Double.toString(sparseVector2.getChannelVector(symbol).get(key));
                } else {
                    strArr[i] = null;
                }
                i++;
            }
            tableWriter.writeRow(strArr);
        }
    }

    private void writeOutput(StopWatch stopWatch, StopWatch stopWatch2, List<Object> list, List<TestUserMetricAccumulator> list2) throws IOException {
        TableWriter tableWriter = (TableWriter) this.outputSupplier.get();
        try {
            Object[] objArr = new Object[tableWriter.getLayout().getColumnCount()];
            objArr[0] = Long.valueOf(stopWatch.getTime());
            objArr[1] = Long.valueOf(stopWatch2.getTime());
            int i = 2;
            Iterator<Object> it = list.iterator();
            while (it.hasNext()) {
                objArr[i] = it.next();
                i++;
            }
            Iterator<TestUserMetricAccumulator> it2 = list2.iterator();
            while (it2.hasNext()) {
                Object[] finalResults = it2.next().finalResults();
                if (finalResults != null) {
                    int length = finalResults.length;
                    System.arraycopy(finalResults, 0, objArr, i, length);
                    i += length;
                }
            }
            tableWriter.writeRow(objArr);
            tableWriter.close();
        } catch (Throwable th) {
            tableWriter.close();
            throw th;
        }
    }

    static {
        $assertionsDisabled = !TrainTestEvalJob.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(TrainTestEvalJob.class);
    }
}
