package sklearn2pmml.preprocessing;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.BlockIndicator;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Lag;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

/* loaded from: input_file:sklearn2pmml/preprocessing/RollingAggregateTransformer.class */
public class RollingAggregateTransformer extends Transformer {
    private static final String FUNCTION_AVG = "avg";
    private static final String FUNCTION_MAX = "max";
    private static final String FUNCTION_MEAN = "mean";
    private static final String FUNCTION_MIN = "min";
    private static final String FUNCTION_PROD = "prod";
    private static final String FUNCTION_PRODUCT = "product";
    private static final String FUNCTION_SUM = "sum";

    public RollingAggregateTransformer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        String function = getFunction();
        Integer n = getN();
        List<Object> blockIndicators = getBlockIndicators();
        Lag.Aggregate parseFunction = parseFunction(function);
        BlockIndicator[] blockIndicatorArr = null;
        if (blockIndicators != null) {
            List<Feature> selectFeatures = BlockIndicatorUtil.selectFeatures(blockIndicators, list);
            list = new ArrayList(list);
            list.removeAll(selectFeatures);
            blockIndicatorArr = BlockIndicatorUtil.toBlockIndicators(selectFeatures);
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Feature feature = list.get(i);
            Lag n2 = new Lag(feature.getField().requireName()).setAggregate(parseFunction).setN(n);
            if (blockIndicatorArr != null) {
                n2 = n2.addBlockIndicators(blockIndicatorArr);
            }
            arrayList.add(new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(FieldNameUtil.create(parseFunction.value(), new Object[]{feature, n}), OpType.CONTINUOUS, DataType.DOUBLE, n2)));
        }
        return arrayList;
    }

    public String getFunction() {
        return (String) getEnum("function", this::getString, Arrays.asList(FUNCTION_AVG, FUNCTION_MAX, FUNCTION_MEAN, FUNCTION_MIN, FUNCTION_PROD, FUNCTION_PRODUCT, FUNCTION_SUM));
    }

    public Integer getN() {
        return getInteger("n");
    }

    public List<Object> getBlockIndicators() {
        if (hasattr("block_indicators")) {
            return getObjectList("block_indicators");
        }
        return null;
    }

    private static Lag.Aggregate parseFunction(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -309474065:
                if (str.equals(FUNCTION_PRODUCT)) {
                    z = 5;
                    break;
                }
                break;
            case 96978:
                if (str.equals(FUNCTION_AVG)) {
                    z = false;
                    break;
                }
                break;
            case 107876:
                if (str.equals(FUNCTION_MAX)) {
                    z = 2;
                    break;
                }
                break;
            case 108114:
                if (str.equals(FUNCTION_MIN)) {
                    z = 3;
                    break;
                }
                break;
            case 114251:
                if (str.equals(FUNCTION_SUM)) {
                    z = 6;
                    break;
                }
                break;
            case 3347397:
                if (str.equals(FUNCTION_MEAN)) {
                    z = true;
                    break;
                }
                break;
            case 3449687:
                if (str.equals(FUNCTION_PROD)) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
                return Lag.Aggregate.AVG;
            case true:
                return Lag.Aggregate.MAX;
            case true:
                return Lag.Aggregate.MIN;
            case true:
            case true:
                return Lag.Aggregate.PRODUCT;
            case true:
                return Lag.Aggregate.SUM;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
