package io.trino.sql.gen.columnar;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.FieldDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.Constant;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.type.FunctionType;
import io.trino.util.CompilerUtils;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;

/* loaded from: input_file:io/trino/sql/gen/columnar/CallColumnarFilterGenerator.class */
public class CallColumnarFilterGenerator {
    private final CallExpression callExpression;
    private final FunctionManager functionManager;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/gen/columnar/CallColumnarFilterGenerator$CachedInstanceBinder.class */
    public static final class CachedInstanceBinder {
        private final ClassDefinition classDefinition;
        private final CallSiteBinder callSiteBinder;
        private Optional<FieldDefinition> field = Optional.empty();
        private Optional<MethodHandle> method = Optional.empty();

        public CachedInstanceBinder(ClassDefinition classDefinition, CallSiteBinder callSiteBinder) {
            this.classDefinition = (ClassDefinition) Objects.requireNonNull(classDefinition, "classDefinition is null");
            this.callSiteBinder = (CallSiteBinder) Objects.requireNonNull(callSiteBinder, "callSiteBinder is null");
        }

        public FieldDefinition getCachedInstance(MethodHandle methodHandle) {
            if (this.field.isEmpty()) {
                this.field = Optional.of(this.classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "__cachedInstance", methodHandle.type().returnType()));
                this.method = Optional.of(methodHandle);
            }
            return this.field.get();
        }

        public void generateInitializations(Variable variable, BytecodeBlock bytecodeBlock) {
            if (this.field.isPresent()) {
                bytecodeBlock.append(variable).append(BytecodeUtils.invoke(this.callSiteBinder.bind(this.method.orElseThrow()), "instanceFieldConstructor", new BytecodeExpression[0])).putField(this.field.get());
            }
        }
    }

    public CallColumnarFilterGenerator(CallExpression callExpression, FunctionManager functionManager) {
        callExpression.arguments().forEach(rowExpression -> {
            if (!(rowExpression instanceof InputReferenceExpression) && !(rowExpression instanceof ConstantExpression)) {
                throw new UnsupportedOperationException("Call expression with unsupported argument: " + String.valueOf(rowExpression));
            }
            if ((rowExpression instanceof ConstantExpression) && ((ConstantExpression) rowExpression).value() == null) {
                throw new UnsupportedOperationException("Call expressions with null constant are not supported");
            }
        });
        callExpression.resolvedFunction().signature().getArgumentTypes().forEach(type -> {
            if (type instanceof FunctionType) {
                throw new UnsupportedOperationException("Functions with lambda arguments are not supported");
            }
        });
        this.callExpression = callExpression;
        this.functionManager = (FunctionManager) Objects.requireNonNull(functionManager, "functionManager is null");
    }

    public Supplier<ColumnarFilter> generateColumnarFilter() {
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName(ColumnarFilter.class.getSimpleName() + String.valueOf(this.callExpression.resolvedFunction().signature().getName()), Optional.empty()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(ColumnarFilter.class)});
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
        ColumnarFilterCompiler.generateGetInputChannels(callSiteBinder, classDefinition, this.callExpression);
        generateFilterRangeMethod(classDefinition, callSiteBinder, cachedInstanceBinder);
        generateFilterListMethod(classDefinition, callSiteBinder, cachedInstanceBinder);
        generateConstructor(classDefinition, cachedInstanceBinder);
        return ColumnarFilterCompiler.createClassInstance(callSiteBinder, classDefinition);
    }

    private void generateFilterRangeMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder) {
        Parameter arg = Parameter.arg("session", ConnectorSession.class);
        Parameter arg2 = Parameter.arg("outputPositions", int[].class);
        Parameter arg3 = Parameter.arg("offset", Integer.TYPE);
        Parameter arg4 = Parameter.arg("size", Integer.TYPE);
        Parameter arg5 = Parameter.arg("page", SourcePage.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "filterPositionsRange", ParameterizedType.type(Integer.TYPE), ImmutableList.of(arg, arg2, arg3, arg4, arg5));
        Scope scope = declareMethod.getScope();
        BytecodeBlock body = declareMethod.getBody();
        ColumnarFilterCompiler.declareBlockVariables(this.callExpression.arguments(), arg5, scope, body);
        Variable declareVariable = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt(0));
        Variable declareVariable2 = scope.declareVariable(Integer.TYPE, "position");
        Variable declareVariable3 = scope.declareVariable(Boolean.TYPE, "result");
        FunctionNullability functionNullability = this.callExpression.resolvedFunction().functionNullability();
        IfStatement condition = new IfStatement().condition(ColumnarFilterCompiler.generateBlockMayHaveNull(this.callExpression.arguments(), functionNullability.getArgumentNullable(), scope));
        body.append(condition);
        Function function = methodHandle -> {
            return scope.getThis().getField(cachedInstanceBinder.getCachedInstance(methodHandle));
        };
        condition.ifTrue(new ForLoop("nullable range based loop", new Object[0]).initialize(declareVariable2.set(arg3)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg3, arg4))).update(declareVariable2.increment()).body(new IfStatement().condition(ColumnarFilterCompiler.generateBlockPositionNotNull(this.callExpression.arguments(), functionNullability.getArgumentNullable(), scope, declareVariable2)).ifTrue(new BytecodeBlock().append(generateFullInvocation(this.functionManager, function, callSiteBinder, this.callExpression, scope, declareVariable2).putVariable(declareVariable3)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable3, declareVariable2, arg2, declareVariable)))));
        condition.ifFalse(new ForLoop("nullable function range based loop", new Object[0]).initialize(declareVariable2.set(arg3)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg3, arg4))).update(declareVariable2.increment()).body(new BytecodeBlock().append(generateFullInvocation(this.functionManager, function, callSiteBinder, this.callExpression, scope, declareVariable2).putVariable(declareVariable3)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable3, declareVariable2, arg2, declareVariable))));
        body.append(declareVariable.ret());
    }

    private void generateFilterListMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder) {
        Parameter arg = Parameter.arg("session", ConnectorSession.class);
        Parameter arg2 = Parameter.arg("outputPositions", int[].class);
        Parameter arg3 = Parameter.arg("activePositions", int[].class);
        Parameter arg4 = Parameter.arg("offset", Integer.TYPE);
        Parameter arg5 = Parameter.arg("size", Integer.TYPE);
        Parameter arg6 = Parameter.arg("page", SourcePage.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "filterPositionsList", ParameterizedType.type(Integer.TYPE), ImmutableList.of(arg, arg2, arg3, arg4, arg5, arg6));
        Scope scope = declareMethod.getScope();
        BytecodeBlock body = declareMethod.getBody();
        ColumnarFilterCompiler.declareBlockVariables(this.callExpression.arguments(), arg6, scope, body);
        Variable declareVariable = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt(0));
        Variable declareVariable2 = scope.declareVariable(Integer.TYPE, "index");
        Variable declareVariable3 = scope.declareVariable(Integer.TYPE, "position");
        Variable declareVariable4 = scope.declareVariable(Boolean.TYPE, "result");
        FunctionNullability functionNullability = this.callExpression.resolvedFunction().functionNullability();
        IfStatement condition = new IfStatement().condition(ColumnarFilterCompiler.generateBlockMayHaveNull(this.callExpression.arguments(), functionNullability.getArgumentNullable(), scope));
        body.append(condition);
        Function function = methodHandle -> {
            return scope.getThis().getField(cachedInstanceBinder.getCachedInstance(methodHandle));
        };
        condition.ifTrue(new ForLoop("nullable positions loop", new Object[0]).initialize(declareVariable2.set(arg4)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg4, arg5))).update(declareVariable2.increment()).body(new BytecodeBlock().append(declareVariable3.set(arg3.getElement(declareVariable2))).append(new IfStatement().condition(ColumnarFilterCompiler.generateBlockPositionNotNull(this.callExpression.arguments(), functionNullability.getArgumentNullable(), scope, declareVariable3)).ifTrue(new BytecodeBlock().append(generateFullInvocation(this.functionManager, function, callSiteBinder, this.callExpression, scope, declareVariable3).putVariable(declareVariable4)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable4, declareVariable3, arg2, declareVariable))))));
        condition.ifFalse(new ForLoop("non-nullable positions loop", new Object[0]).initialize(declareVariable2.set(arg4)).condition(BytecodeExpressions.lessThan(declareVariable2, BytecodeExpressions.add(arg4, arg5))).update(declareVariable2.increment()).body(new BytecodeBlock().append(declareVariable3.set(arg3.getElement(declareVariable2))).append(generateFullInvocation(this.functionManager, function, callSiteBinder, this.callExpression, scope, declareVariable3).putVariable(declareVariable4)).append(ColumnarFilterCompiler.updateOutputPositions(declareVariable4, declareVariable3, arg2, declareVariable))));
        body.append(declareVariable.ret());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static BytecodeBlock generateInvocation(FunctionManager functionManager, CallSiteBinder callSiteBinder, CallExpression callExpression, Scope scope, BytecodeExpression bytecodeExpression) {
        return generateFullInvocation(functionManager, methodHandle -> {
            throw new IllegalArgumentException("Simple method invocation can not be used with functions that require an instance factory");
        }, callSiteBinder, callExpression, scope, bytecodeExpression);
    }

    private static BytecodeBlock generateFullInvocation(FunctionManager functionManager, Function<MethodHandle, BytecodeNode> function, CallSiteBinder callSiteBinder, CallExpression callExpression, Scope scope, BytecodeExpression bytecodeExpression) {
        String functionName = callExpression.resolvedFunction().signature().getName().getFunctionName();
        BytecodeBlock description = new BytecodeBlock().setDescription("invoke " + functionName);
        ScalarFunctionImplementation scalarFunctionImplementation = getScalarFunctionImplementation(functionManager, callExpression);
        Binding bind = callSiteBinder.bind(scalarFunctionImplementation.getMethodHandle());
        Optional map = scalarFunctionImplementation.getInstanceFactory().map(function);
        MethodType type = bind.getType();
        boolean z = false;
        for (int i = 0; i < type.parameterArray().length; i++) {
            Class<?> cls = type.parameterArray()[i];
            if (map.isPresent() && !z) {
                Preconditions.checkState(cls.equals(((MethodHandle) scalarFunctionImplementation.getInstanceFactory().get()).type().returnType()), "Mismatched type for instance parameter");
                description.append((BytecodeNode) map.get());
                z = true;
            } else if (cls == ConnectorSession.class) {
                description.append(scope.getVariable("session"));
            }
        }
        for (RowExpression rowExpression : callExpression.arguments()) {
            if (rowExpression instanceof InputReferenceExpression) {
                description.append(generateInputReference(scope.getVariable("block_" + ((InputReferenceExpression) rowExpression).field()), bytecodeExpression));
            } else {
                if (!(rowExpression instanceof ConstantExpression)) {
                    throw new UnsupportedOperationException(String.format("CallExpression %s is not supported", callExpression));
                }
                description.append(generateConstant(callSiteBinder, (ConstantExpression) rowExpression));
            }
        }
        description.append(BytecodeUtils.invoke(bind, functionName, new BytecodeExpression[0]));
        return description;
    }

    private static ScalarFunctionImplementation getScalarFunctionImplementation(FunctionManager functionManager, CallExpression callExpression) {
        ResolvedFunction resolvedFunction = callExpression.resolvedFunction();
        List<RowExpression> arguments = callExpression.arguments();
        ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(arguments.size());
        for (RowExpression rowExpression : arguments) {
            if (rowExpression instanceof InputReferenceExpression) {
                builderWithExpectedSize.add(InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION);
            } else {
                if (!(rowExpression instanceof ConstantExpression)) {
                    throw new UnsupportedOperationException(String.format("CallExpression %s is not supported", callExpression));
                }
                builderWithExpectedSize.add(InvocationConvention.InvocationArgumentConvention.NEVER_NULL);
            }
        }
        return functionManager.getScalarFunctionImplementation(resolvedFunction, new InvocationConvention(builderWithExpectedSize.build(), resolvedFunction.functionNullability().isReturnNullable() ? InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL : InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, true, true));
    }

    private static BytecodeNode generateInputReference(BytecodeExpression bytecodeExpression, BytecodeExpression bytecodeExpression2) {
        BytecodeBlock bytecodeBlock = new BytecodeBlock();
        bytecodeBlock.append(bytecodeExpression);
        bytecodeBlock.append(bytecodeExpression2);
        return bytecodeBlock;
    }

    private static BytecodeNode generateConstant(CallSiteBinder callSiteBinder, ConstantExpression constantExpression) {
        Object value = constantExpression.value();
        Class javaType = constantExpression.type().getJavaType();
        BytecodeBlock bytecodeBlock = new BytecodeBlock();
        bytecodeBlock.comment("constant " + String.valueOf(constantExpression.type().getTypeSignature()));
        if (javaType == Boolean.TYPE) {
            return bytecodeBlock.append(Constant.loadBoolean(((Boolean) value).booleanValue()));
        }
        if (javaType == Long.TYPE) {
            return bytecodeBlock.append(Constant.loadLong(((Long) value).longValue()));
        }
        if (javaType == Double.TYPE) {
            return bytecodeBlock.append(Constant.loadDouble(((Double) value).doubleValue()));
        }
        if (javaType == String.class) {
            return bytecodeBlock.append(Constant.loadString((String) value));
        }
        return new BytecodeBlock().setDescription("constant " + String.valueOf(constantExpression.type())).comment(constantExpression.toString()).append(BytecodeUtils.loadConstant(callSiteBinder.bind(value, constantExpression.type().getJavaType())));
    }

    private static void generateConstructor(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder) {
        MethodDefinition declareConstructor = classDefinition.declareConstructor(Access.a(new Access[]{Access.PUBLIC}), new Parameter[0]);
        BytecodeBlock body = declareConstructor.getBody();
        Variable variable = declareConstructor.getThis();
        body.comment("super();").append(variable).invokeConstructor(Object.class, new Class[0]);
        cachedInstanceBinder.generateInitializations(variable, body);
        body.ret();
    }
}
