package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.io.Closeables;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nonnull;
import org.grouplens.lenskit.eval.AbstractCommand;
import org.grouplens.lenskit.eval.AlgorithmInstance;
import org.grouplens.lenskit.eval.CommandException;
import org.grouplens.lenskit.eval.IsolationLevel;
import org.grouplens.lenskit.eval.JobGroup;
import org.grouplens.lenskit.eval.JobGroupExecutor;
import org.grouplens.lenskit.eval.MergedJobGroupExecutor;
import org.grouplens.lenskit.eval.SequentialJobGroupExecutor;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.TestUserMetric;
import org.grouplens.lenskit.eval.util.table.TableImpl;
import org.grouplens.lenskit.util.tablewriter.CSVWriter;
import org.grouplens.lenskit.util.tablewriter.InMemoryWriter;
import org.grouplens.lenskit.util.tablewriter.MultiplexedTableWriter;
import org.grouplens.lenskit.util.tablewriter.TableLayout;
import org.grouplens.lenskit.util.tablewriter.TableLayoutBuilder;
import org.grouplens.lenskit.util.tablewriter.TableWriter;
import org.grouplens.lenskit.util.tablewriter.TableWriters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/TrainTestEvalCommand.class */
public class TrainTestEvalCommand extends AbstractCommand<TableImpl> {
    private static final Logger logger = LoggerFactory.getLogger(TrainTestEvalCommand.class);
    private List<TTDataSet> dataSources;
    private List<AlgorithmInstance> algorithms;
    private List<TestUserMetric> metrics;
    private IsolationLevel isolationLevel;
    private int nThread;
    private File outputFile;
    private File userOutputFile;
    private File predictOutputFile;
    private int numRecs;
    private int commonColumnCount;
    private TableLayout outputLayout;
    private TableLayout userLayout;
    private TableLayout predictLayout;
    private TableWriter output;
    private InMemoryWriter outputInMemory;
    private TableWriter userOutput;
    private TableWriter predictOutput;
    private List<JobGroup> jobGroups;
    private Map<String, Integer> dataColumns;
    private Map<String, Integer> algoColumns;
    private List<TestUserMetric> predictMetrics;

    public TrainTestEvalCommand() {
        this("Traintest");
    }

    public TrainTestEvalCommand(String str) {
        super(str);
        this.numRecs = 5;
        this.dataSources = new LinkedList();
        this.algorithms = new LinkedList();
        this.metrics = new LinkedList();
        this.outputFile = new File("train-test-results.csv");
        this.isolationLevel = IsolationLevel.NONE;
    }

    public TrainTestEvalCommand addDataset(TTDataSet tTDataSet) {
        this.dataSources.add(tTDataSet);
        return this;
    }

    public TrainTestEvalCommand addAlgorithm(AlgorithmInstance algorithmInstance) {
        this.algorithms.add(algorithmInstance);
        return this;
    }

    public TrainTestEvalCommand addMetric(TestUserMetric testUserMetric) {
        this.metrics.add(testUserMetric);
        return this;
    }

    public TrainTestEvalCommand setOutput(File file) {
        this.outputFile = file;
        return this;
    }

    public TrainTestEvalCommand setUserOutput(File file) {
        this.userOutputFile = file;
        return this;
    }

    public TrainTestEvalCommand setPredictOutput(File file) {
        this.predictOutputFile = file;
        return this;
    }

    public TrainTestEvalCommand setIsolation(IsolationLevel isolationLevel) {
        this.isolationLevel = isolationLevel;
        return this;
    }

    public TrainTestEvalCommand setThread(int i) {
        this.nThread = i;
        return this;
    }

    List<TTDataSet> dataSources() {
        return this.dataSources;
    }

    List<AlgorithmInstance> getAlgorithms() {
        return this.algorithms;
    }

    List<TestUserMetric> getMetrics() {
        return this.metrics;
    }

    File getOutput() {
        return this.outputFile;
    }

    File getPredictOutput() {
        return this.predictOutputFile;
    }

    public int getNumRecs() {
        return this.numRecs;
    }

    public TrainTestEvalCommand setNumRecs(int i) {
        this.numRecs = i;
        return this;
    }

    /* JADX WARN: Finally extract failed */
    @Override // org.grouplens.lenskit.eval.AbstractCommand, org.grouplens.lenskit.eval.Command, java.util.concurrent.Callable
    public TableImpl call() throws CommandException {
        JobGroupExecutor sequentialJobGroupExecutor;
        setupJobs();
        int i = this.nThread;
        if (i <= 0) {
            i = Runtime.getRuntime().availableProcessors();
        }
        logger.info("Starting evaluation");
        initialize();
        logger.info("Running evaluator with {} threads", Integer.valueOf(i));
        switch (this.isolationLevel) {
            case NONE:
                sequentialJobGroupExecutor = new MergedJobGroupExecutor(i);
                break;
            case JOB_GROUP:
                sequentialJobGroupExecutor = new SequentialJobGroupExecutor(i);
                break;
            default:
                throw new RuntimeException("Invalid isolation level " + this.isolationLevel);
        }
        Iterator<JobGroup> it = getJobGroups().iterator();
        while (it.hasNext()) {
            try {
                sequentialJobGroupExecutor.add(it.next());
            } catch (Throwable th) {
                logger.info("Finishing evaluation");
                cleanUp();
                throw th;
            }
        }
        try {
            sequentialJobGroupExecutor.run();
            logger.info("Finishing evaluation");
            cleanUp();
            return this.outputInMemory.getResult();
        } catch (ExecutionException e) {
            throw new CommandException("Error running the evaluation", e);
        }
    }

    protected void setupJobs() throws CommandException {
        TableLayoutBuilder tableLayoutBuilder = new TableLayoutBuilder();
        tableLayoutBuilder.addColumn("Algorithm");
        this.dataColumns = new HashMap();
        Iterator<TTDataSet> it = this.dataSources.iterator();
        while (it.hasNext()) {
            for (String str : it.next().getAttributes().keySet()) {
                if (!this.dataColumns.containsKey(str)) {
                    this.dataColumns.put(str, Integer.valueOf(tableLayoutBuilder.addColumn(str)));
                }
            }
        }
        this.algoColumns = new HashMap();
        Iterator<AlgorithmInstance> it2 = this.algorithms.iterator();
        while (it2.hasNext()) {
            for (String str2 : it2.next().getAttributes().keySet()) {
                if (!this.algoColumns.containsKey(str2)) {
                    this.algoColumns.put(str2, Integer.valueOf(tableLayoutBuilder.addColumn(str2)));
                }
            }
        }
        this.jobGroups = new ArrayList(this.dataSources.size());
        int i = 0;
        Iterator<TTDataSet> it3 = this.dataSources.iterator();
        while (it3.hasNext()) {
            this.jobGroups.add(new TrainTestEvalJobGroup(this, this.algorithms, this.metrics, it3.next(), i, this.numRecs));
            i++;
        }
        this.commonColumnCount = tableLayoutBuilder.getColumnCount();
        TableLayoutBuilder m64clone = tableLayoutBuilder.m64clone();
        m64clone.addColumn("BuildTime");
        m64clone.addColumn("TestTime");
        TableLayoutBuilder m64clone2 = tableLayoutBuilder.m64clone();
        for (TestUserMetric testUserMetric : this.metrics) {
            String[] columnLabels = testUserMetric.getColumnLabels();
            if (columnLabels != null) {
                for (String str3 : columnLabels) {
                    m64clone.addColumn(str3);
                }
            }
            String[] userColumnLabels = testUserMetric.getUserColumnLabels();
            if (userColumnLabels != null) {
                for (String str4 : userColumnLabels) {
                    m64clone2.addColumn(str4);
                }
            }
        }
        this.outputLayout = m64clone.m65build();
        this.userLayout = m64clone2.m65build();
        TableLayoutBuilder m64clone3 = tableLayoutBuilder.m64clone();
        m64clone3.addColumn("User");
        m64clone3.addColumn("Item");
        m64clone3.addColumn("Rating");
        m64clone3.addColumn("Prediction");
        this.predictLayout = m64clone3.m65build();
        this.predictMetrics = this.metrics;
    }

    private void initialize() {
        logger.info("Starting evaluation");
        ArrayList arrayList = new ArrayList();
        this.outputInMemory = new InMemoryWriter(this.outputLayout);
        arrayList.add(this.outputInMemory);
        if (this.outputFile != null) {
            try {
                arrayList.add(CSVWriter.open(this.outputFile, this.outputLayout));
            } catch (IOException e) {
                throw new RuntimeException("Error opening output table", e);
            }
        }
        this.output = new MultiplexedTableWriter(this.outputLayout, arrayList);
        if (this.userOutputFile != null) {
            try {
                this.userOutput = CSVWriter.open(this.userOutputFile, this.userLayout);
            } catch (IOException e2) {
                Closeables.closeQuietly(this.output);
                throw new RuntimeException("Error opening user output table", e2);
            }
        }
        if (this.predictOutputFile != null) {
            try {
                this.predictOutput = CSVWriter.open(this.predictOutputFile, this.predictLayout);
            } catch (IOException e3) {
                Closeables.closeQuietly(this.userOutput);
                Closeables.closeQuietly(this.output);
                throw new RuntimeException("Error opening prediction table", e3);
            }
        }
        Iterator<TestUserMetric> it = this.predictMetrics.iterator();
        while (it.hasNext()) {
            it.next().startEvaluation(this);
        }
    }

    private void cleanUp() {
        Iterator<TestUserMetric> it = this.predictMetrics.iterator();
        while (it.hasNext()) {
            it.next().finishEvaluation();
        }
        if (this.output == null) {
            throw new IllegalStateException("evaluation not running");
        }
        logger.info("Evaluation finished");
        try {
            try {
                this.output.close();
                if (this.userOutput != null) {
                    this.userOutput.close();
                }
                if (this.predictOutput != null) {
                    this.predictOutput.close();
                }
            } catch (IOException e) {
                throw new RuntimeException("Error closing output", e);
            }
        } finally {
            this.output = null;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Nonnull
    public Supplier<TableWriter> outputTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalCommand.1
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m55get() {
                Preconditions.checkState(TrainTestEvalCommand.this.output != null, "evaluation not running");
                return TrainTestEvalCommand.this.output;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Nonnull
    public Supplier<TableWriter> predictTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalCommand.2
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m56get() {
                return TrainTestEvalCommand.this.predictOutput;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Nonnull
    public Supplier<TableWriter> userTableSupplier() {
        return new Supplier<TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalCommand.3
            /* renamed from: get, reason: merged with bridge method [inline-methods] */
            public TableWriter m57get() {
                return TrainTestEvalCommand.this.userOutput;
            }
        };
    }

    @Nonnull
    public List<JobGroup> getJobGroups() {
        return this.jobGroups;
    }

    public Function<TableWriter, TableWriter> prefixFunction(final AlgorithmInstance algorithmInstance, final TTDataSet tTDataSet) {
        return new Function<TableWriter, TableWriter>() { // from class: org.grouplens.lenskit.eval.traintest.TrainTestEvalCommand.4
            public TableWriter apply(TableWriter tableWriter) {
                return TrainTestEvalCommand.this.prefixTable(tableWriter, algorithmInstance, tTDataSet);
            }
        };
    }

    public TableWriter prefixTable(TableWriter tableWriter, AlgorithmInstance algorithmInstance, TTDataSet tTDataSet) {
        if (tableWriter == null) {
            return null;
        }
        Object[] objArr = new Object[this.commonColumnCount];
        objArr[0] = algorithmInstance.getName();
        for (Map.Entry<String, Object> entry : tTDataSet.getAttributes().entrySet()) {
            objArr[this.dataColumns.get(entry.getKey()).intValue()] = entry.getValue();
        }
        for (Map.Entry<String, Object> entry2 : algorithmInstance.getAttributes().entrySet()) {
            objArr[this.algoColumns.get(entry2.getKey()).intValue()] = entry2.getValue();
        }
        return TableWriters.prefixed(tableWriter, objArr);
    }
}
