package org.grouplens.lenskit.eval.data.crossfold;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.dtree.DataNode;
import org.grouplens.lenskit.dtree.Trees;
import org.grouplens.lenskit.eval.EvaluatorConfigurationException;
import org.grouplens.lenskit.eval.data.DataSource;
import org.grouplens.lenskit.eval.data.DataSourceProvider;
import org.grouplens.lenskit.eval.data.traintest.TTDataProvider;
import org.grouplens.lenskit.eval.data.traintest.TTDataSet;
import org.grouplens.lenskit.util.spi.ConfigAlias;
import org.grouplens.lenskit.util.spi.ServiceFinder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ConfigAlias("crossfold")
/* loaded from: input_file:org/grouplens/lenskit/eval/data/crossfold/CrossfoldTTDataProvider.class */
public class CrossfoldTTDataProvider implements TTDataProvider {
    private static final Logger logger = LoggerFactory.getLogger(CrossfoldTTDataProvider.class);
    static final Pattern intPat = Pattern.compile("^\\d+$");
    static final Pattern pctPat = Pattern.compile("^(\\d+)%");

    @Override // org.grouplens.lenskit.eval.data.traintest.TTDataProvider
    public List<TTDataSet> configure(DataNode dataNode) throws EvaluatorConfigurationException {
        logger.debug("Configuring crossfolder from {}", dataNode);
        int childValueInt = Trees.childValueInt(dataNode, "folds", 5);
        Holdout holdout = new Holdout();
        holdout.setOrder(parseOrder(dataNode));
        holdout.setPartition(parsePartition(dataNode));
        boolean childValueBool = Trees.childValueBool(dataNode, "database", false);
        List<DataSource> configureSources = configureSources(dataNode);
        if (configureSources.isEmpty()) {
            throw new EvaluatorConfigurationException("No crossfold inputs configured");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<DataSource> it = configureSources.iterator();
        while (it.hasNext()) {
            CrossfoldManager crossfoldManager = new CrossfoldManager(it.next(), childValueInt, holdout);
            for (int i = 1; i <= childValueInt; i++) {
                arrayList.add(childValueBool ? new DBCrossfoldTTDataSet(crossfoldManager, i) : new MemoryCrossfoldTTDataSet(crossfoldManager, i));
            }
        }
        return arrayList;
    }

    private List<DataSource> configureSources(DataNode dataNode) throws EvaluatorConfigurationException {
        logger.debug("Looking for sources in {}", dataNode);
        DataNode child = Trees.child(dataNode, "sources");
        if (child == null) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        ServiceFinder serviceFinder = ServiceFinder.get(DataSourceProvider.class);
        for (DataNode dataNode2 : child.getChildren()) {
            logger.debug("Getting data provider {}", dataNode2.getName());
            DataSourceProvider dataSourceProvider = (DataSourceProvider) serviceFinder.findProvider(dataNode2.getName());
            if (dataSourceProvider == null) {
                throw new EvaluatorConfigurationException("Unknown data source " + dataNode2.getName());
            }
            arrayList.addAll(dataSourceProvider.configure(dataNode2));
        }
        return arrayList;
    }

    private Order<Rating> parseOrder(DataNode dataNode) throws EvaluatorConfigurationException {
        String childValue = Trees.childValue(dataNode, "mode", "random");
        if (childValue.equalsIgnoreCase("random")) {
            return new RandomOrder();
        }
        if (childValue.equalsIgnoreCase("timestamp")) {
            return new TimestampOrder();
        }
        throw new EvaluatorConfigurationException("Unknown holdout mode " + childValue);
    }

    private PartitionAlgorithm<Rating> parsePartition(DataNode dataNode) throws EvaluatorConfigurationException {
        String childValue = Trees.childValue(dataNode, "holdout", "10");
        Matcher matcher = intPat.matcher(childValue);
        if (matcher.find()) {
            int parseInt = Integer.parseInt(matcher.group(0));
            logger.info("Holding out {} ratings", Integer.valueOf(parseInt));
            return new CountPartition(parseInt);
        }
        if (!pctPat.matcher(childValue).find()) {
            throw new EvaluatorConfigurationException("Invalid holdout specification " + childValue);
        }
        double parseInt2 = Integer.parseInt(r0.group(1)) / 100.0d;
        logger.info("Holding out {} of the ratings", Double.valueOf(parseInt2));
        return new FractionPartition(parseInt2);
    }
}
