package org.grouplens.lenskit.eval.data.pack;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import javax.annotation.Nullable;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.data.dao.packed.BinaryFormatFlag;
import org.grouplens.lenskit.data.dao.packed.BinaryRatingPacker;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.data.CSVDataSource;
import org.grouplens.lenskit.eval.data.DataSource;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataSet;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.util.io.UpToDateChecker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/data/pack/PackTask.class */
public class PackTask extends AbstractTask<List<Object>> {
    private static final Logger logger;
    private List<TTDataSet> trainTestSets = Lists.newArrayList();
    private List<DataSource> dataSources = Lists.newArrayList();
    private Function<DataSource, File> packFileFunction = new DefaultOutputFunction();
    private EnumSet<BinaryFormatFlag> binaryFlags = EnumSet.of(BinaryFormatFlag.TIMESTAMPS);
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/grouplens/lenskit/eval/data/pack/PackTask$DefaultOutputFunction.class */
    private class DefaultOutputFunction implements Function<DataSource, File> {
        static final /* synthetic */ boolean $assertionsDisabled;

        private DefaultOutputFunction() {
        }

        @Nullable
        public File apply(@Nullable DataSource dataSource) {
            if (!$assertionsDisabled && dataSource == null) {
                throw new AssertionError();
            }
            if (dataSource instanceof CSVDataSource) {
                File file = ((CSVDataSource) dataSource).getFile();
                return new File(file.getParentFile(), file.getName() + ".pack");
            }
            return new File(new File(new File(PackTask.this.getProject().getConfig().getDataDir()), "packed"), UUID.randomUUID().toString() + ".pack");
        }

        static {
            $assertionsDisabled = !PackTask.class.desiredAssertionStatus();
        }
    }

    public void addDataset(TTDataSet tTDataSet) {
        Preconditions.checkNotNull(tTDataSet, "data source");
        this.trainTestSets.add(tTDataSet);
    }

    public void addDataset(DataSource dataSource) {
        Preconditions.checkNotNull(dataSource, "data source");
        this.dataSources.add(dataSource);
    }

    public void setIncludeTimestamps(boolean z) {
        if (z) {
            this.binaryFlags.add(BinaryFormatFlag.TIMESTAMPS);
        } else {
            this.binaryFlags.remove(BinaryFormatFlag.TIMESTAMPS);
        }
    }

    public boolean getIncludeTimestamps() {
        return this.binaryFlags.contains(BinaryFormatFlag.TIMESTAMPS);
    }

    public void setOutputFile(Function<DataSource, File> function) {
        Preconditions.checkNotNull(function, "output file function");
        this.packFileFunction = function;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.grouplens.lenskit.eval.AbstractTask
    public List<Object> perform() throws TaskExecutionException, InterruptedException {
        ArrayList newArrayList = Lists.newArrayList();
        for (TTDataSet tTDataSet : this.trainTestSets) {
            DataSource packDataSource = packDataSource(tTDataSet.getTrainingData());
            newArrayList.add(GenericTTDataSet.copyBuilder(tTDataSet).setTrain(packDataSource).setTest(packDataSource(tTDataSet.getTestData())).m19build());
        }
        Iterator<DataSource> it = this.dataSources.iterator();
        while (it.hasNext()) {
            newArrayList.add(packDataSource(it.next()));
        }
        return newArrayList;
    }

    /* JADX WARN: Finally extract failed */
    private DataSource packDataSource(DataSource dataSource) throws TaskExecutionException {
        File file = (File) this.packFileFunction.apply(dataSource);
        Preconditions.checkNotNull(file, "output file");
        if (!$assertionsDisabled && file == null) {
            throw new AssertionError();
        }
        PackedDataSource packedDataSource = new PackedDataSource(dataSource.getName(), file, dataSource.getPreferenceDomain());
        UpToDateChecker upToDateChecker = new UpToDateChecker();
        upToDateChecker.addInput(dataSource.lastModified());
        upToDateChecker.addOutput(file);
        if (upToDateChecker.isUpToDate()) {
            logger.info("{} is up to date", file);
        } else {
            logger.info("packing {} to {}", dataSource, file);
            File file2 = new File(file.getParentFile(), file.getName() + ".tmp");
            try {
                BinaryRatingPacker open = BinaryRatingPacker.open(file2, this.binaryFlags);
                try {
                    Cursor streamEvents = dataSource.getEventDAO().streamEvents(Rating.class);
                    try {
                        open.writeRatings(streamEvents);
                        streamEvents.close();
                        open.close();
                        file2.renameTo(file);
                    } catch (Throwable th) {
                        streamEvents.close();
                        throw th;
                    }
                } catch (Throwable th2) {
                    open.close();
                    throw th2;
                }
            } catch (IOException e) {
                logger.error("error packing {}: {}", file, e);
                file2.delete();
                throw new TaskExecutionException("error packing " + file, e);
            }
        }
        return packedDataSource;
    }

    static {
        $assertionsDisabled = !PackTask.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(PackTask.class);
    }
}
