package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.type.BigintType;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.RowNumberSymbolMatcher;
import io.trino.sql.planner.assertions.TopNRankingSymbolMatcher;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.WindowNode;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/optimizations/TestWindowFilterPushDown.class */
public class TestWindowFilterPushDown extends BasePlanTest {
    private static final String ROW_NUMBER_FUNCTION_NAME = "row_number";
    private static final String RANK_FUNCTION_NAME = "rank";

    @Test
    public void testLimitAbovePartitionedWindow() {
        assertLimitAbovePartitionedWindow(ROW_NUMBER_FUNCTION_NAME, TopNRankingNode.RankingType.ROW_NUMBER);
        assertLimitAbovePartitionedWindow(RANK_FUNCTION_NAME, TopNRankingNode.RankingType.RANK);
    }

    private void assertLimitAbovePartitionedWindow(String str, TopNRankingNode.RankingType rankingType) {
        String format = String.format("SELECT %s() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10", str);
        assertPlanWithSession(format, optimizeTopNRanking(true), true, PlanMatchPattern.anyTree(PlanMatchPattern.limit(10L, PlanMatchPattern.anyTree(PlanMatchPattern.topNRanking(builder -> {
            builder.rankingType(rankingType).maxRankingPerPartition(10);
        }, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
        assertPlanWithSession(format, optimizeTopNRanking(false), true, PlanMatchPattern.anyTree(PlanMatchPattern.limit(10L, PlanMatchPattern.anyTree(PlanMatchPattern.node(WindowNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
    }

    @Test
    public void testLimitAboveUnpartitionedWindow() {
        assertPlanWithSession("SELECT row_number() OVER (ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10", optimizeTopNRanking(true), true, PlanMatchPattern.output(PlanMatchPattern.project(PlanMatchPattern.topNRanking(builder -> {
            builder.rankingType(TopNRankingNode.RankingType.ROW_NUMBER).maxRankingPerPartition(10);
        }, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem"))))));
        assertPlanWithSession("SELECT rank() OVER (ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10", optimizeTopNRanking(true), true, PlanMatchPattern.anyTree(PlanMatchPattern.limit(10L, PlanMatchPattern.anyTree(PlanMatchPattern.topNRanking(builder2 -> {
            builder2.rankingType(TopNRankingNode.RankingType.RANK).maxRankingPerPartition(10);
        }, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
    }

    @Test
    public void testFilterAboveWindow() {
        assertFilterAboveWindow(ROW_NUMBER_FUNCTION_NAME, TopNRankingNode.RankingType.ROW_NUMBER);
        assertFilterAboveWindow(RANK_FUNCTION_NAME, TopNRankingNode.RankingType.RANK);
    }

    private void assertFilterAboveWindow(String str, TopNRankingNode.RankingType rankingType) {
        String format = String.format("SELECT * FROM (SELECT %s() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_ranking FROM lineitem) WHERE partition_ranking < 10", str);
        assertPlanWithSession(format, optimizeTopNRanking(true), true, PlanMatchPattern.anyTree(PlanMatchPattern.anyNot(FilterNode.class, PlanMatchPattern.topNRanking(builder -> {
            builder.rankingType(rankingType).maxRankingPerPartition(9);
        }, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem"))))));
        assertPlanWithSession(format, optimizeTopNRanking(false), true, PlanMatchPattern.anyTree(PlanMatchPattern.node(FilterNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(WindowNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.tableScan("lineitem")))))));
        assertPlanWithSession(String.format("SELECT * FROM (SELECT name, %s() OVER(ORDER BY name) FROM nation) t(name, ranking) WHERE ranking < 0", str), optimizeTopNRanking(true), true, PlanMatchPattern.output(ImmutableList.of("name", "ranking"), PlanMatchPattern.values("name", "ranking")));
        assertPlanWithSession(String.format("SELECT * FROM (SELECT name, %s() OVER(ORDER BY name) FROM nation) t(name, ranking) WHERE ranking > 0 AND ranking < 2", str), optimizeTopNRanking(true), true, PlanMatchPattern.output(ImmutableList.of("name", "ranking"), PlanMatchPattern.topNRanking(builder2 -> {
            builder2.rankingType(rankingType).maxRankingPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias("ranking", new TopNRankingSymbolMatcher())));
        assertPlanWithSession(String.format("SELECT * FROM (SELECT name, %s() OVER(ORDER BY name) FROM nation) t(name, ranking) WHERE ranking <= 1", str), optimizeTopNRanking(true), true, PlanMatchPattern.output(ImmutableList.of("name", "ranking"), PlanMatchPattern.topNRanking(builder3 -> {
            builder3.rankingType(rankingType).maxRankingPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias("ranking", new TopNRankingSymbolMatcher())));
        assertPlanWithSession(String.format("SELECT * FROM (SELECT name, %s() OVER(ORDER BY name) FROM nation) t(name, ranking) WHERE ranking <= 1 AND ranking > -10", str), optimizeTopNRanking(true), true, PlanMatchPattern.output(ImmutableList.of("name", "ranking"), PlanMatchPattern.topNRanking(builder4 -> {
            builder4.rankingType(rankingType).maxRankingPerPartition(1).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias("ranking", new TopNRankingSymbolMatcher())));
        assertPlanWithSession(String.format("SELECT * FROM (SELECT name, %s() OVER(ORDER BY name) FROM nation) t(name, ranking) WHERE ranking > 1 AND ranking < 3", str), optimizeTopNRanking(true), true, PlanMatchPattern.output(ImmutableList.of("name", "ranking"), PlanMatchPattern.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "ranking"), new Constant(BigintType.BIGINT, 1L)), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "ranking"), new Constant(BigintType.BIGINT, 3L)))), PlanMatchPattern.topNRanking(builder5 -> {
            builder5.rankingType(rankingType).maxRankingPerPartition(2).specification(ImmutableList.of(), ImmutableList.of("name"), ImmutableMap.of("name", SortOrder.ASC_NULLS_LAST));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias("ranking", new TopNRankingSymbolMatcher()))));
    }

    @Test
    public void testFilterAboveRowNumber() {
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number < 0", PlanMatchPattern.output(ImmutableList.of("name", ROW_NUMBER_FUNCTION_NAME), PlanMatchPattern.values("name", ROW_NUMBER_FUNCTION_NAME)));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number < 2", PlanMatchPattern.output(ImmutableList.of("name", ROW_NUMBER_FUNCTION_NAME), PlanMatchPattern.rowNumber(builder -> {
            builder.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias(ROW_NUMBER_FUNCTION_NAME, new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number <= 1", PlanMatchPattern.output(ImmutableList.of("name", ROW_NUMBER_FUNCTION_NAME), PlanMatchPattern.rowNumber(builder2 -> {
            builder2.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias(ROW_NUMBER_FUNCTION_NAME, new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number <= 1 AND row_number > -10", PlanMatchPattern.output(ImmutableList.of("name", ROW_NUMBER_FUNCTION_NAME), PlanMatchPattern.rowNumber(builder3 -> {
            builder3.maxRowCountPerPartition(Optional.of(1));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias(ROW_NUMBER_FUNCTION_NAME, new RowNumberSymbolMatcher())));
        assertPlan("SELECT * FROM (SELECT name, row_number() OVER() FROM nation) t(name, row_number) WHERE row_number > 1 AND row_number < 3", PlanMatchPattern.output(ImmutableList.of("name", ROW_NUMBER_FUNCTION_NAME), PlanMatchPattern.filter(new Logical(Logical.Operator.AND, ImmutableList.of(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, ROW_NUMBER_FUNCTION_NAME), new Constant(BigintType.BIGINT, 1L)), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, ROW_NUMBER_FUNCTION_NAME), new Constant(BigintType.BIGINT, 3L)))), PlanMatchPattern.rowNumber(builder4 -> {
            builder4.maxRowCountPerPartition(Optional.of(2));
        }, PlanMatchPattern.any(PlanMatchPattern.tableScan("nation", ImmutableMap.of("name", "name")))).withAlias(ROW_NUMBER_FUNCTION_NAME, new RowNumberSymbolMatcher()))));
    }

    private Session optimizeTopNRanking(boolean z) {
        return Session.builder(getPlanTester().getDefaultSession()).setSystemProperty("optimize_top_n_ranking", Boolean.toString(z)).build();
    }
}
