package sklearn.tree;

import com.google.common.primitives.Doubles;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.HasContinuousDomain;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.NodeTransformer;
import org.dmg.pmml.tree.SimplifyingNodeTransformer;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnException;
import sklearn.Estimator;
import sklearn.VersionUtil;
import sklearn.tree.visitors.TreeModelCompactor;
import sklearn.tree.visitors.TreeModelFlattener;
import sklearn.tree.visitors.TreeModelPruner;

/* loaded from: input_file:sklearn/tree/TreeUtil.class */
public class TreeUtil {
    private TreeUtil() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <E extends Estimator & HasTree> boolean hasMissingValueSupport(E e) {
        String skLearnVersion = e.getSkLearnVersion();
        return skLearnVersion != null && VersionUtil.compareVersion(skLearnVersion, "1.3.0") >= 0;
    }

    public static <E extends Estimator & HasTree> TreeModel encodeTreeModel(E e, MiningFunction miningFunction, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        Tree tree = e.getTree();
        boolean hasMissingValueSupport = e.hasMissingValueSupport();
        TreeModel missingValueStrategy = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(0, True.INSTANCE, miningFunction, tree.getChildrenLeft(), tree.getChildrenRight(), tree.getFeature(), tree.getThreshold(), tree.getValues(), hasMissingValueSupport ? tree.getMissingGoToLeft() : null, new CategoryManager(), predicateManager, scoreDistributionManager, schema)).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(hasMissingValueSupport ? TreeModel.MissingValueStrategy.DEFAULT_CHILD : TreeModel.MissingValueStrategy.NULL_PREDICTION);
        ClassDictUtil.clearContent(tree);
        return missingValueStrategy;
    }

    private static Node encodeNode(int i, Predicate predicate, MiningFunction miningFunction, int[] iArr, int[] iArr2, int[] iArr3, double[] dArr, double[] dArr2, int[] iArr4, CategoryManager categoryManager, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, Schema schema) {
        PMMLObject id;
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        ClassifierNode branchNode;
        Integer valueOf = Integer.valueOf(i);
        int i2 = iArr3[i];
        if (i2 < 0) {
            if (miningFunction == MiningFunction.CLASSIFICATION) {
                CategoricalLabel label = schema.getLabel();
                final double[] row = getRow(dArr2, iArr.length, label.size(), i);
                AbstractList<Number> abstractList = new AbstractList<Number>() { // from class: sklearn.tree.TreeUtil.1
                    @Override // java.util.AbstractCollection, java.util.Collection, java.util.List
                    public int size() {
                        return row.length;
                    }

                    @Override // java.util.AbstractList, java.util.List
                    public Number get(int i3) {
                        return ValueUtil.narrow(row[i3]);
                    }
                };
                double d = 0.0d;
                Iterator<Number> it = abstractList.iterator();
                while (it.hasNext()) {
                    d += it.next().doubleValue();
                }
                id = new ClassifierNode(label.getValue(ScoreDistributionManager.indexOfMax(Doubles.asList(row))), predicate).setId(valueOf).setRecordCount(ValueUtil.narrow(d));
                scoreDistributionManager.addScoreDistributions(id, label.getValues(), abstractList, (List) null);
            } else {
                if (miningFunction != MiningFunction.REGRESSION) {
                    throw new IllegalArgumentException();
                }
                id = new LeafNode(Double.valueOf(dArr2[i]), predicate).setId(valueOf);
            }
            return id;
        }
        BinaryFeature feature = schema.getFeature(i2);
        double d2 = dArr[i];
        CategoryManager categoryManager2 = categoryManager;
        CategoryManager categoryManager3 = categoryManager;
        Boolean bool = null;
        if (iArr4 != null) {
            bool = Boolean.valueOf(iArr4[i] == 1);
        }
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = feature;
            if (d2 < 0.0d || d2 > 1.0d) {
                throw new IllegalArgumentException();
            }
            Object value = binaryFeature.getValue();
            createSimplePredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            createSimplePredicate2 = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
            if (iArr4 != null) {
                bool = Boolean.TRUE;
            }
        } else if (feature instanceof MissingValueFeature) {
            MissingValueFeature missingValueFeature = (MissingValueFeature) feature;
            if (d2 != 0.5d) {
                throw new IllegalArgumentException();
            }
            createSimplePredicate = predicateManager.createSimplePredicate(missingValueFeature, SimplePredicate.Operator.IS_NOT_MISSING, (Object) null);
            createSimplePredicate2 = predicateManager.createSimplePredicate(missingValueFeature, SimplePredicate.Operator.IS_MISSING, (Object) null);
        } else if (feature instanceof ThresholdFeature) {
            ThresholdFeature thresholdFeature = (ThresholdFeature) feature;
            String name = thresholdFeature.getName();
            Object missingValue = thresholdFeature.getMissingValue();
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            if (!ValueUtil.isNaN(missingValue)) {
                valueFilter = valueFilter.and(obj -> {
                    return !ValueUtil.isNaN(obj);
                });
            }
            List list = (List) thresholdFeature.getValues(number -> {
                return toSplitValue(number) <= d2;
            }).stream().filter(valueFilter).collect(Collectors.toList());
            List list2 = (List) thresholdFeature.getValues(number2 -> {
                return toSplitValue(number2) > d2;
            }).stream().filter(valueFilter).collect(Collectors.toList());
            categoryManager2 = categoryManager2.fork(name, list);
            categoryManager3 = categoryManager3.fork(name, list2);
            createSimplePredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, list, missingValue, predicateManager);
            createSimplePredicate2 = ThresholdFeatureUtil.createPredicate(thresholdFeature, list2, missingValue, predicateManager);
        } else {
            ContinuousFeature continuousFeature = toContinuousFeature(feature);
            Object valueOf2 = d2 == Double.POSITIVE_INFINITY ? "INF" : Double.valueOf(d2);
            createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, valueOf2);
            createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, valueOf2);
        }
        int i3 = iArr[i];
        int i4 = iArr2[i];
        Node encodeNode = encodeNode(i3, createSimplePredicate, miningFunction, iArr, iArr2, iArr3, dArr, dArr2, iArr4, categoryManager2, predicateManager, scoreDistributionManager, schema);
        Node encodeNode2 = encodeNode(i4, createSimplePredicate2, miningFunction, iArr, iArr2, iArr3, dArr, dArr2, iArr4, categoryManager3, predicateManager, scoreDistributionManager, schema);
        if (miningFunction == MiningFunction.CLASSIFICATION) {
            branchNode = new ClassifierNode((Object) null, predicate);
        } else {
            if (miningFunction != MiningFunction.REGRESSION) {
                throw new IllegalArgumentException();
            }
            branchNode = new BranchNode(Double.valueOf(dArr2[i]), predicate);
        }
        branchNode.setId(valueOf).addNodes(encodeNode, encodeNode2);
        if (bool != null) {
            branchNode.setDefaultChild(bool.booleanValue() ? encodeNode.getId() : encodeNode2.getId());
        }
        return branchNode;
    }

    public static <E extends Estimator & HasTreeOptions> Schema configureSchema(E e, Schema schema) {
        return toTreeModelSchema((Boolean) e.getOption(HasTreeOptions.OPTION_NUMERIC, Boolean.TRUE), (Boolean) e.getOption(HasTreeOptions.OPTION_INPUT_FLOAT, null), schema);
    }

    public static <E extends Estimator & HasTreeOptions, M extends Model> M configureModel(E e, M m) {
        Boolean bool = (Boolean) e.getOption(HasTreeOptions.OPTION_ALLOW_MISSING, Boolean.FALSE);
        Boolean bool2 = (Boolean) e.getOption(HasTreeOptions.OPTION_WINNER_ID, Boolean.FALSE);
        Map map = (Map) e.getOption(HasTreeOptions.OPTION_NODE_EXTENSIONS, null);
        Boolean bool3 = (Boolean) e.getOption(HasTreeOptions.OPTION_NODE_ID, bool2);
        Boolean bool4 = (Boolean) e.getOption(HasTreeOptions.OPTION_NODE_SCORE, bool2.booleanValue() ? Boolean.TRUE : null);
        boolean z = map != null || (bool3 != null && bool3.booleanValue()) || (bool4 != null && bool4.booleanValue());
        Boolean bool5 = (Boolean) e.getOption(HasTreeOptions.OPTION_COMPACT, z ? Boolean.FALSE : Boolean.TRUE);
        Boolean bool6 = (Boolean) e.getOption(HasTreeOptions.OPTION_FLAT, Boolean.FALSE);
        Boolean bool7 = (Boolean) e.getOption(HasTreeOptions.OPTION_PRUNE, z ? Boolean.FALSE : Boolean.TRUE);
        if (bool5.booleanValue() || bool6.booleanValue() || bool7.booleanValue()) {
            if (z) {
                throw new SkLearnException("Conflicting tree model options");
            }
            map = null;
            bool3 = bool2.booleanValue() ? Boolean.TRUE : bool;
            bool4 = bool2.booleanValue() ? Boolean.TRUE : null;
        }
        if (Boolean.TRUE.equals(bool2)) {
            encodeNodeId(e, m);
        }
        ArrayList arrayList = new ArrayList();
        if (Boolean.FALSE.equals(bool)) {
            arrayList.add(new AbstractVisitor() { // from class: sklearn.tree.TreeUtil.2
                public VisitorAction visit(TreeModel treeModel) {
                    treeModel.setMissingValueStrategy((TreeModel.MissingValueStrategy) null);
                    return super.visit(treeModel);
                }

                public VisitorAction visit(Node node) {
                    if (node.getDefaultChild() != null) {
                        node.setDefaultChild((Object) null);
                    }
                    return super.visit(node);
                }
            });
        }
        if (Boolean.TRUE.equals(bool7)) {
            arrayList.add(new TreeModelPruner());
        }
        if (Boolean.TRUE.equals(bool5)) {
            arrayList.add(new TreeModelCompactor());
        }
        if (Boolean.TRUE.equals(bool6)) {
            arrayList.add(new TreeModelFlattener());
        }
        if (map != null) {
            for (Map.Entry entry : map.entrySet()) {
                String str = (String) entry.getKey();
                final Map map2 = (Map) entry.getValue();
                arrayList.add(new AbstractExtender(str) { // from class: sklearn.tree.TreeUtil.3
                    private NodeTransformer nodeTransformer = SimplifyingNodeTransformer.INSTANCE;

                    public VisitorAction visit(TreeModel treeModel) {
                        treeModel.setNode(ensureExtensibility(treeModel.getNode()));
                        return super.visit(treeModel);
                    }

                    public VisitorAction visit(Node node) {
                        if (node.hasNodes()) {
                            ListIterator listIterator = node.getNodes().listIterator();
                            while (listIterator.hasNext()) {
                                listIterator.set(ensureExtensibility((Node) listIterator.next()));
                            }
                        }
                        Object value = getValue(node);
                        if (value != null) {
                            addExtension((HasExtensions) node, ValueUtil.asString(ScalarUtil.decode(value)));
                        }
                        return super.visit(node);
                    }

                    private Node ensureExtensibility(Node node) {
                        if (!(node instanceof HasExtensions) && getValue(node) != null) {
                            return this.nodeTransformer.toComplexNode(node);
                        }
                        return node;
                    }

                    private Object getValue(Node node) {
                        return map2.get(ValueUtil.asInteger((Number) node.getId()));
                    }
                });
            }
        }
        if (Boolean.FALSE.equals(bool3)) {
            arrayList.add(new AbstractVisitor() { // from class: sklearn.tree.TreeUtil.4
                public VisitorAction visit(Node node) {
                    node.setId((Object) null);
                    return super.visit(node);
                }
            });
        }
        if (Boolean.FALSE.equals(bool4)) {
            arrayList.add(new AbstractVisitor() { // from class: sklearn.tree.TreeUtil.5
                public VisitorAction visit(Node node) {
                    if (node.hasNodes()) {
                        node.setScore((Object) null);
                        if (node.hasScoreDistributions()) {
                            node.getScoreDistributions().clear();
                        }
                    }
                    return super.visit(node);
                }
            });
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Visitor) it.next()).applyTo(m);
        }
        return m;
    }

    static Schema toTreeModelSchema(final Boolean bool, final Boolean bool2, Schema schema) {
        return schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: sklearn.tree.TreeUtil.6
            @Override // java.util.function.Function
            public Feature apply(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    return (BinaryFeature) feature;
                }
                if (feature instanceof MissingValueFeature) {
                    return (MissingValueFeature) feature;
                }
                if ((feature instanceof ThresholdFeature) && bool != null && !bool.booleanValue()) {
                    return (ThresholdFeature) feature;
                }
                if (bool2 == null || !bool2.booleanValue()) {
                    return feature.toContinuousFeature(DataType.FLOAT);
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                if (continuousFeature.getDataType() == DataType.FLOAT) {
                    return continuousFeature;
                }
                HasContinuousDomain field = continuousFeature.getField();
                field.setDataType(DataType.FLOAT);
                if (field instanceof HasContinuousDomain) {
                    HasContinuousDomain hasContinuousDomain = field;
                    if (hasContinuousDomain.hasIntervals()) {
                        for (Interval interval : hasContinuousDomain.getIntervals()) {
                            Number leftMargin = interval.getLeftMargin();
                            Number rightMargin = interval.getRightMargin();
                            if (leftMargin != null) {
                                interval.setLeftMargin(Double.valueOf(leftMargin.floatValue()));
                            }
                            if (rightMargin != null) {
                                interval.setRightMargin(Double.valueOf(rightMargin.floatValue()));
                            }
                        }
                    }
                }
                return new ContinuousFeature(continuousFeature.getEncoder(), field);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Schema toTreeModelFeatureImportanceSchema(Schema schema) {
        return schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: sklearn.tree.TreeUtil.7
            @Override // java.util.function.Function
            public Feature apply(Feature feature) {
                return feature instanceof BinaryFeature ? (BinaryFeature) feature : feature instanceof MissingValueFeature ? (MissingValueFeature) feature : feature instanceof ThresholdFeature ? (ThresholdFeature) feature : TreeUtil.toContinuousFeature(feature);
            }
        });
    }

    private static ContinuousFeature toContinuousFeature(Feature feature) {
        return feature.toContinuousFeature(DataType.FLOAT).toContinuousFeature(DataType.DOUBLE);
    }

    private static double toSplitValue(Number number) {
        return number.floatValue();
    }

    private static double[] getRow(double[] dArr, int i, int i2, int i3) {
        if (dArr.length != i * i2) {
            throw new IllegalArgumentException("Expected " + (i * i2) + " element(s), got " + dArr.length + " element(s)");
        }
        double[] dArr2 = new double[i2];
        System.arraycopy(dArr, i3 * i2, dArr2, 0, i2);
        return dArr2;
    }

    private static void encodeNodeId(Estimator estimator, Model model) {
        if (model instanceof TreeModel) {
            estimator.encodeApplyOutput((TreeModel) model, DataType.INTEGER);
            return;
        }
        if (!(model instanceof MiningModel)) {
            throw new IllegalArgumentException();
        }
        MiningModel miningModel = (MiningModel) model;
        List<Segment> requireSegments = miningModel.requireSegmentation().requireSegments();
        ArrayList arrayList = new ArrayList();
        for (Segment segment : requireSegments) {
            segment.requireModel(TreeModel.class);
            String id = segment.getId();
            if (id == null) {
                throw new UnsupportedElementException(segment);
            }
            arrayList.add(id);
        }
        estimator.encodeMultiApplyOutput(miningModel, DataType.INTEGER, arrayList);
    }
}
