package org.jpmml.evaluator.testing;

import com.google.common.base.Equivalence;
import com.google.common.collect.MapDifference;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorFunction;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.HasGroupFields;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.Table;
import org.jpmml.evaluator.TableCollector;
import org.jpmml.evaluator.TargetField;

/* loaded from: input_file:org/jpmml/evaluator/testing/BatchUtil.class */
public class BatchUtil {
    private BatchUtil() {
    }

    public static List<Conflict> evaluate(Batch batch) throws Exception {
        HasGroupFields evaluator = batch.getEvaluator();
        Table input = batch.getInput();
        if (input.getNumberOfRows() == 0) {
            return Collections.emptyList();
        }
        if (evaluator instanceof HasGroupFields) {
            input = EvaluatorUtil.groupRows(evaluator, input);
        }
        Table output = batch.getOutput();
        if (input.getNumberOfRows() != output.getNumberOfRows()) {
            throw new IllegalArgumentException("Expected the same number of data rows, got " + input.getNumberOfRows() + " input data rows and " + output.getNumberOfRows() + " expected output data rows");
        }
        Set<String> collectResultColumns = collectResultColumns(evaluator, batch.getColumnFilter());
        Equivalence<Object> equivalence = batch.getEquivalence();
        Table table = (Table) input.stream().map(new EvaluatorFunction(evaluator)).collect(new TableCollector());
        if (output.getNumberOfRows() != table.getNumberOfRows()) {
            throw new IllegalArgumentException("Expected the same number of output data rows, got " + output.getNumberOfRows() + " expected output data rows and " + table.getNumberOfRows() + " actual output data rows");
        }
        ArrayList arrayList = new ArrayList();
        Table.Row createReaderRow = input.createReaderRow(0);
        Table.Row createReaderRow2 = output.createReaderRow(0);
        Table.Row createReaderRow3 = table.createReaderRow(0);
        int numberOfRows = output.getNumberOfRows();
        for (int i = 0; i < numberOfRows; i++) {
            Exception exception = createReaderRow3.getException();
            if (exception != null) {
                arrayList.add(new Conflict(Integer.valueOf(i), (Map<String, ?>) createReaderRow, exception));
            } else {
                MapDifference filteredDifference = filteredDifference(createReaderRow2, createReaderRow3, collectResultColumns, equivalence);
                if (!filteredDifference.areEqual()) {
                    arrayList.add(new Conflict(Integer.valueOf(i), (Map<String, ?>) createReaderRow, (MapDifference<String, ?>) filteredDifference));
                }
            }
            createReaderRow.advance();
            createReaderRow2.advance();
            createReaderRow3.advance();
        }
        return arrayList;
    }

    public static List<Conflict> evaluateSingleton(Batch batch, Function<Map<String, ?>, Table> function) throws Exception {
        HasGroupFields evaluator = batch.getEvaluator();
        Table input = batch.getInput();
        if (evaluator instanceof HasGroupFields) {
            input = EvaluatorUtil.groupRows(evaluator, input);
        }
        if (input.getNumberOfRows() != 1) {
            throw new IllegalArgumentException("Expected exactly one input data row, got " + input.getNumberOfRows() + " input data rows");
        }
        Table output = batch.getOutput();
        Set<String> collectResultColumns = collectResultColumns(evaluator, batch.getColumnFilter());
        Equivalence<Object> equivalence = batch.getEquivalence();
        Table.Row createReaderRow = input.createReaderRow(0);
        try {
            Table apply = function.apply(evaluator.evaluate(createReaderRow));
            if (output.getNumberOfRows() != apply.getNumberOfRows()) {
                throw new IllegalArgumentException("Expected the same number of output data rows, got " + output.getNumberOfRows() + " expected output data rows and " + apply.getNumberOfRows() + " actual output data rows");
            }
            ArrayList arrayList = new ArrayList();
            Table.Row createReaderRow2 = output.createReaderRow(0);
            Table.Row createReaderRow3 = apply.createReaderRow(0);
            int numberOfRows = output.getNumberOfRows();
            for (int i = 0; i < numberOfRows; i++) {
                MapDifference filteredDifference = filteredDifference(createReaderRow2, createReaderRow3, collectResultColumns, equivalence);
                if (!filteredDifference.areEqual()) {
                    arrayList.add(new Conflict(0 + "/" + i, (Map<String, ?>) createReaderRow, (MapDifference<String, ?>) filteredDifference));
                }
                createReaderRow2.advance();
                createReaderRow3.advance();
            }
            return arrayList;
        } catch (Exception e) {
            return Collections.singletonList(new Conflict((Object) 0, (Map<String, ?>) createReaderRow, e));
        }
    }

    private static <K> MapDifference<K, Object> filteredDifference(Map<K, ?> map, Map<K, ?> map2, Set<K> set, Equivalence<Object> equivalence) {
        Objects.requireNonNull(set);
        Map filterKeys = Maps.filterKeys(map, set::contains);
        Objects.requireNonNull(set);
        return Maps.difference(filterKeys, Maps.filterKeys(map2, set::contains), equivalence);
    }

    private static Set<String> collectResultColumns(Evaluator evaluator, Predicate<ResultField> predicate) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (TargetField targetField : evaluator.getTargetFields()) {
            if (!targetField.isSynthetic() && predicate.test(targetField)) {
                linkedHashSet.add(targetField.getName());
            }
        }
        for (OutputField outputField : evaluator.getOutputFields()) {
            if (predicate.test(outputField)) {
                linkedHashSet.add(outputField.getName());
            }
        }
        return linkedHashSet;
    }
}
