package sktree.tree;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;

/* loaded from: input_file:sktree/tree/ProjectionManager.class */
public class ProjectionManager {
    private Map<List<Vector>, Feature> projections = new LinkedHashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:sktree/tree/ProjectionManager$Vector.class */
    public static class Vector {
        private Feature feature = null;
        private Number weight = null;

        private Vector(Feature feature, Number number) {
            setFeature(feature);
            setWeight(number);
        }

        public Feature getFeature() {
            return this.feature;
        }

        public void setFeature(Feature feature) {
            this.feature = (Feature) Objects.requireNonNull(feature);
        }

        public Number getWeight() {
            return this.weight;
        }

        private void setWeight(Number number) {
            this.weight = (Number) Objects.requireNonNull(number);
        }

        public int hashCode() {
            int hashCode = 0 + (31 * 0) + Objects.hashCode(getFeature());
            return hashCode + (31 * hashCode) + Objects.hashCode(getWeight());
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof Vector)) {
                return false;
            }
            Vector vector = (Vector) obj;
            return Objects.equals(getFeature(), vector.getFeature()) && Objects.equals(getWeight(), vector.getWeight());
        }
    }

    public Feature getOrCreateFeature(String str, List<Feature> list, List<Number> list2, SkLearnEncoder skLearnEncoder) {
        ClassDictUtil.checkSize(new Collection[]{list, list2});
        List<Vector> createKey = createKey(list, list2);
        if (createKey.isEmpty()) {
            return null;
        }
        if (this.projections.containsKey(createKey)) {
            return this.projections.get(createKey);
        }
        Feature encodeFeature = encodeFeature(str, createKey, skLearnEncoder);
        this.projections.put(createKey, encodeFeature);
        return encodeFeature;
    }

    private static Feature encodeFeature(String str, List<Vector> list, SkLearnEncoder skLearnEncoder) {
        Apply negative;
        NormDiscrete ref;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Vector vector = list.get(i);
            BinaryFeature feature = vector.getFeature();
            Number weight = vector.getWeight();
            if (list.size() == 1 && weight.doubleValue() == 1.0d) {
                return feature;
            }
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = feature;
                ref = new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue());
            } else {
                ref = feature.toContinuousFeature().ref();
            }
            if (weight.doubleValue() == 1.0d) {
                arrayList.add(ref);
            } else {
                if (weight.doubleValue() != -1.0d) {
                    throw new IllegalArgumentException();
                }
                arrayList2.add(ref);
            }
        }
        Apply aggregate = aggregate(arrayList);
        Expression aggregate2 = aggregate(arrayList2);
        if (aggregate != null) {
            negative = aggregate2 != null ? ExpressionUtil.createApply("-", new Expression[]{aggregate, aggregate2}) : aggregate;
        } else {
            if (aggregate2 == null) {
                throw new IllegalArgumentException();
            }
            negative = ExpressionUtil.toNegative(aggregate2);
        }
        return new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(str, OpType.CONTINUOUS, DataType.FLOAT, negative));
    }

    private static Expression aggregate(List<Expression> list) {
        if (list.isEmpty()) {
            return null;
        }
        if (list.size() == 1) {
            return (Expression) Iterables.getOnlyElement(list);
        }
        Apply createApply = ExpressionUtil.createApply("sum", new Expression[0]);
        createApply.getExpressions().addAll(list);
        return createApply;
    }

    private static List<Vector> createKey(List<Feature> list, List<Number> list2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            Number number = list2.get(i);
            if (!ValueUtil.isZero(number)) {
                arrayList.add(new Vector(feature, number));
            }
        }
        return arrayList;
    }
}
