package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.inject.Provider;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.data.snapshot.PreferenceSnapshot;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.ExternalAlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.LenskitAlgorithmInstance;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.parallel.TaskGroupRunner;
import org.grouplens.lenskit.util.table.Table;
import org.grouplens.lenskit.util.table.TableBuilder;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.TableLayoutBuilder;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.MultiplexedTableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalTask.class */
public class TrainTestEvalTask extends AbstractTask<Table> {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalTask.class);
    private List<TTDataSet> dataSets;
    private List<AlgorithmInstance> algorithms;
    private List<TestUserMetric> metrics;
    private List<Pair<Symbol, String>> predictChannels;
    private boolean isolate;
    private File outputFile;
    private File userOutputFile;
    private File predictOutputFile;
    private int numRecs;
    private int commonColumnCount;
    private TableLayout outputLayout;
    private TableLayout userLayout;
    private TableLayout predictLayout;
    private TableWriter output;
    private TableBuilder outputInMemory;
    private TableWriter userOutput;
    private TableWriter predictOutput;
    private Map<String, Integer> dataColumns;
    private Map<String, Integer> algoColumns;
    private List<TestUserMetric> predictMetrics;
    private TableLayout masterLayout;
    private List<ModelMetric> modelMetrics;

    public TrainTestEvalTask() {
        this("train-test");
    }

    public TrainTestEvalTask(String str) {
        super(str);
        this.numRecs = 5;
        this.dataSets = new LinkedList();
        this.algorithms = new LinkedList();
        this.metrics = new LinkedList();
        this.modelMetrics = new LinkedList();
        this.predictChannels = new LinkedList();
        this.outputFile = new File("train-test-results.csv");
        this.isolate = false;
    }

    public TrainTestEvalTask addDataset(TTDataSet tTDataSet) {
        this.dataSets.add(tTDataSet);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(LenskitAlgorithmInstance lenskitAlgorithmInstance) {
        this.algorithms.add(lenskitAlgorithmInstance);
        return this;
    }

    public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithmInstance externalAlgorithmInstance) {
        this.algorithms.add(externalAlgorithmInstance);
        return this;
    }

    public TrainTestEvalTask addMetric(TestUserMetric testUserMetric) {
        this.metrics.add(testUserMetric);
        return this;
    }

    public TrainTestEvalTask addMetric(Class<? extends TestUserMetric> cls) throws IllegalAccessException, InstantiationException {
        return addMetric(cls.newInstance());
    }

    public TrainTestEvalTask addMultiMetric(File file, List<String> list, Function<Recommender, List<List<Object>>> function) {
        this.modelMetrics.add(new FunctionMultiModelMetric(file, list, function));
        return this;
    }

    public TrainTestEvalTask addMetric(List<String> list, Function<Recommender, List<Object>> function) {
        this.modelMetrics.add(new FunctionModelMetric(list, function));
        return this;
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol symbol) {
        return addWritePredictionChannel(symbol, null);
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol symbol, @Nullable String str) {
        Preconditions.checkNotNull(symbol, "channel is null");
        if (str == null) {
            str = symbol.getName();
        }
        this.predictChannels.add(Pair.of(symbol, str));
        return this;
    }

    public TrainTestEvalTask setOutput(File file) {
        this.outputFile = file;
        return this;
    }

    public TrainTestEvalTask setOutput(String str) {
        return setOutput(new File(str));
    }

    public TrainTestEvalTask setUserOutput(File file) {
        this.userOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setUserOutput(String str) {
        return setUserOutput(new File(str));
    }

    public TrainTestEvalTask setPredictOutput(File file) {
        this.predictOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setPredictOutput(String str) {
        return setPredictOutput(new File(str));
    }

    public TrainTestEvalTask setIsolate(boolean z) {
        this.isolate = z;
        return this;
    }

    List<TTDataSet> dataSources() {
        return this.dataSets;
    }

    List<AlgorithmInstance> getAlgorithms() {
        return this.algorithms;
    }

    List<TestUserMetric> getMetrics() {
        return this.metrics;
    }

    List<Pair<Symbol, String>> getPredictionChannels() {
        return this.predictChannels;
    }

    File getOutput() {
        return this.outputFile;
    }

    File getPredictOutput() {
        return this.predictOutputFile;
    }

    public int getNumRecs() {
        return this.numRecs;
    }

    public TrainTestEvalTask setNumRecs(int i) {
        this.numRecs = i;
        return this;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Finally extract failed */
    @Override // org.grouplens.lenskit.eval.AbstractTask
    public Table perform() throws TaskExecutionException {
        List<List<TrainTestEvalJob>> makeJobGroups = makeJobGroups();
        setupTableLayouts();
        logger.info("Starting evaluation");
        Closer create = Closer.create();
        try {
            try {
                try {
                    prepareEval(create);
                    try {
                        runEvaluations(makeJobGroups);
                        cleanUp();
                        create.close();
                        return this.outputInMemory.m52build();
                    } catch (Throwable th) {
                        cleanUp();
                        throw th;
                    }
                } catch (Throwable th2) {
                    create.close();
                    throw th2;
                }
            } catch (Throwable th3) {
                throw create.rethrow(th3, TaskExecutionException.class);
            }
        } catch (IOException e) {
            throw new TaskExecutionException("I/O error", e);
        }
    }

    private void runEvaluations(List<List<TrainTestEvalJob>> list) throws TaskExecutionException {
        int threadCount = getProject().getConfig().getThreadCount();
        logger.info("Running evaluator with {} threads", Integer.valueOf(threadCount));
        if (threadCount == 1) {
            Iterator it = Iterables.concat(list).iterator();
            while (it.hasNext()) {
                try {
                    ((TrainTestEvalJob) it.next()).run();
                } catch (TrainTestJobException e) {
                    throw new TaskExecutionException(e.getCause());
                } catch (RuntimeException e2) {
                    throw new TaskExecutionException(e2);
                }
            }
            return;
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(threadCount);
        try {
            for (List<TrainTestEvalJob> list2 : list) {
                TaskGroupRunner create = TaskGroupRunner.create(newFixedThreadPool);
                create.submitAll(list2);
                try {
                    create.waitForAll();
                } catch (ExecutionException e3) {
                    Throwable cause = e3.getCause();
                    if (cause instanceof TrainTestJobException) {
                        cause = cause.getCause();
                    }
                    throw new TaskExecutionException(cause);
                } catch (TrainTestJobException e4) {
                    throw new TaskExecutionException(e4);
                }
            }
        } finally {
            newFixedThreadPool.shutdown();
        }
    }

    List<List<TrainTestEvalJob>> makeJobGroups() {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<TTDataSet> it = this.dataSets.iterator();
        while (it.hasNext()) {
            newArrayList.add(makeJobs(it.next()));
        }
        return !this.isolate ? Collections.singletonList(Lists.newArrayList(Iterables.concat(newArrayList))) : newArrayList;
    }

    private List<TrainTestEvalJob> makeJobs(TTDataSet tTDataSet) {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(this.algorithms.size());
        Provider<PreferenceSnapshot> provider = SharedPreferenceSnapshot.provider(tTDataSet);
        for (AlgorithmInstance algorithmInstance : this.algorithms) {
            Function<TableWriter, TableWriter> prefixFunction = prefixFunction(algorithmInstance, tTDataSet);
            newArrayListWithCapacity.add(new TrainTestEvalJob(algorithmInstance, this.metrics, this.modelMetrics, this.predictChannels, tTDataSet, provider, Suppliers.compose(prefixFunction, outputTableSupplier()), Suppliers.compose(prefixFunction, userTableSupplier()), Suppliers.compose(prefixFunction, predictTableSupplier()), this.numRecs));
        }
        return newArrayListWithCapacity;
    }

    private void setupTableLayouts() {
        TableLayoutBuilder tableLayoutBuilder = new TableLayoutBuilder();
        layoutCommonColumns(tableLayoutBuilder);
        this.masterLayout = tableLayoutBuilder.m54build();
        this.commonColumnCount = tableLayoutBuilder.getColumnCount();
        this.outputLayout = layoutAggregateOutput(tableLayoutBuilder);
        this.userLayout = layoutUserTable(tableLayoutBuilder);
        this.predictLayout = layoutPredictionTable(tableLayoutBuilder);
        this.predictMetrics = this.metrics;
    }

    public TableLayout getMasterLayout() {
        return this.masterLayout;
    }

    private void layoutCommonColumns(TableLayoutBuilder tableLayoutBuilder) {
        tableLayoutBuilder.addColumn("Algorithm");
        this.dataColumns = new HashMap();
        Iterator<TTDataSet> it = this.dataSets.iterator();
        while (it.hasNext()) {
            for (String str : it.next().getAttributes().keySet()) {
                if (!this.dataColumns.containsKey(str)) {
                    this.dataColumns.put(str, Integer.valueOf(tableLayoutBuilder.getColumnCount()));
                    tableLayoutBuilder.addColumn(str);
                }
            }
        }
        this.algoColumns = new HashMap();
        Iterator<AlgorithmInstance> it2 = this.algorithms.iterator();
        while (it2.hasNext()) {
            for (String str2 : it2.next().getAttributes().keySet()) {
                if (!this.algoColumns.containsKey(str2)) {
                    this.algoColumns.put(str2, Integer.valueOf(tableLayoutBuilder.getColumnCount()));
                    tableLayoutBuilder.addColumn(str2);
                }
            }
        }
    }

    private TableLayout layoutAggregateOutput(TableLayoutBuilder tableLayoutBuilder) {
        TableLayoutBuilder m53clone = tableLayoutBuilder.m53clone();
        m53clone.addColumn("BuildTime");
        m53clone.addColumn("TestTime");
        Iterator<ModelMetric> it = this.modelMetrics.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().getColumnLabels().iterator();
            while (it2.hasNext()) {
                m53clone.addColumn(it2.next());
            }
        }
        Iterator<TestUserMetric> it3 = this.metrics.iterator();
        while (it3.hasNext()) {
            List<String> columnLabels = it3.next().getColumnLabels();
            if (columnLabels != null) {
                Iterator<String> it4 = columnLabels.iterator();
                while (it4.hasNext()) {
                    m53clone.addColumn(it4.next());
                }
            }
        }
        return m53clone.m54build();
    }

    private TableLayout layoutUserTable(TableLayoutBuilder tableLayoutBuilder) {
        TableLayoutBuilder m53clone = tableLayoutBuilder.m53clone();
        m53clone.addColumn("User");
        Iterator<TestUserMetric> it = this.metrics.iterator();
        while (it.hasNext()) {
            List<String> userColumnLabels = it.next().getUserColumnLabels();
            if (userColumnLabels != null) {
                Iterator<String> it2 = userColumnLabels.iterator();
                while (it2.hasNext()) {
                    m53clone.addColumn(it2.next());
                }
            }
        }
        return m53clone.m54build();
    }

    private TableLayout layoutPredictionTable(TableLayoutBuilder tableLayoutBuilder) {
        TableLayoutBuilder m53clone = tableLayoutBuilder.m53clone();
        m53clone.addColumn("User");
        m53clone.addColumn("Item");
        m53clone.addColumn("Rating");
        m53clone.addColumn("Prediction");
        Iterator<Pair<Symbol, String>> it = this.predictChannels.iterator();
        while (it.hasNext()) {
            m53clone.addColumn((String) it.next().getRight());
        }
        return m53clone.m54build();
    }

    private void prepareEval(Closer closer) throws IOException {
        logger.info("Starting evaluation");
        ArrayList arrayList = new ArrayList();
        this.outputInMemory = new TableBuilder(this.outputLayout);
        arrayList.add(this.outputInMemory);
        if (this.outputFile != null) {
            arrayList.add(closer.register(CSVWriter.open(this.outputFile, this.outputLayout)));
        }
        this.output = new MultiplexedTableWriter(this.outputLayout, arrayList);
        if (this.userOutputFile != null) {
            this.userOutput = (TableWriter) closer.register(CSVWriter.open(this.userOutputFile, this.userLayout));
        }
        if (this.predictOutputFile != null) {
            this.predictOutput = (TableWriter) closer.register(CSVWriter.open(this.predictOutputFile, this.predictLayout));
        }
        Iterator it = Iterables.concat(this.predictMetrics, this.modelMetrics).iterator();
        while (it.hasNext()) {
            ((Metric) it.next()).startEvaluation(this);
        }
    }

    private void cleanUp() throws IOException {
        Iterator it = Iterables.concat(this.predictMetrics, this.modelMetrics).iterator();
        while (it.hasNext()) {
            ((Metric) it.next()).finishEvaluation();
        }
        if (this.output == null) {
            throw new IllegalStateException("evaluation not running");
        }
        logger.info("Evaluation finished");
        this.output = null;
        this.userOutput = null;
        this.predictOutput = null;
    }

    @Nonnull
    Supplier<TableWriter> outputTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalTask.1
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m48get() {
                Preconditions.checkState(TrainTestEvalTask.this.output != null, "evaluation not running");
                return TrainTestEvalTask.this.output;
            }
        };
    }

    @Nonnull
    Supplier<TableWriter> predictTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalTask.2
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m49get() {
                return TrainTestEvalTask.this.predictOutput;
            }
        };
    }

    @Nonnull
    Supplier<TableWriter> userTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalTask.3
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m50get() {
                return TrainTestEvalTask.this.userOutput;
            }
        };
    }

    public Function<TableWriter, TableWriter> prefixFunction(final AlgorithmInstance algorithmInstance, final TTDataSet tTDataSet) {
        return new Function<TableWriter, TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalTask.4
            public TableWriter apply(TableWriter tableWriter) {
                return TrainTestEvalTask.this.prefixTable(tableWriter, algorithmInstance, tTDataSet);
            }
        };
    }

    public TableWriter prefixTable(TableWriter tableWriter, AlgorithmInstance algorithmInstance, TTDataSet tTDataSet) {
        if (tableWriter == null) {
            return null;
        }
        Object[] objArr = new Object[this.commonColumnCount];
        objArr[0] = algorithmInstance.getName();
        for (Map.Entry<String, Object> entry : tTDataSet.getAttributes().entrySet()) {
            objArr[this.dataColumns.get(entry.getKey()).intValue()] = entry.getValue();
        }
        for (Map.Entry<String, Object> entry2 : algorithmInstance.getAttributes().entrySet()) {
            objArr[this.algoColumns.get(entry2.getKey()).intValue()] = entry2.getValue();
        }
        return TableWriters.prefixed(tableWriter, objArr);
    }
}
