package org.jpmml.converter.mining;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.Field;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;

/* loaded from: input_file:org/jpmml/converter/mining/MiningModelUtil.class */
public class MiningModelUtil {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.converter.mining.MiningModelUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/converter/mining/MiningModelUtil$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$ResultFeature;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningField$UsageType = new int[MiningField.UsageType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.PREDICTED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.TARGET.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$ResultFeature = new int[ResultFeature.values().length];
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.PROBABILITY.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.AFFINITY.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod = new int[Segmentation.MultipleModelMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[Segmentation.MultipleModelMethod.SELECT_ALL.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[Segmentation.MultipleModelMethod.MODEL_CHAIN.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
            $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod = new int[RegressionModel.NormalizationMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SIMPLEMAX.ordinal()] = 2;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    private MiningModelUtil() {
    }

    public static MiningModel createRegression(Model model, RegressionModel.NormalizationMethod normalizationMethod, Schema schema) {
        ContinuousFeature prediction = getPrediction(model, schema);
        MathContext mathContext = model.getMathContext();
        return createModelChain(Arrays.asList(model, RegressionModelUtil.createRegression(mathContext, Collections.singletonList(prediction), Collections.singletonList(Double.valueOf(1.0d)), null, normalizationMethod, schema)), Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
    }

    public static MiningModel createBinaryLogisticClassification(Model model, double d, double d2, RegressionModel.NormalizationMethod normalizationMethod, boolean z, Schema schema) {
        ContinuousFeature prediction = getPrediction(model, schema);
        MathContext mathContext = model.getMathContext();
        return createModelChain(Arrays.asList(model, RegressionModelUtil.createBinaryLogisticClassification(mathContext, Collections.singletonList(prediction), Collections.singletonList(Double.valueOf(d)), Double.valueOf(d2), normalizationMethod, z, schema)), Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
    }

    public static MiningModel createClassification(List<? extends Model> list, RegressionModel.NormalizationMethod normalizationMethod, boolean z, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
        SchemaUtil.checkSize(list.size(), categoricalLabel);
        if (normalizationMethod != null) {
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[normalizationMethod.ordinal()]) {
                case 1:
                    if (categoricalLabel.size() < 3) {
                        throw new IllegalArgumentException();
                    }
                    break;
                case 2:
                case 3:
                    if (categoricalLabel.size() < 2) {
                        throw new IllegalArgumentException();
                    }
                    break;
                default:
                    throw new IllegalArgumentException();
            }
        } else if (categoricalLabel.size() < 3) {
            throw new IllegalArgumentException();
        }
        MathContext mathContext = null;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < categoricalLabel.size(); i++) {
            Model model = list.get(i);
            MathContext mathContext2 = model.getMathContext();
            if (mathContext2 == null) {
                mathContext2 = MathContext.DOUBLE;
            }
            if (mathContext == null) {
                mathContext = mathContext2;
            } else if (!Objects.equals(mathContext, mathContext2)) {
                throw new IllegalArgumentException();
            }
            arrayList.add(RegressionModelUtil.createRegressionTable(mathContext, Collections.singletonList(getPrediction(model, schema)), Collections.singletonList(Double.valueOf(1.0d)), null).setTargetCategory(categoricalLabel.getValue(i)));
        }
        RegressionModel output = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), arrayList).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(z ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
        ArrayList arrayList2 = new ArrayList(list);
        arrayList2.add(output);
        return createModelChain(arrayList2, Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
    }

    public static MiningModel createModelChain(List<? extends Model> list, Segmentation.MissingPredictionTreatment missingPredictionTreatment) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException();
        }
        return new MiningModel(((Model) Iterables.getLast(list)).requireMiningFunction(), createMiningSchema(list)).setSegmentation(createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, missingPredictionTreatment, list));
    }

    public static MiningModel createMultiModelChain(List<? extends Model> list, Segmentation.MissingPredictionTreatment missingPredictionTreatment) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException();
        }
        MiningFunction miningFunction = null;
        Iterator<? extends Model> it = list.iterator();
        while (it.hasNext()) {
            MiningFunction requireMiningFunction = it.next().requireMiningFunction();
            if (miningFunction == null) {
                miningFunction = requireMiningFunction;
            } else if (miningFunction != MiningFunction.MIXED && !Objects.equals(miningFunction, requireMiningFunction)) {
                miningFunction = MiningFunction.MIXED;
            }
        }
        return new MiningModel(miningFunction, createMiningSchema(list)).setSegmentation(createSegmentation(Segmentation.MultipleModelMethod.MULTI_MODEL_CHAIN, missingPredictionTreatment, list));
    }

    public static MiningSchema createMiningSchema(List<? extends Model> list) {
        MiningSchema miningSchema = new MiningSchema();
        Stream map = list.stream().map((v0) -> {
            return v0.requireMiningSchema();
        }).map((v0) -> {
            return v0.getMiningFields();
        }).flatMap((v0) -> {
            return v0.stream();
        }).filter(miningField -> {
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningField$UsageType[miningField.getUsageType().ordinal()]) {
                case 1:
                case 2:
                    return true;
                default:
                    return false;
            }
        }).map((v0) -> {
            return v0.getName();
        }).distinct().map(str -> {
            return ModelUtil.createMiningField(str, MiningField.UsageType.TARGET);
        });
        Objects.requireNonNull(miningSchema);
        map.forEach(miningField2 -> {
            miningSchema.addMiningFields(new MiningField[]{miningField2});
        });
        return miningSchema;
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, List<? extends Model> list) {
        return createSegmentation(multipleModelMethod, missingPredictionTreatment, list, null);
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, List<? extends Model> list, List<? extends Number> list2) {
        if (list2 != null && list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Model model = list.get(i);
            Number number = list2 != null ? list2.get(i) : null;
            Segment id = new Segment(True.INSTANCE, model).setId(String.valueOf(i + 1));
            if (number != null && !ValueUtil.isOne(number)) {
                id.setWeight(number);
            }
            arrayList.add(id);
        }
        return new Segmentation(multipleModelMethod, arrayList).setMissingPredictionTreatment(missingPredictionTreatment);
    }

    public static Model getFinalModel(Model model) {
        return model instanceof MiningModel ? getFinalModel((MiningModel) model) : model;
    }

    public static Model getFinalModel(MiningModel miningModel) {
        Segmentation requireSegmentation = miningModel.requireSegmentation();
        Segmentation.MultipleModelMethod requireMultipleModelMethod = requireSegmentation.requireMultipleModelMethod();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[requireMultipleModelMethod.ordinal()]) {
            case 1:
                throw new UnsupportedAttributeException(requireSegmentation, requireMultipleModelMethod);
            case 2:
                if (isChain(requireSegmentation)) {
                    List requireSegments = requireSegmentation.requireSegments();
                    Segment segment = (Segment) requireSegments.get(requireSegments.size() - 1);
                    segment.requirePredicate(True.class);
                    return getFinalModel(segment.requireModel());
                }
                break;
        }
        return miningModel;
    }

    public static boolean isChain(Segmentation segmentation) {
        Iterator it = segmentation.requireSegments().iterator();
        while (it.hasNext()) {
            if (!(((Segment) it.next()).requirePredicate() instanceof True)) {
                return false;
            }
        }
        return true;
    }

    public static void optimizeOutputFields(MiningModel miningModel) {
        Segmentation requireSegmentation = miningModel.requireSegmentation();
        Map<String, OutputField> collectCommonOutputFields = collectCommonOutputFields(requireSegmentation);
        if (collectCommonOutputFields.isEmpty()) {
            return;
        }
        Output ensureOutput = ModelUtil.ensureOutput(miningModel);
        removeCommonOutputFields(requireSegmentation, collectCommonOutputFields.keySet());
        ensureOutput.getOutputFields().addAll(collectCommonOutputFields.values());
    }

    private static Map<String, OutputField> collectCommonOutputFields(Segmentation segmentation) {
        Map<String, OutputField> map = null;
        Iterator it = segmentation.requireSegments().iterator();
        while (it.hasNext()) {
            Output output = getFinalModel(((Segment) it.next()).requireModel()).getOutput();
            if (output == null || !output.hasOutputFields()) {
                map = Collections.emptyMap();
            } else {
                List<OutputField> outputFields = output.getOutputFields();
                if (map == null) {
                    map = (Map) outputFields.stream().filter(outputField -> {
                        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$ResultFeature[outputField.getResultFeature().ordinal()]) {
                            case 1:
                            case 2:
                                return true;
                            default:
                                return false;
                        }
                    }).collect(Collectors.toMap(outputField2 -> {
                        return outputField2.requireName();
                    }, outputField3 -> {
                        return outputField3;
                    }));
                } else {
                    LinkedHashSet linkedHashSet = new LinkedHashSet();
                    for (OutputField outputField4 : outputFields) {
                        String requireName = outputField4.requireName();
                        linkedHashSet.add(requireName);
                        OutputField outputField5 = map.get(requireName);
                        if (outputField5 != null && !ReflectionUtil.equals(outputField4, outputField5)) {
                            map.remove(requireName);
                        }
                    }
                    map.keySet().retainAll(linkedHashSet);
                }
            }
            if (map.isEmpty()) {
                break;
            }
        }
        return map;
    }

    private static void removeCommonOutputFields(Segmentation segmentation, Set<String> set) {
        Iterator it = segmentation.requireSegments().iterator();
        while (it.hasNext()) {
            Model finalModel = getFinalModel(((Segment) it.next()).requireModel());
            Output output = finalModel.getOutput();
            if (output != null && output.hasOutputFields()) {
                List outputFields = output.getOutputFields();
                outputFields.removeIf(outputField -> {
                    return set.contains(outputField.requireName());
                });
                if (outputFields.isEmpty()) {
                    finalModel.setOutput((Output) null);
                }
            }
        }
    }

    private static ContinuousFeature getPrediction(Model model, Schema schema) {
        Output output = model.getOutput();
        if (output == null || !output.hasOutputFields()) {
            throw new InvalidElementException(model);
        }
        return new ContinuousFeature(schema.getEncoder(), (Field<?>) Iterables.getLast(output.getOutputFields()));
    }
}
