package org.jpmml.xgboost;

import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;

/* loaded from: input_file:org/jpmml/xgboost/Classification.class */
public abstract class Classification extends ObjFunction {
    private int num_class;

    public Classification(String str, int i) {
        super(str);
        this.num_class = i;
    }

    @Override // org.jpmml.xgboost.ObjFunction
    public Label encodeLabel(String str, List<?> list, ModelEncoder modelEncoder) {
        DataField createDataField;
        if (list == null) {
            createDataField = modelEncoder.createDataField(str, OpType.CATEGORICAL, DataType.INTEGER, LabelUtil.createTargetCategories(this.num_class));
        } else {
            if (list.size() != this.num_class) {
                throw new IllegalArgumentException("Expected " + this.num_class + " target categories, got " + list.size() + " target categories");
            }
            createDataField = modelEncoder.createDataField(str, OpType.CATEGORICAL, DataType.STRING, list);
        }
        return new CategoricalLabel(createDataField);
    }

    @Override // org.jpmml.xgboost.ObjFunction
    public MiningModel encodeModel(int i, List<RegTree> list, List<Float> list2, float f, Integer num, Schema schema) {
        MiningModel encodeModel = encodeModel(list, list2, f, num, schema);
        if (i != -1) {
            Output output = MiningModelUtil.getFinalModel(encodeModel).getOutput();
            if (output == null || !output.hasOutputFields()) {
                throw new IllegalArgumentException();
            }
            List outputFields = output.getOutputFields();
            outputFields.removeIf(outputField -> {
                return outputField.getResultFeature() == ResultFeature.PROBABILITY;
            });
            CategoricalLabel label = schema.getLabel();
            Stream map = label.getValues().stream().map(obj -> {
                return ModelUtil.createProbabilityField(FieldNameUtil.create("probability", new Object[]{label.getName(), obj}), DataType.FLOAT, obj);
            });
            Objects.requireNonNull(outputFields);
            map.forEach((v1) -> {
                r1.add(v1);
            });
        }
        return encodeModel;
    }

    public int num_class() {
        return this.num_class;
    }
}
