package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationFunctionAdapter;
import io.trino.operator.aggregation.state.GenericBooleanState;
import io.trino.operator.aggregation.state.GenericBooleanStateSerializer;
import io.trino.operator.aggregation.state.GenericDoubleState;
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.GenericSliceState;
import io.trino.operator.aggregation.state.GenericSliceStateSerializer;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.lambda.BinaryFunctionInterface;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;

/* loaded from: input_file:io/trino/operator/aggregation/ReduceAggregationFunction.class */
public class ReduceAggregationFunction extends SqlAggregationFunction {
    private static final String NAME = "reduce_agg";
    public static final ReduceAggregationFunction REDUCE_AGG = new ReduceAggregationFunction();
    private static final MethodHandle LONG_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericLongState.class, Object.class, Long.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle DOUBLE_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericDoubleState.class, Object.class, Double.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle BOOLEAN_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericBooleanState.class, Object.class, Boolean.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle SLICE_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericSliceState.class, Object.class, Slice.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle LONG_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle DOUBLE_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle BOOLEAN_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle SLICE_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericSliceState.class, GenericSliceState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle LONG_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class);
    private static final MethodHandle DOUBLE_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class);
    private static final MethodHandle BOOLEAN_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class);
    private static final MethodHandle SLICE_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericSliceState.class, "write", Type.class, GenericSliceState.class, BlockBuilder.class);

    public ReduceAggregationFunction() {
        super(FunctionMetadata.aggregateBuilder(NAME).signature(Signature.builder().typeVariable("T").typeVariable("S").returnType(new TypeSignature("S", new TypeSignatureParameter[0])).argumentType(new TypeSignature("T", new TypeSignatureParameter[0])).argumentType(new TypeSignature("S", new TypeSignatureParameter[0])).argumentType(TypeSignature.functionType(new TypeSignature("S", new TypeSignatureParameter[0]), new TypeSignature[]{new TypeSignature("T", new TypeSignatureParameter[0]), new TypeSignature("S", new TypeSignatureParameter[0])})).argumentType(TypeSignature.functionType(new TypeSignature("S", new TypeSignatureParameter[0]), new TypeSignature[]{new TypeSignature("S", new TypeSignatureParameter[0]), new TypeSignature("S", new TypeSignatureParameter[0])})).build()).description("Reduce input elements into a single value").build(), AggregationFunctionMetadata.builder().intermediateType(new TypeSignature("S", new TypeSignatureParameter[0])).build());
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public AggregationImplementation specialize(BoundSignature boundSignature) {
        Type type = (Type) boundSignature.getArgumentTypes().get(0);
        Type type2 = (Type) boundSignature.getArgumentTypes().get(1);
        if (type2.getJavaType() == Long.TYPE) {
            return AggregationImplementation.builder().inputFunction(normalizeInputMethod(boundSignature, type, LONG_STATE_INPUT_FUNCTION)).combineFunction(LONG_STATE_COMBINE_FUNCTION).outputFunction(LONG_STATE_OUTPUT_FUNCTION.bindTo(type2)).accumulatorStateDescriptor(GenericLongState.class, new GenericLongStateSerializer(type2), StateCompiler.generateStateFactory(GenericLongState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        if (type2.getJavaType() == Double.TYPE) {
            return AggregationImplementation.builder().inputFunction(normalizeInputMethod(boundSignature, type, DOUBLE_STATE_INPUT_FUNCTION)).combineFunction(DOUBLE_STATE_COMBINE_FUNCTION).outputFunction(DOUBLE_STATE_OUTPUT_FUNCTION.bindTo(type2)).accumulatorStateDescriptor(GenericDoubleState.class, new GenericDoubleStateSerializer(type2), StateCompiler.generateStateFactory(GenericDoubleState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        if (type2.getJavaType() == Boolean.TYPE) {
            return AggregationImplementation.builder().inputFunction(normalizeInputMethod(boundSignature, type, BOOLEAN_STATE_INPUT_FUNCTION)).combineFunction(BOOLEAN_STATE_COMBINE_FUNCTION).outputFunction(BOOLEAN_STATE_OUTPUT_FUNCTION.bindTo(type2)).accumulatorStateDescriptor(GenericBooleanState.class, new GenericBooleanStateSerializer(type2), StateCompiler.generateStateFactory(GenericBooleanState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        if (type2.getJavaType() == Slice.class) {
            return AggregationImplementation.builder().inputFunction(normalizeInputMethod(boundSignature, type, SLICE_STATE_INPUT_FUNCTION)).combineFunction(SLICE_STATE_COMBINE_FUNCTION).outputFunction(SLICE_STATE_OUTPUT_FUNCTION.bindTo(type2)).accumulatorStateDescriptor(GenericSliceState.class, new GenericSliceStateSerializer(type2), StateCompiler.generateStateFactory(GenericSliceState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        throw new TrinoException(StandardErrorCode.NOT_SUPPORTED, String.format("State type not supported for %s: %s", NAME, type2.getDisplayName()));
    }

    private static MethodHandle normalizeInputMethod(BoundSignature boundSignature, Type type, MethodHandle methodHandle) {
        return AggregationFunctionAdapter.normalizeInputMethod(methodHandle.asType(methodHandle.type().changeParameterType(1, type.getJavaType())), boundSignature, ImmutableList.of(AggregationFunctionAdapter.AggregationParameterKind.STATE, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL, AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL), 2);
    }

    public static void input(GenericLongState genericLongState, Object obj, long j, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericLongState.isNull()) {
            genericLongState.setNull(false);
            genericLongState.setValue(j);
        }
        genericLongState.setValue(((Long) binaryFunctionInterface.apply(Long.valueOf(genericLongState.getValue()), obj)).longValue());
    }

    public static void input(GenericDoubleState genericDoubleState, Object obj, double d, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericDoubleState.isNull()) {
            genericDoubleState.setNull(false);
            genericDoubleState.setValue(d);
        }
        genericDoubleState.setValue(((Double) binaryFunctionInterface.apply(Double.valueOf(genericDoubleState.getValue()), obj)).doubleValue());
    }

    public static void input(GenericBooleanState genericBooleanState, Object obj, boolean z, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericBooleanState.isNull()) {
            genericBooleanState.setNull(false);
            genericBooleanState.setValue(z);
        }
        genericBooleanState.setValue(((Boolean) binaryFunctionInterface.apply(Boolean.valueOf(genericBooleanState.getValue()), obj)).booleanValue());
    }

    public static void input(GenericSliceState genericSliceState, Object obj, Slice slice, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericSliceState.isNull()) {
            genericSliceState.setNull(false);
            genericSliceState.setValue(slice);
        }
        genericSliceState.setValue((Slice) binaryFunctionInterface.apply(genericSliceState.getValue(), obj));
    }

    public static void combine(GenericLongState genericLongState, GenericLongState genericLongState2, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericLongState.isNull()) {
            genericLongState.set(genericLongState2);
        } else {
            genericLongState.setValue(((Long) binaryFunctionInterface2.apply(Long.valueOf(genericLongState.getValue()), Long.valueOf(genericLongState2.getValue()))).longValue());
        }
    }

    public static void combine(GenericDoubleState genericDoubleState, GenericDoubleState genericDoubleState2, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericDoubleState.isNull()) {
            genericDoubleState.set(genericDoubleState2);
        } else {
            genericDoubleState.setValue(((Double) binaryFunctionInterface2.apply(Double.valueOf(genericDoubleState.getValue()), Double.valueOf(genericDoubleState2.getValue()))).doubleValue());
        }
    }

    public static void combine(GenericBooleanState genericBooleanState, GenericBooleanState genericBooleanState2, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericBooleanState.isNull()) {
            genericBooleanState.set(genericBooleanState2);
        } else {
            genericBooleanState.setValue(((Boolean) binaryFunctionInterface2.apply(Boolean.valueOf(genericBooleanState.getValue()), Boolean.valueOf(genericBooleanState2.getValue()))).booleanValue());
        }
    }

    public static void combine(GenericSliceState genericSliceState, GenericSliceState genericSliceState2, BinaryFunctionInterface binaryFunctionInterface, BinaryFunctionInterface binaryFunctionInterface2) {
        if (genericSliceState.isNull()) {
            genericSliceState.set(genericSliceState2);
        } else {
            genericSliceState.setValue((Slice) binaryFunctionInterface2.apply(genericSliceState.getValue(), genericSliceState2.getValue()));
        }
    }
}
