package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
import com.google.common.eventbus.EventBus;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.eval.Attributed;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestJob.class */
public abstract class TrainTestJob implements Callable<Void> {
    private static final Logger logger;
    protected final Attributed algorithmInfo;
    protected final TTDataSet dataSet;
    private final TrainTestEvalTask task;
    private ExperimentOutputs outputs;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestJob$MetricWithAccumulator.class */
    public static class MetricWithAccumulator<A> {
        private final Metric<A> metric;
        private final A accumulator;

        public MetricWithAccumulator(Metric<A> metric, A a) {
            this.metric = metric;
            this.accumulator = a;
        }

        public List<Object> measureUser(TestUser testUser) {
            return this.metric.measureUser(testUser, this.accumulator);
        }

        public Metric<A> getMetric() {
            return this.metric;
        }

        public A getAccumulator() {
            return this.accumulator;
        }

        public List<Object> getResults() {
            return this.metric.getResults(this.accumulator);
        }
    }

    public TrainTestJob(TrainTestEvalTask trainTestEvalTask, @Nonnull Attributed attributed, @Nonnull TTDataSet tTDataSet) {
        this.task = trainTestEvalTask;
        this.algorithmInfo = attributed;
        this.dataSet = tTDataSet;
    }

    public TrainTestEvalTask getTask() {
        return this.task;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Void call() throws IOException, RecommenderBuildException {
        runEvaluation();
        return null;
    }

    private void runEvaluation() throws IOException, RecommenderBuildException {
        EventBus eventBus = this.task.getProject().getEventBus();
        eventBus.post(JobEvents.started(this));
        Closer create = Closer.create();
        try {
            try {
                this.outputs = this.task.getOutputs().getPrefixed(this.algorithmInfo, this.dataSet);
                TableWriter userWriter = this.outputs.getUserWriter();
                ArrayList newArrayList = Lists.newArrayList();
                logger.info("Building {} on {}", this.algorithmInfo, this.dataSet);
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                buildRecommender();
                stopWatch.stop();
                logger.info("Built {} in {}", this.algorithmInfo.getName(), stopWatch);
                logger.info("Measuring {} on {}", this.algorithmInfo.getName(), this.dataSet.getName());
                StopWatch stopWatch2 = new StopWatch();
                stopWatch2.start();
                ArrayList newArrayList2 = Lists.newArrayList();
                ArrayList newArrayList3 = Lists.newArrayList();
                Iterator<Metric<?>> it = this.outputs.getMetrics().iterator();
                while (it.hasNext()) {
                    newArrayList3.add(makeMetricAccumulator(it.next()));
                }
                LongSet userIds = this.dataSet.getTestData().getUserDAO().getUserIds();
                NumberFormat percentInstance = NumberFormat.getPercentInstance();
                percentInstance.setMaximumFractionDigits(2);
                percentInstance.setMinimumFractionDigits(2);
                int size = userIds.size();
                logger.info("Testing {} on {} ({} users)", new Object[]{this.algorithmInfo, this.dataSet, Integer.valueOf(size)});
                int i = 0;
                LongIterator it2 = userIds.iterator();
                while (it2.hasNext()) {
                    if (Thread.interrupted()) {
                        throw new InterruptedException("eval job interrupted");
                    }
                    long nextLong = it2.nextLong();
                    newArrayList2.add(Long.valueOf(nextLong));
                    newArrayList2.add(null);
                    if (!$assertionsDisabled && newArrayList2.size() != 2) {
                        throw new AssertionError();
                    }
                    Stopwatch createStarted = Stopwatch.createStarted();
                    TestUser userResults = getUserResults(nextLong);
                    newArrayList2.add(Integer.valueOf(userResults.getTrainHistory().size()));
                    newArrayList2.add(Integer.valueOf(userResults.getTestHistory().size()));
                    Iterator<MetricWithAccumulator<?>> it3 = newArrayList3.iterator();
                    while (it3.hasNext()) {
                        List<Object> measureUser = it3.next().measureUser(userResults);
                        if (measureUser != null) {
                            newArrayList2.addAll(measureUser);
                        }
                    }
                    createStarted.stop();
                    newArrayList2.set(1, Double.valueOf(createStarted.elapsed(TimeUnit.MILLISECONDS) * 0.001d));
                    if (userWriter != null) {
                        try {
                            userWriter.writeRow(newArrayList2);
                        } catch (IOException e) {
                            throw new RuntimeException("error writing user row", e);
                        }
                    }
                    newArrayList2.clear();
                    i++;
                    if (i % 100 == 0) {
                        stopWatch2.split();
                        logger.info("tested {} of {} users ({}), ETA {}", new Object[]{Integer.valueOf(i), Integer.valueOf(size), percentInstance.format(i / size), DurationFormatUtils.formatDurationHMS((long) ((size - i) * (stopWatch2.getSplitTime() / i)))});
                    }
                }
                stopWatch2.stop();
                logger.info("Tested {} in {}", this.algorithmInfo.getName(), stopWatch2);
                writeMetricValues(stopWatch, stopWatch2, newArrayList, newArrayList3);
                eventBus.post(JobEvents.finished(this));
                try {
                    cleanup();
                    this.outputs = null;
                    create.close();
                } finally {
                }
            } catch (Throwable th) {
                eventBus.post(JobEvents.failed(this, th));
                throw create.rethrow(th, RecommenderBuildException.class);
            }
        } catch (Throwable th2) {
            try {
                cleanup();
                this.outputs = null;
                create.close();
                throw th2;
            } finally {
            }
        }
    }

    protected <A> MetricWithAccumulator<A> makeMetricAccumulator(Metric<A> metric) {
        return new MetricWithAccumulator<>(metric, metric.createContext(this.algorithmInfo, this.dataSet, null));
    }

    protected abstract void buildRecommender() throws RecommenderBuildException;

    protected abstract TestUser getUserResults(long j);

    protected abstract void cleanup();

    private void writeMetricValues(StopWatch stopWatch, StopWatch stopWatch2, List<Object> list, List<MetricWithAccumulator<?>> list2) throws IOException {
        TableWriter resultsWriter = this.outputs.getResultsWriter();
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(Long.valueOf(stopWatch.getTime()));
        newArrayList.add(Long.valueOf(stopWatch2.getTime()));
        newArrayList.addAll(list);
        Iterator<MetricWithAccumulator<?>> it = list2.iterator();
        while (it.hasNext()) {
            newArrayList.addAll(it.next().getResults());
        }
        resultsWriter.writeRow(newArrayList);
    }

    public String toString() {
        return String.format("test %s on %s", this.algorithmInfo, this.dataSet);
    }

    static {
        $assertionsDisabled = !TrainTestJob.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(TrainTestJob.class);
    }
}
