package io.trino.operator.project;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.block.BlockAssertions;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.TestingSourcePage;
import io.trino.operator.project.PageFieldsToInputParametersRewriter;
import io.trino.spi.block.Block;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.In;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.transaction.TestingTransactionManager;
import io.trino.type.FunctionType;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/project/TestPageFieldsToInputParametersRewriter.class */
public class TestPageFieldsToInputParametersRewriter {
    private static final TestingTransactionManager TRANSACTION_MANAGER = new TestingTransactionManager();
    private static final PlannerContext PLANNER_CONTEXT = TestingPlannerContext.plannerContextBuilder().withTransactionManager(TRANSACTION_MANAGER).build();
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction CEIL = FUNCTIONS.resolveFunction("ceil", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT}));
    private static final ResolvedFunction ROUND = FUNCTIONS.resolveFunction("round", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT}));
    private static final ResolvedFunction TRANSFORM = FUNCTIONS.resolveFunction("transform", TypeSignatureProvider.fromTypes(new Type[]{new ArrayType(BigintType.BIGINT), new FunctionType(ImmutableList.of(BigintType.BIGINT), IntegerType.INTEGER)}));
    private static final ResolvedFunction ZIP_WITH = FUNCTIONS.resolveFunction("zip_with", TypeSignatureProvider.fromTypes(new Type[]{new ArrayType(BigintType.BIGINT), new ArrayType(BigintType.BIGINT), new FunctionType(ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT), BigintType.BIGINT)}));
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction DIVIDE_BIGINT = FUNCTIONS.resolveOperator(OperatorType.DIVIDE, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction MULTIPLY_INTEGER = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER));
    private static final ResolvedFunction MODULUS_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MODULUS, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction NEGATION_BIGINT = FUNCTIONS.resolveOperator(OperatorType.NEGATION, ImmutableList.of(BigintType.BIGINT));

    /* loaded from: input_file:io/trino/operator/project/TestPageFieldsToInputParametersRewriter$RowExpressionBuilder.class */
    private static class RowExpressionBuilder {
        private final Map<Symbol, Integer> sourceLayout = new HashMap();
        private final List<Type> types = new LinkedList();

        private RowExpressionBuilder() {
        }

        private static RowExpressionBuilder create() {
            return new RowExpressionBuilder();
        }

        private RowExpressionBuilder addSymbol(String str, Type type) {
            this.sourceLayout.put(new Symbol(type, str), Integer.valueOf(this.types.size()));
            this.types.add(type);
            return this;
        }

        private RowExpression buildExpression(Expression expression) {
            return SqlToRowExpressionTranslator.translate(expression, this.sourceLayout, TestPageFieldsToInputParametersRewriter.PLANNER_CONTEXT.getMetadata(), TestPageFieldsToInputParametersRewriter.PLANNER_CONTEXT.getTypeManager());
        }
    }

    @Test
    public void testEagerLoading() {
        RowExpressionBuilder addSymbol = RowExpressionBuilder.create().addSymbol("bigint0", BigintType.BIGINT).addSymbol("bigint1", BigintType.BIGINT);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 5L)))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Cast(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 10L))), IntegerType.INTEGER)), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Coalesce(new Call(MODULUS_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 2L))), new Reference(BigintType.BIGINT, "bigint0"), new Expression[0])), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new In(new Reference(BigintType.BIGINT, "bigint0"), ImmutableList.of(new Constant(BigintType.BIGINT, 1L), new Constant(BigintType.BIGINT, 2L), new Constant(BigintType.BIGINT, 3L)))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Comparison(Comparison.Operator.EQUAL, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 1L))), new Constant(BigintType.BIGINT, 0L))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Between(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 1L), new Constant(BigintType.BIGINT, 10L))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L)), new Reference(BigintType.BIGINT, "bigint0"))), new Constant(BigintType.BIGINT, (Object) null))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Switch(new Reference(BigintType.BIGINT, "bigint0"), ImmutableList.of(new WhenClause(new Constant(BigintType.BIGINT, 1L), new Constant(BigintType.BIGINT, 1L))), new Call(NEGATION_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"))))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Coalesce(new Constant(BigintType.BIGINT, 0L), new Reference(BigintType.BIGINT, "bigint0"), new Expression[0]), new Reference(BigintType.BIGINT, "bigint0")))), 1);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BigintType.BIGINT, 2L), new Reference(BigintType.BIGINT, "bigint1")))))), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new NullIf(new Reference(BigintType.BIGINT, "bigint0"), new Reference(BigintType.BIGINT, "bigint1"))), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Coalesce(new Call(CEIL, ImmutableList.of(new Call(DIVIDE_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"), new Reference(BigintType.BIGINT, "bigint1"))))), new Constant(BigintType.BIGINT, 0L), new Expression[0])), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Reference(BigintType.BIGINT, "bigint1")), new Constant(IntegerType.INTEGER, 1L))), new Constant(IntegerType.INTEGER, 0L))), 2);
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Case(ImmutableList.of(new WhenClause(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L)), new Reference(BigintType.BIGINT, "bigint1"))), new Constant(BigintType.BIGINT, 0L))), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Coalesce(new Call(ROUND, ImmutableList.of(new Reference(BigintType.BIGINT, "bigint0"))), new Reference(BigintType.BIGINT, "bigint1"), new Expression[0])), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint1"), new Constant(BigintType.BIGINT, 0L))))), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Logical(Logical.Operator.OR, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L)), new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "bigint1"), new Constant(BigintType.BIGINT, 0L))))), 2, ImmutableSet.of(0));
        verifyEagerlyLoadedColumns(addSymbol.buildExpression(new Between(new Reference(BigintType.BIGINT, "bigint0"), new Constant(BigintType.BIGINT, 0L), new Reference(BigintType.BIGINT, "bigint1"))), 2, ImmutableSet.of(0));
        RowExpressionBuilder addSymbol2 = RowExpressionBuilder.create().addSymbol("array_bigint0", new ArrayType(BigintType.BIGINT)).addSymbol("array_bigint1", new ArrayType(BigintType.BIGINT));
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BigintType.BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(BigintType.BIGINT, "x")), new Constant(IntegerType.INTEGER, 1L))))), 1, ImmutableSet.of());
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression(new Call(TRANSFORM, ImmutableList.of(new Reference(new ArrayType(BigintType.BIGINT), "array_bigint0"), new Lambda(ImmutableList.of(new Symbol(BigintType.BIGINT, "x")), new Call(MULTIPLY_INTEGER, ImmutableList.of(new Constant(IntegerType.INTEGER, 2L), new Reference(IntegerType.INTEGER, "x"))))))), 1, ImmutableSet.of());
        verifyEagerlyLoadedColumns(addSymbol2.buildExpression(new Call(ZIP_WITH, ImmutableList.of(new Reference(new ArrayType(BigintType.BIGINT), "array_bigint0"), new Reference(new ArrayType(BigintType.BIGINT), "array_bigint1"), new Lambda(ImmutableList.of(new Symbol(BigintType.BIGINT, "x"), new Symbol(BigintType.BIGINT, "y")), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Constant(BigintType.BIGINT, 2L), new Reference(BigintType.BIGINT, "x"))))))), 2, ImmutableSet.of());
    }

    private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int i) {
        verifyEagerlyLoadedColumns(rowExpression, i, (Set) IntStream.range(0, i).boxed().collect(ImmutableSet.toImmutableSet()));
    }

    private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int i, Set<Integer> set) {
        PageFieldsToInputParametersRewriter.Result rewritePageFieldsToInputParameters = PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters(rowExpression);
        Block[] blockArr = new Block[i];
        for (int i2 = 0; i2 < i; i2++) {
            blockArr[i2] = BlockAssertions.createLongSequenceBlock(0, 100);
        }
        TestingSourcePage testingSourcePage = new TestingSourcePage(100, blockArr);
        rewritePageFieldsToInputParameters.getInputChannels().getInputChannels(testingSourcePage);
        for (int i3 = 0; i3 < i; i3++) {
            Assertions.assertThat(testingSourcePage.wasLoaded(i3)).isEqualTo(set.contains(Integer.valueOf(i3)));
        }
    }
}
