package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/evaluation/EvaluationAggregator.class */
public final class EvaluationAggregator {
    private EvaluationAggregator() {
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T, C> evaluationMetric, List<? extends Model<T>> list, Dataset<T> dataset) {
        DescriptiveStats descriptiveStats = new DescriptiveStats();
        Iterator<? extends Model<T>> it = list.iterator();
        while (it.hasNext()) {
            descriptiveStats.addValue(evaluationMetric.compute(evaluationMetric.createContext(it.next(), dataset)));
        }
        return descriptiveStats;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T, R> evaluator, List<? extends Model<T>> list, Dataset<T> dataset) {
        return summarize((List) list.stream().map(model -> {
            return evaluator.evaluate(model, dataset);
        }).collect(Collectors.toList()));
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(EvaluationMetric<T, C> evaluationMetric, Model<T> model, List<? extends Dataset<T>> list) {
        DescriptiveStats descriptiveStats = new DescriptiveStats();
        Iterator<? extends Dataset<T>> it = list.iterator();
        while (it.hasNext()) {
            descriptiveStats.addValue(evaluationMetric.compute(evaluationMetric.createContext(model, it.next())));
        }
        return descriptiveStats;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T, C>> list, Model<T> model, Dataset<T> dataset) {
        List<Prediction<T>> predict = model.predict((Dataset) dataset);
        DescriptiveStats descriptiveStats = new DescriptiveStats();
        for (EvaluationMetric<T, C> evaluationMetric : list) {
            descriptiveStats.addValue(evaluationMetric.compute(evaluationMetric.createContext(model, predict)));
        }
        return descriptiveStats;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> DescriptiveStats summarize(List<? extends EvaluationMetric<T, C>> list, Model<T> model, List<Prediction<T>> list2) {
        DescriptiveStats descriptiveStats = new DescriptiveStats();
        for (EvaluationMetric<T, C> evaluationMetric : list) {
            descriptiveStats.addValue(evaluationMetric.compute(evaluationMetric.createContext(model, list2)));
        }
        return descriptiveStats;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(Evaluator<T, R> evaluator, Model<T> model, List<? extends Dataset<T>> list) {
        return summarize((List) list.stream().map(dataset -> {
            return evaluator.evaluate(model, dataset);
        }).collect(Collectors.toList()));
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarize(List<R> list) {
        HashMap hashMap = new HashMap();
        Iterator<R> it = list.iterator();
        while (it.hasNext()) {
            for (Map.Entry<MetricID<T>, Double> entry : it.next().asMap().entrySet()) {
                MetricID<T> key = entry.getKey();
                DescriptiveStats descriptiveStats = (DescriptiveStats) hashMap.getOrDefault(key, new DescriptiveStats());
                descriptiveStats.addValue(entry.getValue().doubleValue());
                hashMap.put(key, descriptiveStats);
            }
        }
        return hashMap;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> Map<MetricID<T>, DescriptiveStats> summarizeCrossValidation(List<Pair<R, Model<T>>> list) {
        HashMap hashMap = new HashMap();
        Iterator<Pair<R, Model<T>>> it = list.iterator();
        while (it.hasNext()) {
            for (Map.Entry<MetricID<T>, Double> entry : ((Evaluation) it.next().getA()).asMap().entrySet()) {
                MetricID<T> key = entry.getKey();
                DescriptiveStats descriptiveStats = (DescriptiveStats) hashMap.getOrDefault(key, new DescriptiveStats());
                descriptiveStats.addValue(entry.getValue().doubleValue());
                hashMap.put(key, descriptiveStats);
            }
        }
        return hashMap;
    }

    public static <T extends Output<T>, R extends Evaluation<T>> DescriptiveStats summarize(List<R> list, ToDoubleFunction<R> toDoubleFunction) {
        DescriptiveStats descriptiveStats = new DescriptiveStats();
        Iterator<R> it = list.iterator();
        while (it.hasNext()) {
            descriptiveStats.addValue(toDoubleFunction.applyAsDouble(it.next()));
        }
        return descriptiveStats;
    }

    public static <T extends Output<T>, C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T, C> evaluationMetric, List<? extends Model<T>> list, Dataset<T> dataset) {
        return Util.argmax((List) list.stream().map(model -> {
            return Double.valueOf(evaluationMetric.compute(evaluationMetric.createContext(model, dataset)));
        }).collect(Collectors.toList()));
    }

    public static <T extends Output<T>, C extends MetricContext<T>> Pair<Integer, Double> argmax(EvaluationMetric<T, C> evaluationMetric, Model<T> model, List<? extends Dataset<T>> list) {
        return Util.argmax((List) list.stream().map(dataset -> {
            return Double.valueOf(evaluationMetric.compute(evaluationMetric.createContext(model, dataset)));
        }).collect(Collectors.toList()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T extends Output<T>, R extends Evaluation<T>> Pair<Integer, Double> argmax(List<R> list, Function<R, Double> function) {
        return Util.argmax((List) list.stream().map(function).collect(Collectors.toList()));
    }
}
