package io.trino.sql.gen.columnar;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.SwitchStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.JumpInstruction;
import io.airlift.bytecode.instruction.LabelNode;
import io.airlift.slice.Slice;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.InCodeGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import io.trino.util.CompilerUtils;
import io.trino.util.FastutilSetHelper;
import java.lang.invoke.MethodHandle;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/gen/columnar/InColumnarFilterGenerator.class */
public class InColumnarFilterGenerator {
    private final InputReferenceExpression valueExpression;
    private final boolean useSwitchCase;
    private final Set<Object> constantValues;
    private final MethodHandle equalsMethodHandle;
    private final MethodHandle hashCodeMethodHandle;

    public InColumnarFilterGenerator(SpecialForm specialForm, FunctionManager functionManager) {
        Preconditions.checkArgument(specialForm.form() == SpecialForm.Form.IN, "specialForm should be IN");
        Preconditions.checkArgument(specialForm.arguments().size() >= 2, "At least two arguments are required");
        if (!(specialForm.arguments().getFirst() instanceof InputReferenceExpression)) {
            throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input references");
        }
        this.valueExpression = (InputReferenceExpression) specialForm.arguments().getFirst();
        List<RowExpression> subList = specialForm.arguments().subList(1, specialForm.arguments().size());
        subList.forEach(rowExpression -> {
            if (!(rowExpression instanceof ConstantExpression)) {
                throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input reference against constants");
            }
        });
        Stream<RowExpression> stream = subList.stream();
        Class<ConstantExpression> cls = ConstantExpression.class;
        Objects.requireNonNull(ConstantExpression.class);
        List<ConstantExpression> list = (List) stream.map((v1) -> {
            return r1.cast(v1);
        }).collect(ImmutableList.toImmutableList());
        Preconditions.checkArgument(specialForm.functionDependencies().size() == 3);
        ResolvedFunction operatorDependency = specialForm.getOperatorDependency(OperatorType.EQUAL);
        ResolvedFunction operatorDependency2 = specialForm.getOperatorDependency(OperatorType.HASH_CODE);
        ResolvedFunction operatorDependency3 = specialForm.getOperatorDependency(OperatorType.INDETERMINATE);
        this.equalsMethodHandle = functionManager.getScalarFunctionImplementation(operatorDependency, InvocationConvention.simpleConvention(InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN, new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        this.hashCodeMethodHandle = functionManager.getScalarFunctionImplementation(operatorDependency2, InvocationConvention.simpleConvention(InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        MethodHandle methodHandle = functionManager.getScalarFunctionImplementation(operatorDependency3, InvocationConvention.simpleConvention(InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        ImmutableSet.Builder builder = ImmutableSet.builder();
        for (ConstantExpression constantExpression : list) {
            if (isDeterminateConstant(constantExpression, methodHandle)) {
                builder.add(constantExpression.value());
            }
        }
        this.constantValues = builder.build();
        this.useSwitchCase = useSwitchCaseGeneration(this.valueExpression.type(), subList);
    }

    public Supplier<ColumnarFilter> generateColumnarFilter() {
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName(ColumnarFilter.class.getSimpleName() + "_in", Optional.empty()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(ColumnarFilter.class)});
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        classDefinition.declareDefaultConstructor(Access.a(new Access[]{Access.PUBLIC}));
        ColumnarFilterCompiler.generateGetInputChannels(callSiteBinder, classDefinition, this.valueExpression);
        Set<?> fastutilHashSet = FastutilSetHelper.toFastutilHashSet(this.constantValues, this.valueExpression.type(), this.hashCodeMethodHandle, this.equalsMethodHandle);
        Binding bind = callSiteBinder.bind(fastutilHashSet, fastutilHashSet.getClass());
        generateFilterRangeMethod(callSiteBinder, classDefinition, fastutilHashSet, bind);
        generateFilterListMethod(callSiteBinder, classDefinition, fastutilHashSet, bind);
        return ColumnarFilterCompiler.createClassInstance(callSiteBinder, classDefinition);
    }

    private void generateFilterRangeMethod(CallSiteBinder callSiteBinder, ClassDefinition classDefinition, Set<?> set, Binding binding) {
        Parameter arg = Parameter.arg("session", ConnectorSession.class);
        Parameter arg2 = Parameter.arg("outputPositions", int[].class);
        Parameter arg3 = Parameter.arg("offset", Integer.TYPE);
        Parameter arg4 = Parameter.arg("size", Integer.TYPE);
        Parameter arg5 = Parameter.arg("page", SourcePage.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "filterPositionsRange", ParameterizedType.type(Integer.TYPE), ImmutableList.of(arg, arg2, arg3, arg4, arg5));
        Scope scope = declareMethod.getScope();
        BytecodeBlock body = declareMethod.getBody();
        ColumnarFilterCompiler.declareBlockVariables(ImmutableList.of(this.valueExpression), arg5, scope, body);
        Variable declareVariable = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt(0));
        Variable declareVariable2 = scope.declareVariable(Integer.TYPE, "position");
        Variable declareVariable3 = scope.declareVariable(Boolean.TYPE, "result");
        IfStatement condition = new IfStatement().condition(ColumnarFilterCompiler.generateBlockMayHaveNull(ImmutableList.of(this.valueExpression), scope));
        body.append(condition);
        condition.ifTrue(new ForLoop("nullable range based loop", new Object[0]).initialize(declareVariable2.set(arg3)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg3, arg4))).update(declareVariable2.increment()).body(new IfStatement().condition(ColumnarFilterCompiler.generateBlockPositionNotNull(ImmutableList.of(this.valueExpression), scope, declareVariable2)).ifTrue(new BytecodeBlock().append(generateSetContainsCall(callSiteBinder, scope, set, binding, declareVariable2, declareVariable3)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable3, declareVariable2, arg2, declareVariable)))));
        condition.ifFalse(new ForLoop("non-nullable range based loop", new Object[0]).initialize(declareVariable2.set(arg3)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg3, arg4))).update(declareVariable2.increment()).body(new BytecodeBlock().append(generateSetContainsCall(callSiteBinder, scope, set, binding, declareVariable2, declareVariable3)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable3, declareVariable2, arg2, declareVariable))));
        body.append(declareVariable.ret());
    }

    private void generateFilterListMethod(CallSiteBinder callSiteBinder, ClassDefinition classDefinition, Set<?> set, Binding binding) {
        Parameter arg = Parameter.arg("session", ConnectorSession.class);
        Parameter arg2 = Parameter.arg("outputPositions", int[].class);
        Parameter arg3 = Parameter.arg("activePositions", int[].class);
        Parameter arg4 = Parameter.arg("offset", Integer.TYPE);
        Parameter arg5 = Parameter.arg("size", Integer.TYPE);
        Parameter arg6 = Parameter.arg("page", SourcePage.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "filterPositionsList", ParameterizedType.type(Integer.TYPE), ImmutableList.of(arg, arg2, arg3, arg4, arg5, arg6));
        Scope scope = declareMethod.getScope();
        BytecodeBlock body = declareMethod.getBody();
        ColumnarFilterCompiler.declareBlockVariables(ImmutableList.of(this.valueExpression), arg6, scope, body);
        Variable declareVariable = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt(0));
        Variable declareVariable2 = scope.declareVariable(Integer.TYPE, "index");
        Variable declareVariable3 = scope.declareVariable(Integer.TYPE, "position");
        Variable declareVariable4 = scope.declareVariable(Boolean.TYPE, "result");
        IfStatement condition = new IfStatement().condition(ColumnarFilterCompiler.generateBlockMayHaveNull(ImmutableList.of(this.valueExpression), scope));
        body.append(condition);
        condition.ifTrue(new ForLoop("nullable positions loop", new Object[0]).initialize(declareVariable2.set(arg4)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg4, arg5))).update(declareVariable2.increment()).body(new BytecodeBlock().append(declareVariable3.set(arg3.getElement(declareVariable2))).append(new IfStatement().condition(ColumnarFilterCompiler.generateBlockPositionNotNull(ImmutableList.of(this.valueExpression), scope, declareVariable3)).ifTrue(new BytecodeBlock().append(generateSetContainsCall(callSiteBinder, scope, set, binding, declareVariable3, declareVariable4)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable4, declareVariable3, arg2, declareVariable))))));
        condition.ifFalse(new ForLoop("non-nullable positions loop", new Object[0]).initialize(declareVariable2.set(arg4)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg4, arg5))).update(declareVariable2.increment()).body(new BytecodeBlock().append(declareVariable3.set(arg3.getElement(declareVariable2))).append(generateSetContainsCall(callSiteBinder, scope, set, binding, declareVariable3, declareVariable4)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable4, declareVariable3, arg2, declareVariable))));
        body.append(declareVariable.ret());
    }

    private BytecodeBlock generateSetContainsCall(CallSiteBinder callSiteBinder, Scope scope, Set<?> set, Binding binding, BytecodeExpression bytecodeExpression, Variable variable) {
        Type type = this.valueExpression.type();
        Class<Object> javaType = type.getJavaType();
        Class<Object> cls = javaType;
        if (!cls.isPrimitive() && cls != Slice.class) {
            cls = Object.class;
        }
        BytecodeExpression invoke = SqlTypeBytecodeExpression.constantType(callSiteBinder, type).invoke("get" + Primitives.wrap(cls).getSimpleName(), cls, new BytecodeExpression[]{scope.getVariable("block_" + this.valueExpression.field()), bytecodeExpression});
        if (cls != javaType) {
            invoke = invoke.cast(javaType);
        }
        if (!this.useSwitchCase) {
            BytecodeBlock comment = new BytecodeBlock().comment("inListSet.contains(<stackValue>)");
            BytecodeBlock append = new BytecodeBlock().comment("value").append(invoke).comment("set").append(BytecodeUtils.loadConstant(binding));
            Class cls2 = Boolean.TYPE;
            Class[] clsArr = new Class[2];
            clsArr[0] = javaType.isPrimitive() ? javaType : Object.class;
            clsArr[1] = set.getClass();
            return comment.append(append.invokeStatic(FastutilSetHelper.class, "in", cls2, clsArr).putVariable(variable));
        }
        LabelNode labelNode = new LabelNode("end");
        LabelNode labelNode2 = new LabelNode("match");
        LabelNode labelNode3 = new LabelNode("default");
        SwitchStatement.SwitchBuilder switchBuilder = SwitchStatement.switchBuilder();
        BytecodeBlock gotoLabel = new BytecodeBlock().setDescription("match").visitLabel(labelNode2).append(variable.set(BytecodeExpressions.constantTrue())).gotoLabel(labelNode);
        BytecodeBlock gotoLabel2 = new BytecodeBlock().setDescription("default").visitLabel(labelNode3).append(variable.set(BytecodeExpressions.constantFalse())).gotoLabel(labelNode);
        Iterator<Object> it = this.constantValues.iterator();
        while (it.hasNext()) {
            switchBuilder.addCase(Math.toIntExact(((Long) it.next()).longValue()), JumpInstruction.jump(labelNode2));
        }
        switchBuilder.defaultCase(JumpInstruction.jump(labelNode3));
        BytecodeExpression createTempVariable = scope.createTempVariable(javaType);
        return new BytecodeBlock().comment("lookupSwitch(<stackValue>))").append(createTempVariable.set(invoke)).append(new IfStatement().condition(BytecodeExpressions.invokeStatic(InCodeGenerator.class, "isInteger", Boolean.TYPE, new BytecodeExpression[]{createTempVariable})).ifFalse(new BytecodeBlock().gotoLabel(labelNode3))).append(switchBuilder.expression(createTempVariable.cast(Integer.TYPE)).build()).append(gotoLabel).append(gotoLabel2).visitLabel(labelNode);
    }

    private static boolean isDeterminateConstant(RowExpression rowExpression, MethodHandle methodHandle) {
        Object value;
        if (!(rowExpression instanceof ConstantExpression) || (value = ((ConstantExpression) rowExpression).value()) == null) {
            return false;
        }
        try {
            return !(boolean) methodHandle.invoke(value);
        } catch (Throwable th) {
            Throwables.throwIfUnchecked(th);
            throw new RuntimeException(th);
        }
    }

    static boolean useSwitchCaseGeneration(Type type, List<RowExpression> list) {
        if (!type.getTypeParameters().isEmpty()) {
            throw new UnsupportedOperationException("Structural type not supported");
        }
        if (list.size() >= 8 || type.getJavaType() != Long.TYPE) {
            return false;
        }
        for (RowExpression rowExpression : list) {
            if (!(rowExpression instanceof ConstantExpression)) {
                throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input reference against constants");
            }
            Object value = ((ConstantExpression) rowExpression).value();
            if (value != null) {
                long longValue = ((Number) value).longValue();
                if (longValue < -2147483648L || longValue > 2147483647L) {
                    return false;
                }
            }
        }
        return true;
    }
}
