package org.grouplens.lenskit.eval.traintest;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.longs.Long2DoubleMap;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.grouplens.lenskit.Recommender;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.collections.CollectionUtils;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.EventDAO;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.eval.data.CSVDataSource;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataSet;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.eval.metrics.topn.ItemSelector;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.util.DelimitedTextCursor;
import org.grouplens.lenskit.util.io.LoggingStreamSlurper;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/eval/traintest/ExternalEvalJob.class */
public class ExternalEvalJob extends TrainTestJob {
    private static final Logger logger = LoggerFactory.getLogger(ExternalEvalJob.class);
    private final ExternalAlgorithm algorithm;
    private final UUID key;
    private Long2ObjectMap<SparseVector> userPredictions;
    private UserEventDAO userTrainingEvents;
    private UserEventDAO userTestEvents;

    /* JADX WARN: Classes with same name are omitted:
      
     */
    /* loaded from: input_file:org/grouplens/lenskit/eval/traintest/ExternalEvalJob$TestUserImpl.class */
    class TestUserImpl extends AbstractTestUser {
        private final long userId;

        public TestUserImpl(long j) {
            this.userId = j;
        }

        @Override // org.grouplens.lenskit.eval.traintest.TestUser
        public UserHistory<Event> getTrainHistory() {
            UserHistory<Event> eventsForUser = ExternalEvalJob.this.userTrainingEvents.getEventsForUser(this.userId);
            return eventsForUser == null ? History.forUser(this.userId) : eventsForUser;
        }

        @Override // org.grouplens.lenskit.eval.traintest.TestUser
        public UserHistory<Event> getTestHistory() {
            return ExternalEvalJob.this.userTestEvents.getEventsForUser(this.userId);
        }

        @Override // org.grouplens.lenskit.eval.traintest.TestUser
        public SparseVector getPredictions() {
            return (SparseVector) ExternalEvalJob.this.userPredictions.get(this.userId);
        }

        @Override // org.grouplens.lenskit.eval.traintest.TestUser
        public List<ScoredId> getRecommendations(int i, ItemSelector itemSelector, ItemSelector itemSelector2) {
            return null;
        }

        @Override // org.grouplens.lenskit.eval.traintest.TestUser
        /* renamed from: getRecommender */
        public Recommender mo66getRecommender() {
            return null;
        }
    }

    public ExternalEvalJob(TrainTestEvalTask trainTestEvalTask, @Nonnull ExternalAlgorithm externalAlgorithm, @Nonnull TTDataSet tTDataSet, @Nonnull MeasurementSuite measurementSuite, @Nonnull ExperimentOutputs experimentOutputs) {
        super(trainTestEvalTask, externalAlgorithm, tTDataSet, measurementSuite, experimentOutputs);
        this.algorithm = externalAlgorithm;
        this.key = UUID.randomUUID();
    }

    @Override // org.grouplens.lenskit.eval.traintest.TrainTestJob
    protected void buildRecommender() throws RecommenderBuildException {
        Preconditions.checkState(this.userPredictions == null, "recommender already built");
        File stagingDir = getStagingDir();
        logger.info("using output/staging directory {}", stagingDir);
        if (!stagingDir.exists()) {
            logger.info("creating directory {}", stagingDir);
            stagingDir.mkdirs();
        }
        try {
            final File trainingFile = trainingFile(this.dataSet);
            try {
                final File testFile = testFile(this.dataSet);
                final File file = getFile("predictions.csv");
                List<String> transform = Lists.transform(this.algorithm.getCommand(), new Function<String, String>() { // from class: org.grouplens.lenskit.eval.traintest.ExternalEvalJob.1
                    @Nullable
                    public String apply(@Nullable String str) {
                        if (str == null) {
                            throw new NullPointerException("command element");
                        }
                        return str.replace("{OUTPUT}", file.getAbsolutePath()).replace("{TRAIN_DATA}", trainingFile.getAbsolutePath()).replace("{TEST_DATA}", testFile.getAbsolutePath());
                    }
                });
                logger.info("running {}", StringUtils.join(transform, " "));
                try {
                    Process start = new ProcessBuilder(new String[0]).command(transform).directory(this.algorithm.getWorkDir()).start();
                    new LoggingStreamSlurper("external-algo", start.getErrorStream(), logger, "external: ").run();
                    int i = -1;
                    for (boolean z = false; !z; z = true) {
                        try {
                            i = start.waitFor();
                        } catch (InterruptedException e) {
                            logger.info("thread interrupted, killing subprocess");
                            start.destroy();
                            throw new RecommenderBuildException("recommender build interrupted", e);
                        }
                    }
                    if (i != 0) {
                        logger.error("external command exited with status {}", Integer.valueOf(i));
                        throw new RecommenderBuildException("recommender exited with code " + i);
                    }
                    try {
                        this.userPredictions = readPredictions(file);
                        this.userTrainingEvents = this.dataSet.getTrainingData().getUserEventDAO();
                        this.userTestEvents = this.dataSet.getTestData().getUserEventDAO();
                    } catch (FileNotFoundException e2) {
                        logger.error("cannot find expected output file {}", file);
                        throw new RecommenderBuildException("recommender produced no output", e2);
                    }
                } catch (IOException e3) {
                    throw new RecommenderBuildException("error creating process", e3);
                }
            } catch (IOException e4) {
                throw new RecommenderBuildException("error preparing test file", e4);
            }
        } catch (IOException e5) {
            throw new RecommenderBuildException("error preparing training file", e5);
        }
    }

    @Override // org.grouplens.lenskit.eval.traintest.TrainTestJob
    protected TestUser getUserResults(long j) {
        Preconditions.checkState(this.userPredictions != null, "recommender not built");
        return new TestUserImpl(j);
    }

    @Override // org.grouplens.lenskit.eval.traintest.TrainTestJob
    protected void cleanup() {
        this.userPredictions = null;
        this.userTrainingEvents = null;
        this.userTestEvents = null;
    }

    private File getStagingDir() {
        return new File(this.algorithm.getWorkDir(), String.format("%s-%s", this.algorithm.getName(), this.key));
    }

    private File getFile(String str) {
        return new File(getStagingDir(), str);
    }

    private File trainingFile(TTDataSet tTDataSet) throws IOException {
        try {
            CSVDataSource cSVDataSource = (CSVDataSource) ((GenericTTDataSet) tTDataSet).getTrainingData();
            if (",".equals(cSVDataSource.getDelimiter())) {
                File file = cSVDataSource.getFile();
                logger.debug("using training file {}", file);
                return file;
            }
        } catch (ClassCastException e) {
        }
        File makeCSV = makeCSV(tTDataSet.getTrainingDAO(), getFile("train.csv"), true);
        logger.debug("wrote training file {}", makeCSV);
        return makeCSV;
    }

    private File testFile(TTDataSet tTDataSet) throws IOException {
        File makeCSV = makeCSV(tTDataSet.getTestDAO(), getFile("test.csv"), false);
        logger.debug("wrote test file {}", makeCSV);
        return makeCSV;
    }

    /* JADX WARN: Finally extract failed */
    private File makeCSV(EventDAO eventDAO, File file, boolean z) throws IOException {
        Object[] objArr = new Object[z ? 3 : 2];
        CSVWriter open = CSVWriter.open(file, null);
        try {
            Cursor streamEvents = eventDAO.streamEvents(Rating.class);
            try {
                for (Rating rating : streamEvents.fast()) {
                    Preference preference = rating.getPreference();
                    if (preference != null) {
                        objArr[0] = Long.valueOf(rating.getUserId());
                        objArr[1] = Long.valueOf(rating.getItemId());
                        if (z) {
                            objArr[2] = Double.valueOf(preference.getValue());
                        }
                        open.writeRow(objArr);
                    }
                }
                streamEvents.close();
                return file;
            } catch (Throwable th) {
                streamEvents.close();
                throw th;
            }
        } finally {
            open.close();
        }
    }

    private Long2ObjectMap<SparseVector> readPredictions(File file) throws FileNotFoundException, RecommenderBuildException {
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap();
        DelimitedTextCursor<String[]> delimitedTextCursor = new DelimitedTextCursor(file, this.algorithm.getOutputDelimiter());
        try {
            int i = 0;
            for (String[] strArr : delimitedTextCursor) {
                i++;
                if (strArr.length != 0 && (strArr.length != 1 || !strArr[0].equals(""))) {
                    if (strArr.length < 3) {
                        logger.error("predictions line {}: invalid row {}", Integer.valueOf(i), StringUtils.join(strArr, ","));
                        throw new RecommenderBuildException("invalid prediction row");
                    }
                    long parseLong = Long.parseLong(strArr[0]);
                    long parseLong2 = Long.parseLong(strArr[1]);
                    double parseDouble = Double.parseDouble(strArr[2]);
                    Long2DoubleOpenHashMap long2DoubleOpenHashMap = (Long2DoubleMap) long2ObjectOpenHashMap.get(parseLong);
                    if (long2DoubleOpenHashMap == null) {
                        long2DoubleOpenHashMap = new Long2DoubleOpenHashMap();
                        long2ObjectOpenHashMap.put(parseLong, long2DoubleOpenHashMap);
                    }
                    long2DoubleOpenHashMap.put(parseLong2, parseDouble);
                }
            }
            Long2ObjectOpenHashMap long2ObjectOpenHashMap2 = new Long2ObjectOpenHashMap(long2ObjectOpenHashMap.size());
            for (Long2ObjectMap.Entry entry : CollectionUtils.fast(long2ObjectOpenHashMap.long2ObjectEntrySet())) {
                long2ObjectOpenHashMap2.put(entry.getLongKey(), ImmutableSparseVector.create((Map) entry.getValue()));
            }
            return long2ObjectOpenHashMap2;
        } finally {
            delimitedTextCursor.close();
        }
    }
}
