package io.trino.testing;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.operator.OperatorStats;
import io.trino.plugin.base.metrics.DurationTiming;
import io.trino.spi.QueryId;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.metrics.Count;
import io.trino.sql.DynamicFilters;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.testing.QueryRunner;
import io.trino.tpch.TpchTable;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

/* loaded from: input_file:io/trino/testing/AbstractTestDynamicRowFiltering.class */
public abstract class AbstractTestDynamicRowFiltering extends AbstractTestQueryFramework {
    protected static final List<TpchTable<?>> REQUIRED_TPCH_TABLES = ImmutableList.of(TpchTable.CUSTOMER, TpchTable.NATION);

    protected abstract SchemaTableName getSchemaTableName(ConnectorTableHandle connectorTableHandle);

    @Test
    public void verifyDynamicFilteringEnabled() {
        assertQuery("SHOW SESSION LIKE 'enable_dynamic_filtering'", "VALUES ('enable_dynamic_filtering', 'true', 'true', 'boolean', 'Enable dynamic filtering')");
    }

    @Timeout(30)
    @Test
    public void testJoinWithSelectiveRowFiltering() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            assertRowFiltering("SELECT * FROM customer c, nation n WHERE c.nationkey = n.nationkey and n.name = 'ALGERIA'", joinDistributionType);
        }
    }

    @Timeout(30)
    @Test
    public void testJoinWithNonSelectiveRowFiltering() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            assertNoRowFiltering("SELECT * FROM  customer c, nation n WHERE c.nationkey = n.nationkey", joinDistributionType);
        }
    }

    @Timeout(30)
    @Test
    public void testRowFilteringWithStrings() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            assertRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.name = c2.name AND c2.acctbal > 9000", joinDistributionType);
            assertRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.mktsegment = c2.mktsegment AND c2.custkey = 1", joinDistributionType);
            assertNoRowFiltering("SELECT * FROM customer c1, customer c2 WHERE c1.mktsegment = c2.mktsegment AND c2.custkey < 10", joinDistributionType);
        }
    }

    @Timeout(30)
    @Test
    public void testJoinWithMultipleDynamicFilters() {
        for (OptimizerConfig.JoinDistributionType joinDistributionType : OptimizerConfig.JoinDistributionType.values()) {
            assertNoRowFiltering("SELECT a.* FROM customer a INNER JOIN customer b ON a.nationkey = b.nationkey AND a.mktsegment = b.mktsegment", joinDistributionType);
            assertRowFiltering("SELECT * FROM (SELECT a.* FROM customer a INNER JOIN customer b ON a.mktsegment = b.mktsegment AND a.custkey = b.custkey) c INNER JOIN nation on c.nationkey = nation.nationkey AND  nation.name IN ('ALGERIA')", joinDistributionType);
        }
    }

    protected void assertRowFiltering(@Language("SQL") String str, OptimizerConfig.JoinDistributionType joinDistributionType, String str2) {
        QueryRunner.MaterializedResultWithPlan executeWithPlan = getDistributedQueryRunner().executeWithPlan(dynamicRowFiltering(joinDistributionType), str);
        QueryRunner.MaterializedResultWithPlan executeWithPlan2 = getDistributedQueryRunner().executeWithPlan(noDynamicRowFiltering(joinDistributionType), str);
        MaterializedResult computeExpected = computeExpected(str, executeWithPlan.result().getTypes());
        QueryAssertions.assertEqualsIgnoreOrder(executeWithPlan.result(), computeExpected, "For query: \n " + str);
        QueryAssertions.assertEqualsIgnoreOrder(executeWithPlan2.result(), computeExpected, "For query: \n " + str);
        OperatorStats scanFilterAndProjectOperatorStats = getScanFilterAndProjectOperatorStats(executeWithPlan.queryId(), str2);
        Assertions.assertThat(scanFilterAndProjectOperatorStats.getInputPositions()).isEqualTo(scanFilterAndProjectOperatorStats.getPhysicalInputPositions());
        Assertions.assertThat(scanFilterAndProjectOperatorStats.getOutputPositions()).isLessThan(scanFilterAndProjectOperatorStats.getInputPositions());
        OperatorStats scanFilterAndProjectOperatorStats2 = getScanFilterAndProjectOperatorStats(executeWithPlan2.queryId(), str2);
        Assertions.assertThat(scanFilterAndProjectOperatorStats2.getInputPositions()).isEqualTo(scanFilterAndProjectOperatorStats2.getPhysicalInputPositions());
        Assertions.assertThat(scanFilterAndProjectOperatorStats2.getOutputPositions()).isEqualTo(scanFilterAndProjectOperatorStats2.getInputPositions());
        Assertions.assertThat(scanFilterAndProjectOperatorStats.getOutputPositions()).isLessThan(scanFilterAndProjectOperatorStats2.getOutputPositions());
        Map metrics = scanFilterAndProjectOperatorStats.getMetrics().getMetrics();
        Assertions.assertThat(((Count) metrics.get("Dynamic Filter output positions")).getTotal()).isLessThan(scanFilterAndProjectOperatorStats.getInputPositions());
        Assertions.assertThat(((DurationTiming) metrics.get("Dynamic Filter CPU time")).getDuration()).isGreaterThan(Duration.ZERO);
    }

    private void assertRowFiltering(@Language("SQL") String str, OptimizerConfig.JoinDistributionType joinDistributionType) {
        assertRowFiltering(str, joinDistributionType, "customer");
    }

    protected void assertNoRowFiltering(@Language("SQL") String str, OptimizerConfig.JoinDistributionType joinDistributionType, String str2) {
        QueryRunner.MaterializedResultWithPlan executeWithPlan = getDistributedQueryRunner().executeWithPlan(dynamicRowFiltering(joinDistributionType), str);
        QueryAssertions.assertEqualsIgnoreOrder(executeWithPlan.result(), computeExpected(str, executeWithPlan.result().getTypes()), "For query: \n " + str);
        OperatorStats scanFilterAndProjectOperatorStats = getScanFilterAndProjectOperatorStats(executeWithPlan.queryId(), str2);
        Assertions.assertThat(scanFilterAndProjectOperatorStats.getInputPositions()).isEqualTo(scanFilterAndProjectOperatorStats.getPhysicalInputPositions());
        Map metrics = scanFilterAndProjectOperatorStats.getMetrics().getMetrics();
        Assertions.assertThat(scanFilterAndProjectOperatorStats.getOutputPositions()).isEqualTo(((Count) metrics.get("Dynamic Filter output positions")).getTotal());
        Assertions.assertThat(((DurationTiming) metrics.get("Dynamic Filter CPU time")).getDuration()).isGreaterThan(Duration.ZERO);
    }

    private void assertNoRowFiltering(@Language("SQL") String str, OptimizerConfig.JoinDistributionType joinDistributionType) {
        assertNoRowFiltering(str, joinDistributionType, "customer");
    }

    private OperatorStats getScanFilterAndProjectOperatorStats(QueryId queryId, String str) {
        return extractOperatorStatsForNodeId(queryId, PlanNodeSearcher.searchFrom(getDistributedQueryRunner().getQueryPlan(queryId).getRoot()).where(planNode -> {
            if (!(planNode instanceof FilterNode)) {
                return false;
            }
            FilterNode filterNode = (FilterNode) planNode;
            TableScanNode source = filterNode.getSource();
            if (!(source instanceof TableScanNode)) {
                return false;
            }
            TableScanNode tableScanNode = source;
            if (DynamicFilters.extractDynamicFilters(filterNode.getPredicate()).getDynamicConjuncts().isEmpty()) {
                return false;
            }
            return getSchemaTableName(tableScanNode.getTable().connectorHandle()).equals(new SchemaTableName("tpch", str));
        }).findOnlyElement().getId(), "ScanFilterAndProjectOperator");
    }

    private Session dynamicRowFiltering(OptimizerConfig.JoinDistributionType joinDistributionType) {
        return Session.builder(noJoinReordering(joinDistributionType)).setSystemProperty("enable_dynamic_row_filtering", "true").setSystemProperty("dynamic_row_filtering_selectivity_threshold", "1").build();
    }

    private Session noDynamicRowFiltering(OptimizerConfig.JoinDistributionType joinDistributionType) {
        return Session.builder(noJoinReordering(joinDistributionType)).setSystemProperty("enable_dynamic_row_filtering", "false").build();
    }
}
