package org.grouplens.lenskit.eval.data.crossfold;

import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import it.unimi.dsi.fastutil.longs.Long2IntMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongListIterator;
import it.unimi.dsi.fastutil.longs.LongLists;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.grouplens.lenskit.cursors.Cursor;
import org.grouplens.lenskit.cursors.Cursors;
import org.grouplens.lenskit.data.dao.UserDAO;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.data.pref.Preference;
import org.grouplens.lenskit.eval.AbstractTask;
import org.grouplens.lenskit.eval.TaskExecutionException;
import org.grouplens.lenskit.eval.data.CSVDataSourceBuilder;
import org.grouplens.lenskit.eval.data.DataSource;
import org.grouplens.lenskit.eval.data.traintest.GenericTTDataBuilder;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.util.io.UpToDateChecker;
import org.grouplens.lenskit.util.table.writer.CSVWriter;
import org.grouplens.lenskit.util.table.writer.TableWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/eval/data/crossfold/CrossfoldTask.class */
public class CrossfoldTask extends AbstractTask<List<TTDataSet>> {
    private static final Logger logger = LoggerFactory.getLogger(CrossfoldTask.class);
    private DataSource source;
    private int partitionCount;
    private String trainFilePattern;
    private String testFilePattern;
    private Order<Rating> order;
    private PartitionAlgorithm<Rating> partition;
    private boolean isForced;
    private boolean splitUsers;
    private boolean cacheOutput;

    public CrossfoldTask() {
        super(null);
        this.partitionCount = 5;
        this.order = new RandomOrder();
        this.partition = new HoldoutNPartition(10);
        this.splitUsers = true;
        this.cacheOutput = true;
    }

    public CrossfoldTask(String str) {
        super(str);
        this.partitionCount = 5;
        this.order = new RandomOrder();
        this.partition = new HoldoutNPartition(10);
        this.splitUsers = true;
        this.cacheOutput = true;
    }

    public CrossfoldTask setPartitions(int i) {
        this.partitionCount = i;
        return this;
    }

    public CrossfoldTask setTrain(String str) {
        this.trainFilePattern = str;
        return this;
    }

    public CrossfoldTask setTest(String str) {
        this.testFilePattern = str;
        return this;
    }

    public CrossfoldTask setOrder(Order<Rating> order) {
        this.order = order;
        return this;
    }

    public CrossfoldTask setHoldout(int i) {
        this.partition = new HoldoutNPartition(i);
        return this;
    }

    public CrossfoldTask setRetain(int i) {
        this.partition = new RetainNPartition(i);
        return this;
    }

    public CrossfoldTask setHoldoutFraction(double d) {
        this.partition = new FractionPartition(d);
        return this;
    }

    public CrossfoldTask setSource(DataSource dataSource) {
        this.source = dataSource;
        return this;
    }

    public CrossfoldTask setForce(boolean z) {
        this.isForced = z;
        return this;
    }

    public void setSplitUsers(boolean z) {
        this.splitUsers = z;
    }

    public CrossfoldTask setCache(boolean z) {
        this.cacheOutput = z;
        return this;
    }

    @Override // org.grouplens.lenskit.eval.AbstractTask
    public String getName() {
        String name = super.getName();
        if (name == null) {
            name = this.source.getName();
        }
        return name;
    }

    public String getTrainPattern() {
        if (this.trainFilePattern != null) {
            return this.trainFilePattern;
        }
        StringBuilder sb = new StringBuilder();
        String dataDir = getProject().getConfig().getDataDir();
        if (dataDir == null) {
            dataDir = ".";
        }
        return sb.append(dataDir).append(File.separator).append(getName()).append("-crossfold").append(File.separator).append("train.%d.csv").toString();
    }

    public String getTestPattern() {
        if (this.testFilePattern != null) {
            return this.testFilePattern;
        }
        StringBuilder sb = new StringBuilder();
        String dataDir = getProject().getConfig().getDataDir();
        if (dataDir == null) {
            dataDir = ".";
        }
        return sb.append(dataDir).append(File.separator).append(getName()).append("-crossfold").append(File.separator).append("test.%d.csv").toString();
    }

    public DataSource getSource() {
        return this.source;
    }

    public int getPartitionCount() {
        return this.partitionCount;
    }

    public Holdout getHoldout() {
        return new Holdout(this.order, this.partition);
    }

    public boolean getForce() {
        return this.isForced || getProject().getConfig().force();
    }

    public boolean getSplitUsers() {
        return this.splitUsers;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.grouplens.lenskit.eval.AbstractTask
    public List<TTDataSet> perform() throws TaskExecutionException {
        if (!getForce()) {
            UpToDateChecker upToDateChecker = new UpToDateChecker();
            upToDateChecker.addInput(this.source.lastModified());
            for (File file : getFiles(getTrainPattern())) {
                upToDateChecker.addOutput(file);
            }
            for (File file2 : getFiles(getTestPattern())) {
                upToDateChecker.addOutput(file2);
            }
            if (upToDateChecker.isUpToDate()) {
                logger.info("crossfold {} up to date", getName());
                return getTTFiles();
            }
        }
        try {
            createTTFiles();
            return getTTFiles();
        } catch (IOException e) {
            throw new TaskExecutionException("Error writing data sets", e);
        }
    }

    protected File[] getFiles(String str) {
        File[] fileArr = new File[this.partitionCount];
        for (int i = 0; i < this.partitionCount; i++) {
            fileArr[i] = new File(String.format(str, Integer.valueOf(i)));
        }
        return fileArr;
    }

    protected void createTTFiles() throws IOException {
        File[] files = getFiles(getTrainPattern());
        File[] files2 = getFiles(getTestPattern());
        TableWriter[] tableWriterArr = new TableWriter[this.partitionCount];
        TableWriter[] tableWriterArr2 = new TableWriter[this.partitionCount];
        Closer create = Closer.create();
        for (int i = 0; i < this.partitionCount; i++) {
            try {
                try {
                    File file = files[i];
                    File file2 = files2[i];
                    tableWriterArr[i] = (TableWriter) create.register(CSVWriter.open(file, null));
                    tableWriterArr2[i] = (TableWriter) create.register(CSVWriter.open(file2, null));
                } catch (Throwable th) {
                    throw create.rethrow(th);
                }
            } finally {
                create.close();
            }
        }
        if (getSplitUsers()) {
            writeTTFilesByUsers(tableWriterArr, tableWriterArr2);
        } else {
            writeTTFilesByRatings(tableWriterArr, tableWriterArr2);
        }
    }

    protected void writeTTFilesByUsers(TableWriter[] tableWriterArr, TableWriter[] tableWriterArr2) throws TaskExecutionException {
        logger.info("splitting data source {} to {} partitions by users", getName(), Integer.valueOf(this.partitionCount));
        Long2IntMap splitUsers = splitUsers(this.source.getUserDAO());
        Cursor<UserHistory> streamEventsByUser = this.source.getUserEventDAO().streamEventsByUser();
        Holdout holdout = getHoldout();
        try {
            try {
                for (UserHistory userHistory : streamEventsByUser) {
                    int i = splitUsers.get(userHistory.getUserId());
                    ArrayList arrayList = new ArrayList((Collection) userHistory.filter(Rating.class));
                    int partition = holdout.partition(arrayList, getProject().getRandom());
                    int size = arrayList.size();
                    for (int i2 = 0; i2 < this.partitionCount; i2++) {
                        if (i2 == i) {
                            for (int i3 = 0; i3 < partition; i3++) {
                                writeRating(tableWriterArr[i2], arrayList.get(i3));
                            }
                            for (int i4 = partition; i4 < size; i4++) {
                                writeRating(tableWriterArr2[i2], arrayList.get(i4));
                            }
                        } else {
                            Iterator<Rating> it = arrayList.iterator();
                            while (it.hasNext()) {
                                writeRating(tableWriterArr[i2], it.next());
                            }
                        }
                    }
                }
            } catch (IOException e) {
                throw new TaskExecutionException("Error writing to the train test files", e);
            }
        } finally {
            streamEventsByUser.close();
        }
    }

    protected void writeTTFilesByRatings(TableWriter[] tableWriterArr, TableWriter[] tableWriterArr2) throws TaskExecutionException {
        logger.info("splitting data source {} to {} partitions by ratings", getName(), Integer.valueOf(this.partitionCount));
        ArrayList makeList = Cursors.makeList(this.source.getEventDAO().streamEvents(Rating.class));
        Collections.shuffle(makeList);
        try {
            int size = makeList.size();
            for (int i = 0; i < size; i++) {
                for (int i2 = 0; i2 < this.partitionCount; i2++) {
                    if (i2 == i % this.partitionCount) {
                        writeRating(tableWriterArr2[i2], (Rating) makeList.get(i));
                    } else {
                        writeRating(tableWriterArr[i2], (Rating) makeList.get(i));
                    }
                }
            }
        } catch (IOException e) {
            throw new TaskExecutionException("Error writing to the train test files", e);
        }
    }

    protected void writeRating(TableWriter tableWriter, Rating rating) throws IOException {
        Preference preference = rating.getPreference();
        String[] strArr = new String[4];
        strArr[0] = Long.toString(rating.getUserId());
        strArr[1] = Long.toString(rating.getItemId());
        strArr[2] = preference != null ? Double.toString(preference.getValue()) : "NaN";
        strArr[3] = Long.toString(rating.getTimestamp());
        tableWriter.writeRow(Lists.newArrayList(strArr));
    }

    protected Long2IntMap splitUsers(UserDAO userDAO) {
        Long2IntOpenHashMap long2IntOpenHashMap = new Long2IntOpenHashMap();
        LongArrayList longArrayList = new LongArrayList(userDAO.getUserIds());
        LongLists.shuffle(longArrayList, getProject().getRandom());
        LongListIterator listIterator = longArrayList.listIterator();
        while (listIterator.hasNext()) {
            long2IntOpenHashMap.put(listIterator.nextLong(), listIterator.nextIndex() % this.partitionCount);
        }
        logger.info("Partitioned {} users", Integer.valueOf(long2IntOpenHashMap.size()));
        return long2IntOpenHashMap;
    }

    public List<TTDataSet> getTTFiles() {
        ArrayList arrayList = new ArrayList(this.partitionCount);
        File[] files = getFiles(getTrainPattern());
        File[] files2 = getFiles(getTestPattern());
        for (int i = 0; i < this.partitionCount; i++) {
            arrayList.add(new GenericTTDataBuilder(getName() + "." + i).setTest(new CSVDataSourceBuilder().setDomain(this.source.getPreferenceDomain()).setCache(this.cacheOutput).setFile(files2[i]).m12build()).setTrain(new CSVDataSourceBuilder().setDomain(this.source.getPreferenceDomain()).setCache(this.cacheOutput).setFile(files[i]).m12build()).setAttribute("DataSet", getName()).setAttribute("Partition", Integer.valueOf(i)).m16build());
        }
        return arrayList;
    }

    public String toString() {
        return String.format("{CXManager %s}", this.source);
    }
}
