package sklearn2pmml.preprocessing;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasDefaultValue;
import org.dmg.pmml.HasInvalidValueTreatment;
import org.dmg.pmml.HasMapMissingTo;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.SkLearnException;
import pandas.CategoricalDtypeUtil;
import pandas.core.CategoricalDtype;
import sklearn.Transformer;
import sklearn2pmml.util.EvaluatableUtil;

/* loaded from: input_file:sklearn2pmml/preprocessing/ExpressionTransformer.class */
public class ExpressionTransformer extends Transformer {
    private static final String INVALIDVALUETREATMENT_AS_MISSING = "as_missing";
    private static final String INVALIDVALUETREATMENT_RETURN_INVALID = "return_invalid";

    public ExpressionTransformer() {
        this("sklearn2pmml.preprocessing", "ExpressionTransformer");
    }

    public ExpressionTransformer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        DataType dataType;
        Object expr = getExpr();
        Object mapMissingTo = getMapMissingTo();
        Object defaultValue = getDefaultValue();
        InvalidValueTreatmentMethod parseInvalidValueTreatment = parseInvalidValueTreatment(getInvalidValueTreatment());
        CategoricalDtype dType = getDType();
        if (ValueUtil.isNaN(defaultValue)) {
            defaultValue = null;
        }
        if (ValueUtil.isNaN(mapMissingTo)) {
            mapMissingTo = null;
        }
        DataFrameScope dataFrameScope = new DataFrameScope("X", list, skLearnEncoder);
        Expression translateExpression = EvaluatableUtil.translateExpression(expr, dataFrameScope);
        DerivedField derivedField = null;
        if (translateExpression instanceof FieldRef) {
            derivedField = skLearnEncoder.getDerivedField(((FieldRef) translateExpression).requireField());
            if (derivedField != null) {
                translateExpression = derivedField.getExpression();
            }
        }
        if ((translateExpression instanceof Apply) && !isConstantLike((Apply) translateExpression)) {
            ArrayList arrayList = new ArrayList();
            if (mapMissingTo != null) {
                arrayList.add(ClassDictUtil.formatMember(this, "map_missing_to"));
            }
            if (defaultValue != null) {
                arrayList.add(ClassDictUtil.formatMember(this, "default_value"));
            }
            if (parseInvalidValueTreatment != null) {
                arrayList.add(ClassDictUtil.formatMember(this, "invalid_value_treatment"));
            }
            if (!arrayList.isEmpty()) {
                throw new SkLearnException("The target PMML element for " + String.join(", ", (List) arrayList.stream().map(str -> {
                    return "'" + str + "'";
                }).collect(Collectors.toList())) + " attribute(s) is unclear. Please refactor the expression from inline string representation to UDF representation");
            }
        }
        if (mapMissingTo != null) {
            ((HasMapMissingTo) translateExpression).setMapMissingTo(mapMissingTo);
        }
        if (defaultValue != null) {
            ((HasDefaultValue) translateExpression).setDefaultValue(defaultValue);
        }
        if (parseInvalidValueTreatment != null) {
            ((HasInvalidValueTreatment) translateExpression).setInvalidValueTreatment(parseInvalidValueTreatment);
        }
        if (dType != null) {
            dataType = dType.getDataType();
        } else {
            dataType = ExpressionUtil.getDataType(translateExpression, dataFrameScope);
            if (dataType == null) {
                dataType = DataType.DOUBLE;
            }
        }
        OpType opType = TypeUtil.getOpType(dataType);
        if ((translateExpression instanceof FieldRef) && mapMissingTo == null) {
            Feature resolveFeature = dataFrameScope.resolveFeature(((FieldRef) translateExpression).requireField());
            if (resolveFeature != null) {
                Field field = resolveFeature.getField();
                if (field.requireOpType() == opType && field.requireDataType() == dataType) {
                    if (dType instanceof CategoricalDtype) {
                        resolveFeature = CategoricalDtypeUtil.refineFeature(resolveFeature, dType, skLearnEncoder);
                    }
                    return Collections.singletonList(resolveFeature);
                }
            }
        }
        if (derivedField != null) {
            derivedField.setOpType(opType).setDataType(dataType);
        } else {
            derivedField = skLearnEncoder.createDerivedField(createFieldName("eval", EvaluatableUtil.toString(expr)), opType, dataType, translateExpression);
        }
        Feature createFeature = FeatureUtil.createFeature(derivedField, skLearnEncoder);
        if (dType instanceof CategoricalDtype) {
            createFeature = CategoricalDtypeUtil.refineFeature(createFeature, dType, skLearnEncoder);
        }
        return Collections.singletonList(createFeature);
    }

    public Object getDefaultValue() {
        return getOptionalScalar("default_value");
    }

    public ExpressionTransformer setDefaultValue(Object obj) {
        setattr("default_value", obj);
        return this;
    }

    public TypeInfo getDType() {
        return hasattr("dtype_") ? super.getOptionalDType("dtype_", true) : super.getOptionalDType("dtype", true);
    }

    public ExpressionTransformer setDType(Object obj) {
        setattr("dtype", obj);
        return this;
    }

    public Object getExpr() {
        return hasattr("expr_") ? getString("expr_") : getObject("expr");
    }

    public ExpressionTransformer setExpr(String str) {
        setattr("expr", str);
        return this;
    }

    public String getInvalidValueTreatment() {
        return (String) getOptionalEnum("invalid_value_treatment", this::getOptionalString, Arrays.asList(INVALIDVALUETREATMENT_AS_MISSING, INVALIDVALUETREATMENT_RETURN_INVALID));
    }

    public ExpressionTransformer setInvalidValueTreatment(String str) {
        setattr("invalid_value_treatment", str);
        return this;
    }

    public Object getMapMissingTo() {
        return getOptionalScalar("map_missing_to");
    }

    public ExpressionTransformer setMapMissingTo(Object obj) {
        setattr("map_missing_to", obj);
        return this;
    }

    private static InvalidValueTreatmentMethod parseInvalidValueTreatment(String str) {
        if (str == null) {
            return null;
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -1225415192:
                if (str.equals(INVALIDVALUETREATMENT_RETURN_INVALID)) {
                    z = true;
                    break;
                }
                break;
            case -153099687:
                if (str.equals(INVALIDVALUETREATMENT_AS_MISSING)) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return InvalidValueTreatmentMethod.AS_MISSING;
            case true:
                return InvalidValueTreatmentMethod.RETURN_INVALID;
            default:
                throw new IllegalArgumentException(str);
        }
    }

    private static boolean isConstantLike(Apply apply) {
        if (!apply.hasExpressions()) {
            return true;
        }
        for (FieldRef fieldRef : apply.getExpressions()) {
            if (fieldRef instanceof Constant) {
            } else if (!(fieldRef instanceof FieldRef)) {
                return false;
            }
        }
        return true;
    }
}
