package sklearn.preprocessing;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;

/* loaded from: input_file:sklearn/preprocessing/TargetEncoder.class */
public class TargetEncoder extends BaseEncoder {
    private static final String TARGETTYPE_BINARY = "binary";
    private static final String TARGETTYPE_CONTINUOUS = "continuous";

    public TargetEncoder(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        List<List<Object>> categories = getCategories();
        List<List<Number>> encodings = getEncodings();
        Number targetMean = getTargetMean();
        getTargetType();
        ClassDictUtil.checkSize(list.size(), new Collection[]{categories, encodings});
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            List<Object> list2 = categories.get(i);
            List<Number> list3 = encodings.get(i);
            ClassDictUtil.checkSize(new Collection[]{list2, list3});
            Number number = null;
            int indexOf = list2.indexOf(getMissingCategory(list2));
            if (indexOf > -1) {
                list2 = new ArrayList<>(list2);
                list2.remove(indexOf);
                list3 = new ArrayList<>(list3);
                number = list3.remove(indexOf);
            }
            skLearnEncoder.toCategorical(feature.getName(), list2);
            arrayList.add(new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(createFieldName("targetEncoder", feature), OpType.CONTINUOUS, DataType.DOUBLE, ExpressionUtil.createMapValues(feature.getName(), list2, list3).setMapMissingTo(number).setDefaultValue(targetMean))));
        }
        return arrayList;
    }

    public List<List<Number>> getEncodings() {
        return getArrayList("encodings_", Number.class);
    }

    public Number getTargetMean() {
        return getNumber("target_mean_");
    }

    public String getTargetType() {
        return (String) getEnum("target_type_", this::getString, Arrays.asList(TARGETTYPE_BINARY, TARGETTYPE_CONTINUOUS));
    }

    private static Object getMissingCategory(List<?> list) {
        for (Object obj : list) {
            if (ValueUtil.isNaN(obj)) {
                return obj;
            }
        }
        return null;
    }
}
