package sklearn.ensemble.gradient_boosting;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.IntFunction;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.AttributeException;
import org.jpmml.python.PythonObject;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;
import sklearn.HasMultiDecisionFunctionField;
import sklearn.HasPriorProbability;
import sklearn.SkLearnClassifier;
import sklearn.VersionUtil;
import sklearn.loss.BaseLoss;
import sklearn.loss.HalfLogitLink;
import sklearn.loss.Link;
import sklearn.tree.HasTreeOptions;
import sklearn.tree.TreeRegressor;
import sklearn.tree.TreeUtil;
import sklearn2pmml.EstimatorProxy;

/* loaded from: input_file:sklearn/ensemble/gradient_boosting/GradientBoostingClassifier.class */
public class GradientBoostingClassifier extends SkLearnClassifier implements HasEstimatorEnsemble<TreeRegressor>, HasMultiDecisionFunctionField, HasTreeOptions {

    /* loaded from: input_file:sklearn/ensemble/gradient_boosting/GradientBoostingClassifier$GradientBoostingClassifierProxy.class */
    private abstract class GradientBoostingClassifierProxy extends EstimatorProxy implements HasEstimatorEnsemble<TreeRegressor>, HasTreeOptions {
        private GradientBoostingClassifierProxy() {
        }

        @Override // sklearn2pmml.EstimatorProxy, sklearn.HasEstimator
        public Estimator getEstimator() {
            return GradientBoostingClassifier.this;
        }
    }

    public GradientBoostingClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return hasattr("n_features") ? getInteger("n_features").intValue() : super.getNumberOfFeatures();
    }

    @Override // sklearn.Estimator, sklearn.HasType
    public DataType getDataType() {
        return DataType.FLOAT;
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo7encodeModel(Schema schema) {
        MiningModel createClassification;
        String skLearnVersion = getSkLearnVersion();
        HasPriorProbability init = getInit();
        Number learningRate = getLearningRate();
        PythonObject loss = getLoss();
        Objects.requireNonNull(init);
        IntFunction intFunction = init::getPriorProbability;
        if (loss instanceof LossFunction) {
            LossFunction lossFunction = (LossFunction) loss;
            if (skLearnVersion != null && VersionUtil.compareVersion(skLearnVersion, "0.21") >= 0) {
                List<? extends Number> computeInitialPredictions = lossFunction.computeInitialPredictions(init);
                Objects.requireNonNull(computeInitialPredictions);
                intFunction = computeInitialPredictions::get;
            }
        } else if (loss instanceof BaseLoss) {
            BaseLoss baseLoss = (BaseLoss) loss;
            if (skLearnVersion == null || VersionUtil.compareVersion(skLearnVersion, "1.4.0") < 0) {
                throw new IllegalArgumentException();
            }
            List<? extends Number> computeInitialPredictions2 = baseLoss.getLink().computeInitialPredictions(baseLoss.getNumClasses(), init);
            Objects.requireNonNull(computeInitialPredictions2);
            intFunction = computeInitialPredictions2::get;
        }
        Schema anonymousRegressorSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel label = schema.getLabel();
        Transformation[] transformationArr = new Transformation[0];
        if (loss instanceof LossFunction) {
            transformationArr = new Transformation[]{((LossFunction) loss).mo20createTransformation()};
        }
        if (label.size() == 2) {
            SchemaUtil.checkSize(2, label);
            MiningModel output = GradientBoostingUtil.encodeGradientBoosting(this, (Number) intFunction.apply(1), learningRate, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(getMultiDecisionFunctionField(label.getValue(1)), OpType.CONTINUOUS, DataType.DOUBLE, transformationArr));
            double d = 1.0d;
            RegressionModel.NormalizationMethod normalizationMethod = RegressionModel.NormalizationMethod.NONE;
            if (loss instanceof BaseLoss) {
                Link link = ((BaseLoss) loss).getLink();
                normalizationMethod = RegressionModel.NormalizationMethod.LOGIT;
                if (link instanceof HalfLogitLink) {
                    d = 2.0d;
                }
            }
            createClassification = MiningModelUtil.createBinaryLogisticClassification(output, d, 0.0d, normalizationMethod, false, schema);
        } else {
            if (label.size() <= 2) {
                throw new IllegalArgumentException();
            }
            List<? extends TreeRegressor> estimators = getEstimators();
            ArrayList arrayList = new ArrayList();
            int size = label.size();
            int size2 = estimators.size() / size;
            for (int i = 0; i < size; i++) {
                final List column = CMatrixUtil.getColumn(estimators, size2, size, i);
                arrayList.add(GradientBoostingUtil.encodeGradientBoosting(new GradientBoostingClassifierProxy() { // from class: sklearn.ensemble.gradient_boosting.GradientBoostingClassifier.1
                    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                    {
                        super();
                    }

                    @Override // sklearn.HasEstimatorEnsemble
                    public List<? extends TreeRegressor> getEstimators() {
                        return column;
                    }
                }, (Number) intFunction.apply(i), learningRate, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(getMultiDecisionFunctionField(label.getValue(i)), OpType.CONTINUOUS, DataType.DOUBLE, transformationArr)));
            }
            RegressionModel.NormalizationMethod normalizationMethod2 = RegressionModel.NormalizationMethod.SIMPLEMAX;
            if (loss instanceof BaseLoss) {
                normalizationMethod2 = RegressionModel.NormalizationMethod.SOFTMAX;
            }
            createClassification = MiningModelUtil.createClassification(arrayList, normalizationMethod2, false, schema);
        }
        encodePredictProbaOutput(createClassification, DataType.DOUBLE, label);
        return createClassification;
    }

    @Override // sklearn.Estimator
    public Schema configureSchema(Schema schema) {
        return TreeUtil.configureSchema(this, schema);
    }

    @Override // sklearn.Estimator
    public Model configureModel(Model model) {
        return TreeUtil.configureModel(this, model);
    }

    @Override // sklearn.HasEstimatorEnsemble
    public List<? extends TreeRegressor> getEstimators() {
        return getArray("estimators_", TreeRegressor.class);
    }

    public HasPriorProbability getInit() {
        return (HasPriorProbability) get("init_", HasPriorProbability.class);
    }

    public Number getLearningRate() {
        return getNumber("learning_rate");
    }

    public PythonObject getLoss() {
        if (hasattr("loss_")) {
            return (PythonObject) get("loss_", LossFunction.class);
        }
        try {
            return (PythonObject) get("_loss", LossFunction.class);
        } catch (AttributeException e) {
            return (PythonObject) get("_loss", BaseLoss.class);
        }
    }
}
