package org.grouplens.lenskit.eval.traintest;

import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.eventbus.EventBus;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.eval.Attributed;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.metrics.TestUserMetricAccumulator;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelectors;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestJob.class */
public abstract class TrainTestJob implements Callable<Void> {
    private static final Logger logger;
    protected final Attributed algorithmInfo;
    protected final TTDataSet dataSet;
    protected final MeasurementSuite measurements;
    private final ExperimentOutputs output;
    private final TrainTestEvalTask task;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TrainTestJob(TrainTestEvalTask trainTestEvalTask, @Nonnull Attributed attributed, @Nonnull TTDataSet tTDataSet, @Nonnull MeasurementSuite measurementSuite, @Nonnull ExperimentOutputs experimentOutputs) {
        this.task = trainTestEvalTask;
        this.algorithmInfo = attributed;
        this.dataSet = tTDataSet;
        this.measurements = measurementSuite;
        this.output = experimentOutputs;
    }

    public TrainTestEvalTask getTask() {
        return this.task;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Void call() throws IOException, RecommenderBuildException {
        runEvaluation();
        return null;
    }

    private void runEvaluation() throws IOException, RecommenderBuildException {
        EventBus eventBus = this.task.getProject().getEventBus();
        eventBus.post(JobEvents.started(this));
        Closer create = Closer.create();
        try {
            try {
                TableWriter userWriter = this.output.getUserWriter();
                ArrayList newArrayList = Lists.newArrayList();
                logger.info("Building {} on {}", this.algorithmInfo, this.dataSet);
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                buildRecommender();
                stopWatch.stop();
                logger.info("Built {} in {}", this.algorithmInfo.getName(), stopWatch);
                logger.info("Measuring {} on {}", this.algorithmInfo.getName(), this.dataSet.getName());
                List<Object> modelMeasurements = getModelMeasurements();
                if (modelMeasurements == null) {
                    Iterators.addAll(newArrayList, Iterators.limit(Iterators.cycle(new Object[]{(Object) null}), this.measurements.getModelColumnCount()));
                } else {
                    if (!$assertionsDisabled && modelMeasurements.size() != this.measurements.getModelColumnCount()) {
                        throw new AssertionError();
                    }
                    newArrayList.addAll(modelMeasurements);
                }
                logger.info("Testing {}", this.algorithmInfo.getName());
                StopWatch stopWatch2 = new StopWatch();
                stopWatch2.start();
                ArrayList newArrayList2 = Lists.newArrayList();
                ArrayList newArrayList3 = Lists.newArrayList();
                Iterator<TestUserMetric> it = this.measurements.getTestUserMetrics().iterator();
                while (it.hasNext()) {
                    newArrayList2.add(it.next().makeAccumulator(this.algorithmInfo, this.dataSet));
                }
                LongIterator it2 = this.dataSet.getTestData().getUserDAO().getUserIds().iterator();
                while (it2.hasNext()) {
                    if (Thread.interrupted()) {
                        throw new InterruptedException("eval job interrupted");
                    }
                    long nextLong = it2.nextLong();
                    newArrayList3.add(Long.valueOf(nextLong));
                    TestUser userResults = getUserResults(nextLong);
                    Iterator<TestUserMetricAccumulator> it3 = newArrayList2.iterator();
                    while (it3.hasNext()) {
                        List<Object> evaluate = it3.next().evaluate(userResults);
                        if (evaluate != null) {
                            newArrayList3.addAll(evaluate);
                        }
                    }
                    if (userWriter != null) {
                        try {
                            userWriter.writeRow(newArrayList3);
                        } catch (IOException e) {
                            throw new RuntimeException("error writing user row", e);
                        }
                    }
                    newArrayList3.clear();
                    writePredictions(userResults);
                    writeRecommendations(userResults);
                }
                stopWatch2.stop();
                logger.info("Tested {} in {}", this.algorithmInfo.getName(), stopWatch2);
                writeMetricValues(stopWatch, stopWatch2, newArrayList, newArrayList2);
                eventBus.post(JobEvents.finished(this));
                cleanup();
                create.close();
            } catch (Throwable th) {
                eventBus.post(JobEvents.failed(this, th));
                throw create.rethrow(th, RecommenderBuildException.class);
            }
        } catch (Throwable th2) {
            cleanup();
            create.close();
            throw th2;
        }
    }

    protected abstract void buildRecommender() throws RecommenderBuildException;

    protected abstract List<Object> getModelMeasurements();

    protected abstract TestUser getUserResults(long j);

    protected abstract void cleanup();

    private void writePredictions(TestUser testUser) throws IOException {
        SparseVector predictions;
        TableWriter predictionWriter = this.output.getPredictionWriter();
        if (predictionWriter == null || (predictions = testUser.getPredictions()) == null) {
            return;
        }
        SparseVector testRatings = testUser.getTestRatings();
        String[] strArr = new String[predictionWriter.getLayout().getColumnCount()];
        strArr[0] = Long.toString(testUser.getUserId());
        for (VectorEntry vectorEntry : testRatings.fast()) {
            long key = vectorEntry.getKey();
            strArr[1] = Long.toString(key);
            strArr[2] = Double.toString(vectorEntry.getValue());
            if (predictions.containsKey(key)) {
                strArr[3] = Double.toString(predictions.get(key));
            } else {
                strArr[3] = null;
            }
            int i = 4;
            Iterator<Pair<Symbol, String>> it = this.measurements.getPredictionChannels().iterator();
            while (it.hasNext()) {
                Symbol symbol = (Symbol) it.next().getLeft();
                if (predictions.hasChannelVector(symbol) && predictions.getChannelVector(symbol).containsKey(key)) {
                    strArr[i] = Double.toString(predictions.getChannelVector(symbol).get(key));
                } else {
                    strArr[i] = null;
                }
                i++;
            }
            predictionWriter.writeRow(strArr);
        }
    }

    private void writeRecommendations(TestUser testUser) throws IOException {
        List<ScoredId> recommendations;
        TableWriter recommendationWriter = this.output.getRecommendationWriter();
        if (recommendationWriter == null || (recommendations = testUser.getRecommendations(-1, ItemSelectors.allItems(), ItemSelectors.trainingItems())) == null) {
            return;
        }
        String[] strArr = new String[recommendationWriter.getLayout().getColumnCount()];
        strArr[0] = Long.toString(testUser.getUserId());
        int i = 1;
        for (ScoredId scoredId : CollectionUtils.fast(recommendations)) {
            strArr[1] = Long.toString(scoredId.getId());
            strArr[2] = String.valueOf(i);
            i++;
            strArr[3] = Double.toString(scoredId.getScore());
            recommendationWriter.writeRow(strArr);
        }
    }

    private void writeMetricValues(StopWatch stopWatch, StopWatch stopWatch2, List<Object> list, List<TestUserMetricAccumulator> list2) throws IOException {
        TableWriter resultsWriter = this.output.getResultsWriter();
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(Long.valueOf(stopWatch.getTime()));
        newArrayList.add(Long.valueOf(stopWatch2.getTime()));
        newArrayList.addAll(list);
        Iterator<TestUserMetricAccumulator> it = list2.iterator();
        while (it.hasNext()) {
            newArrayList.addAll(it.next().finalResults());
        }
        resultsWriter.writeRow(newArrayList);
    }

    public String toString() {
        return String.format("test %s on %s", this.algorithmInfo, this.dataSet);
    }

    static {
        $assertionsDisabled = !TrainTestJob.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(TrainTestJob.class);
    }
}
