package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.grouplens.grapht.Component;
import org.grouplens.grapht.Dependency;
import org.grouplens.grapht.graph.DAGNode;
import org.grouplens.grapht.graph.MergePool;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.core.RecommenderConfigurationException;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.ExecutionInfo;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstance;
import org.grouplens.lenskit.eval.algorithm.AlgorithmInstanceBuilder;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.Metric;
import org.grouplens.lenskit.eval.traintest.FunctionMultiModelMetric;
import org.grouplens.lenskit.eval.traintest.JobGraph;
import org.grouplens.lenskit.eval.traintest.OutputPredictMetric;
import org.grouplens.lenskit.eval.traintest.OutputTopNMetric;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.util.parallel.TaskGraphExecutor;
import org.grouplens.lenskit.util.table.Table;
import org.grouplens.lenskit.util.table.TableBuilder;
import org.grouplens.lenskit.util.table.TableLayout;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.MultiplexedTableWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalTask.class */
public class TrainTestEvalTask extends AbstractTask<Table> {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalTask.class);
    private List<TTDataSet> dataSets;
    private List<AlgorithmInstance> algorithms;
    private List<ExternalAlgorithm> externalAlgorithms;
    private List<MetricFactory> metrics;
    private List<Pair<Symbol, String>> predictChannels;
    private boolean isolate;
    private boolean separateAlgorithms;
    private File outputFile;
    private File userOutputFile;
    private File predictOutputFile;
    private File recommendOutputFile;
    private File cacheDir;
    private File taskGraphFile;
    private File taskStatusFile;
    private boolean cacheAll;
    private ExperimentSuite experiments;
    private MeasurementSuite measurements;
    private ExperimentOutputLayout layout;
    private ExperimentOutputs outputs;

    public TrainTestEvalTask() {
        this("train-test");
    }

    public TrainTestEvalTask(String str) {
        super(str);
        this.cacheAll = false;
        this.dataSets = Lists.newArrayList();
        this.algorithms = Lists.newArrayList();
        this.externalAlgorithms = Lists.newArrayList();
        this.metrics = Lists.newArrayList();
        this.predictChannels = Lists.newArrayList();
        this.outputFile = new File("train-test-results.csv");
        this.isolate = false;
    }

    public TrainTestEvalTask addDataset(TTDataSet tTDataSet) {
        this.dataSets.add(tTDataSet);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(AlgorithmInstance algorithmInstance) {
        this.algorithms.add(algorithmInstance);
        return this;
    }

    public TrainTestEvalTask addAlgorithm(Map<String, Object> map, String str) throws IOException, RecommenderConfigurationException {
        this.algorithms.add(new AlgorithmInstanceBuilder().setProject(getProject()).configureFromFile(map, new File(str)).m6build());
        return this;
    }

    public TrainTestEvalTask addExternalAlgorithm(ExternalAlgorithm externalAlgorithm) {
        this.externalAlgorithms.add(externalAlgorithm);
        return this;
    }

    public TrainTestEvalTask addMetric(Metric metric) {
        this.metrics.add(MetricFactory.forMetric(metric));
        return this;
    }

    public TrainTestEvalTask addMetric(Class<? extends Metric> cls) throws IllegalAccessException, InstantiationException {
        return addMetric(cls.newInstance());
    }

    public TrainTestEvalTask addMultiMetric(File file, List<String> list, Function<Recommender, List<List<Object>>> function) {
        this.metrics.add(new FunctionMultiModelMetric.Factory(file, list, function));
        return this;
    }

    public TrainTestEvalTask addMetric(List<String> list, Function<Recommender, List<Object>> function) {
        addMetric(new FunctionModelMetric(list, function));
        return this;
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol symbol) {
        return addWritePredictionChannel(symbol, null);
    }

    public TrainTestEvalTask addWritePredictionChannel(@Nonnull Symbol symbol, @Nullable String str) {
        Preconditions.checkNotNull(symbol, "channel is null");
        if (str == null) {
            str = symbol.getName();
        }
        this.predictChannels.add(Pair.of(symbol, str));
        return this;
    }

    public TrainTestEvalTask setOutput(File file) {
        this.outputFile = file;
        return this;
    }

    public TrainTestEvalTask setOutput(String str) {
        return setOutput(new File(str));
    }

    public TrainTestEvalTask setUserOutput(File file) {
        this.userOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setUserOutput(String str) {
        return setUserOutput(new File(str));
    }

    public TrainTestEvalTask setPredictOutput(File file) {
        this.predictOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setPredictOutput(String str) {
        return setPredictOutput(new File(str));
    }

    public TrainTestEvalTask setRecommendOutput(File file) {
        this.recommendOutputFile = file;
        return this;
    }

    public TrainTestEvalTask setRecommendOutput(String str) {
        return setRecommendOutput(new File(str));
    }

    public TrainTestEvalTask setComponentCacheDirectory(File file) {
        this.cacheDir = file;
        return this;
    }

    public TrainTestEvalTask setComponentCacheDirectory(String str) {
        return setComponentCacheDirectory(new File(str));
    }

    public File getComponentCacheDirectory() {
        return this.cacheDir;
    }

    public TrainTestEvalTask setCacheAllComponents(boolean z) {
        this.cacheAll = z;
        return this;
    }

    public boolean getCacheAllComponents() {
        return this.cacheAll;
    }

    @Deprecated
    public TrainTestEvalTask setIsolate(boolean z) {
        logger.warn("Eval task isolation is deprecated. Isolate data sets instead.");
        this.isolate = z;
        return this;
    }

    public TrainTestEvalTask setSeparateAlgorithms(boolean z) {
        this.separateAlgorithms = z;
        return this;
    }

    List<TTDataSet> dataSources() {
        return this.dataSets;
    }

    List<AlgorithmInstance> getAlgorithms() {
        return this.algorithms;
    }

    List<ExternalAlgorithm> getExternalAlgorithms() {
        return this.externalAlgorithms;
    }

    List<MetricFactory> getMetricFactories() {
        return this.metrics;
    }

    List<Pair<Symbol, String>> getPredictionChannels() {
        return this.predictChannels;
    }

    File getOutput() {
        return this.outputFile;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public File getPredictOutput() {
        return this.predictOutputFile;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public File getRecommendOutput() {
        return this.recommendOutputFile;
    }

    public TrainTestEvalTask setTaskGraphFile(File file) {
        this.taskGraphFile = file;
        return this;
    }

    public TrainTestEvalTask setTaskGraphFile(String str) {
        return setTaskGraphFile(new File(str));
    }

    public File getTaskGraphFile() {
        return this.taskGraphFile;
    }

    public TrainTestEvalTask setTaskStatusFile(File file) {
        this.taskStatusFile = file;
        return this;
    }

    public TrainTestEvalTask setTaskStatusFile(String str) {
        return setTaskStatusFile(new File(str));
    }

    public File getTaskStatusFile() {
        return this.taskStatusFile;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Failed to calculate best type for var: r9v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r9v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.applyWithWiderIgnSame(TypeUpdate.java:70)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.applyResolvedVars(TypeSearch.java:100)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:76)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 9, insn: 0x00db: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r9 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:28:0x00db */
    @Override // org.grouplens.lenskit.eval.AbstractTask
    public Table perform() throws TaskExecutionException, InterruptedException {
        Closer closer;
        try {
            try {
                try {
                    this.experiments = createExperimentSuite();
                    this.measurements = createMeasurementSuite();
                    this.layout = ExperimentOutputLayout.create(this.experiments, this.measurements);
                    TableBuilder tableBuilder = new TableBuilder(this.layout.getResultsLayout());
                    logger.info("Starting evaluation of {} algorithms ({} from LensKit) on {} data sets", new Object[]{Integer.valueOf(Iterables.size(this.experiments.getAllAlgorithms())), Integer.valueOf(this.experiments.getAlgorithms().size()), Integer.valueOf(this.experiments.getDataSets().size())});
                    Closer create = Closer.create();
                    try {
                        this.outputs = openExperimentOutputs(this.layout, this.measurements, tableBuilder, create);
                        try {
                            DAGNode<JobGraph.Node, JobGraph.Edge> makeJobGraph = makeJobGraph(this.experiments);
                            if (this.taskGraphFile != null) {
                                logger.info("writing task graph to {}", this.taskGraphFile);
                                JobGraph.writeGraphDescription(makeJobGraph, this.taskGraphFile);
                            }
                            registerTaskListener(makeJobGraph);
                            runEvaluations(makeJobGraph);
                            create.close();
                            logger.info("evaluation {} completed", getName());
                            Table m71build = tableBuilder.m71build();
                            this.experiments = null;
                            this.measurements = null;
                            this.outputs = null;
                            this.layout = null;
                            return m71build;
                        } catch (RecommenderConfigurationException e) {
                            throw new TaskExecutionException("Recommender configuration error", e);
                        }
                    } catch (Throwable th) {
                        throw create.rethrow(th, TaskExecutionException.class, InterruptedException.class);
                    }
                } catch (Throwable th2) {
                    this.experiments = null;
                    this.measurements = null;
                    this.outputs = null;
                    this.layout = null;
                    throw th2;
                }
            } catch (Throwable th3) {
                closer.close();
                throw th3;
            }
        } catch (IOException e2) {
            throw new TaskExecutionException("I/O error", e2);
        }
    }

    private void registerTaskListener(DAGNode<JobGraph.Node, JobGraph.Edge> dAGNode) {
        if (this.taskStatusFile != null) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            Iterator it = dAGNode.getReachableNodes().iterator();
            while (it.hasNext()) {
                TrainTestJob job = ((JobGraph.Node) ((DAGNode) it.next()).getLabel()).getJob();
                if (job != null) {
                    builder.add(job);
                }
            }
            getProject().getEventBus().register(new JobStatusWriter(this, builder.build(), this.taskStatusFile));
        }
    }

    public ExperimentSuite getExperiments() {
        Preconditions.checkState(this.experiments != null, "evaluation not in progress");
        return this.experiments;
    }

    public MeasurementSuite getMeasurements() {
        Preconditions.checkState(this.measurements != null, "evaluation not in progress");
        return this.measurements;
    }

    public ExperimentOutputLayout getOutputLayout() {
        Preconditions.checkState(this.layout != null, "evaluation not in progress");
        return this.layout;
    }

    public ExperimentOutputs getOutputs() {
        Preconditions.checkState(this.outputs != null, "evaluation not in progress");
        return this.outputs;
    }

    ExperimentSuite createExperimentSuite() {
        return new ExperimentSuite(this.algorithms, this.externalAlgorithms, this.dataSets);
    }

    MeasurementSuite createMeasurementSuite() {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.addAll(this.metrics);
        if (this.recommendOutputFile != null) {
            builder.add(new OutputTopNMetric.Factory());
        }
        if (this.predictOutputFile != null) {
            builder.add(new OutputPredictMetric.Factory(this.predictChannels));
        }
        return new MeasurementSuite(builder.build());
    }

    private void runEvaluations(DAGNode<JobGraph.Node, JobGraph.Edge> dAGNode) throws TaskExecutionException, InterruptedException {
        int threadCount = getProject().getConfig().getThreadCount();
        logger.info("Running evaluator with {} threads", Integer.valueOf(threadCount));
        try {
            (threadCount == 1 ? TaskGraphExecutor.singleThreaded() : TaskGraphExecutor.create(threadCount)).execute(dAGNode);
        } catch (ExecutionException e) {
            Throwables.propagateIfInstanceOf(e.getCause(), TaskExecutionException.class);
            throw new TaskExecutionException("error in evaluation job task", e.getCause());
        }
    }

    DAGNode<JobGraph.Node, JobGraph.Edge> makeJobGraph(ExperimentSuite experimentSuite) throws RecommenderConfigurationException {
        LinkedHashMultimap create = LinkedHashMultimap.create();
        for (TTDataSet tTDataSet : experimentSuite.getDataSets()) {
            UUID isolationGroup = tTDataSet.getIsolationGroup();
            if (this.isolate) {
                isolationGroup = UUID.randomUUID();
            }
            create.put(isolationGroup, tTDataSet);
        }
        ComponentCache componentCache = new ComponentCache(this.cacheDir, getProject().getClassLoader());
        JobGraphBuilder jobGraphBuilder = new JobGraphBuilder(this, componentCache);
        for (UUID uuid : create.keySet()) {
            Collection<TTDataSet> collection = create.get(uuid);
            String name = collection.size() == 1 ? ((TTDataSet) collection.iterator().next()).getName() : uuid.toString();
            for (TTDataSet tTDataSet2 : collection) {
                addAlgorithmNodes(jobGraphBuilder, tTDataSet2, experimentSuite.getAlgorithms(), componentCache);
                Iterator<ExternalAlgorithm> it = experimentSuite.getExternalAlgorithms().iterator();
                while (it.hasNext()) {
                    jobGraphBuilder.addExternalJob(it.next(), tTDataSet2);
                }
            }
            jobGraphBuilder.fence(name);
        }
        return jobGraphBuilder.getGraph();
    }

    private void addAlgorithmNodes(JobGraphBuilder jobGraphBuilder, TTDataSet tTDataSet, List<AlgorithmInstance> list, ComponentCache componentCache) throws RecommenderConfigurationException {
        MergePool create = MergePool.create();
        for (AlgorithmInstance algorithmInstance : list) {
            logger.debug("building graph for algorithm {}", algorithmInstance);
            LenskitConfiguration lenskitConfiguration = new LenskitConfiguration();
            lenskitConfiguration.addComponent(ExecutionInfo.newBuilder().setAlgorithm(algorithmInstance).setDataSet(tTDataSet).m2build());
            tTDataSet.configure(lenskitConfiguration);
            DAGNode<Component, Dependency> buildRecommenderGraph = algorithmInstance.buildRecommenderGraph(lenskitConfiguration);
            if (!this.separateAlgorithms) {
                logger.debug("merging algorithm {} with previous graphs", algorithmInstance);
                buildRecommenderGraph = create.merge(buildRecommenderGraph);
            }
            if (this.cacheAll && componentCache != null) {
                componentCache.registerSharedNodes(buildRecommenderGraph.getReachableNodes());
            }
            jobGraphBuilder.addLenskitJob(algorithmInstance, tTDataSet, buildRecommenderGraph);
        }
    }

    ExperimentOutputs openExperimentOutputs(ExperimentOutputLayout experimentOutputLayout, MeasurementSuite measurementSuite, TableWriter tableWriter, Closer closer) throws IOException {
        TableLayout resultsLayout = experimentOutputLayout.getResultsLayout();
        TableWriter tableWriter2 = tableWriter;
        if (this.outputFile != null) {
            tableWriter2 = new MultiplexedTableWriter(resultsLayout, tableWriter2, (TableWriter) closer.register(CSVWriter.open(this.outputFile, resultsLayout)));
        }
        TableWriter tableWriter3 = null;
        if (this.userOutputFile != null) {
            tableWriter3 = (TableWriter) closer.register(CSVWriter.open(this.userOutputFile, experimentOutputLayout.getUserLayout()));
        }
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<MetricFactory> it = measurementSuite.getMetricFactories().iterator();
        while (it.hasNext()) {
            newArrayList.add(closer.register(it.next().createMetric2(this)));
        }
        return new ExperimentOutputs(experimentOutputLayout, tableWriter2, tableWriter3, newArrayList);
    }
}
