package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.testing.Closeables;
import io.trino.Session;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TableHandle;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.function.OperatorType;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.assertions.PlanTestSymbol;
import io.trino.sql.planner.assertions.SetOperationOutputMatcher;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.PlanTester;
import io.trino.testing.TestingSession;
import io.trino.testing.TestingTransactionHandle;
import java.io.Closeable;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.class */
public class TestMultipleDistinctAggregationsToSubqueries extends BaseRuleTest {
    private static final String TEST_SCHEMA = "test_schema";
    private RuleTester ruleTester;
    private static final String MOCK_CATALOG = "mock_catalog";
    private static final Session MOCK_SESSION = TestingSession.testSessionBuilder().setCatalog(MOCK_CATALOG).setSchema("test_schema").build();
    private static final String COLUMN_1 = "orderkey";
    private static final ColumnHandle COLUMN_1_HANDLE = new MockConnectorColumnHandle(COLUMN_1, BigintType.BIGINT);
    private static final String COLUMN_2 = "partkey";
    private static final ColumnHandle COLUMN_2_HANDLE = new MockConnectorColumnHandle(COLUMN_2, BigintType.BIGINT);
    private static final String COLUMN_3 = "linenumber";
    private static final ColumnHandle COLUMN_3_HANDLE = new MockConnectorColumnHandle(COLUMN_3, BigintType.BIGINT);
    private static final String COLUMN_4 = "shipdate";
    private static final ColumnHandle COLUMN_4_HANDLE = new MockConnectorColumnHandle(COLUMN_4, DateType.DATE);
    private static final String GROUPING_KEY_COLUMN = "suppkey";
    private static final ColumnHandle GROUPING_KEY_COLUMN_HANDLE = new MockConnectorColumnHandle(GROUPING_KEY_COLUMN, BigintType.BIGINT);
    private static final String GROUPING_KEY2_COLUMN = "comment";
    private static final ColumnHandle GROUPING_KEY2_COLUMN_HANDLE = new MockConnectorColumnHandle(GROUPING_KEY2_COLUMN, VarcharType.VARCHAR);
    private static final String TEST_TABLE = "test_table";
    private static final SchemaTableName TABLE_SCHEMA = new SchemaTableName("test_schema", TEST_TABLE);
    private static final List<ColumnMetadata> ALL_COLUMNS = (List) Stream.of((Object[]) new ColumnHandle[]{COLUMN_1_HANDLE, COLUMN_2_HANDLE, COLUMN_3_HANDLE, COLUMN_4_HANDLE, GROUPING_KEY_COLUMN_HANDLE, GROUPING_KEY2_COLUMN_HANDLE}).map(columnHandle -> {
        return (MockConnectorColumnHandle) columnHandle;
    }).map(mockConnectorColumnHandle -> {
        return new ColumnMetadata(mockConnectorColumnHandle.getName(), mockConnectorColumnHandle.getType());
    }).collect(ImmutableList.toImmutableList());
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));

    public TestMultipleDistinctAggregationsToSubqueries() {
        super(new Plugin[0]);
        this.ruleTester = tester(true);
    }

    @AfterAll
    public final void tearDownTester() {
        Closeables.closeAllRuntimeException(new Closeable[]{this.ruleTester});
        this.ruleTester = null;
    }

    @Test
    public void testDoesNotFire() {
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("inputSymbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol), ImmutableMap.of(symbol, COLUMN_1_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("inputSymbol", BigintType.BIGINT);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder2.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "inputSymbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol), ImmutableMap.of(symbol, COLUMN_1_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("input1Symbol", BigintType.BIGINT);
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder3.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder3.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder3.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol), ImmutableMap.of(symbol, COLUMN_1_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder4 -> {
            Symbol symbol = planBuilder4.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder4.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder4.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder4.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder4.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).hashSymbol(planBuilder4.symbol("hashSymbol", BigintType.BIGINT)).source(planBuilder4.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder5 -> {
            Symbol symbol = planBuilder5.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder5.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder5.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder5.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder5.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder5.symbol("output3", BigintType.BIGINT), PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder5.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder6 -> {
            Symbol symbol = planBuilder6.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder6.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder6.aggregation(aggregationBuilder -> {
                aggregationBuilder.groupingSets(AggregationNode.groupingSets(ImmutableList.of(), 2, ImmutableSet.of(0, 1))).addAggregation(planBuilder6.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder6.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder6.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder7 -> {
            Symbol symbol = planBuilder7.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder7.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder7.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder7.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder7.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder7.join(JoinType.INNER, planBuilder7.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(), ImmutableMap.of()), planBuilder7.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)), new JoinNode.EquiJoinClause[0]));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder8 -> {
            Symbol symbol = planBuilder8.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder8.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder8.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder8.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder8.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder8.filter(Booleans.TRUE, planBuilder8.join(JoinType.INNER, planBuilder8.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(), ImmutableMap.of()), planBuilder8.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)), new JoinNode.EquiJoinClause[0])));
            });
        }).doesNotFire();
        RuleTester tester = tester(false);
        tester.assertThat(newMultipleDistinctAggregationsToSubqueries(tester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder9 -> {
            Symbol symbol = planBuilder9.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder9.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder9.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder9.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder9.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder9.tableScan(testTableHandle(tester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "single_step").on(planBuilder10 -> {
            Symbol symbol = planBuilder10.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder10.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder10.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder10.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder10.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder10.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).doesNotFire();
        String str = "aggregationSourceId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().addSymbolStatistics(new Symbol(BigintType.BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1000000.0d).build()).build()).on(planBuilder11 -> {
            Symbol symbol = planBuilder11.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder11.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder11.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(planBuilder11.symbol("groupingKey", BigintType.BIGINT)).addAggregation(planBuilder11.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder11.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder11.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE));
                }));
            });
        }).doesNotFire();
    }

    @Test
    public void testAutomaticDecisionForAggregationOnTableScan() {
        String str = "aggregationSourceId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().addSymbolStatistics(new Symbol(BigintType.BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1000000.0d).build()).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(planBuilder.symbol("groupingKey", BigintType.BIGINT)).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE));
                }));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(BigintType.BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10.0d).build()).build()).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder2.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder2.symbol("groupingKey", BigintType.BIGINT);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(symbol3).addAggregation(planBuilder2.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder2.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2, symbol3)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, GROUPING_KEY_COLUMN_HANDLE));
                }));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "group_by_key", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "left_groupingKey"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("left_groupingKey", "right_groupingKey").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("left_groupingKey"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1, "left_groupingKey", GROUPING_KEY_COLUMN)))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("right_groupingKey"), ImmutableMap.of(Optional.of("output2"), PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2, "right_groupingKey", GROUPING_KEY_COLUMN))));
        })));
        String str2 = "aggregationId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(BigintType.BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10.0d).build()).addSymbolStatistics(new Symbol(BigintType.BIGINT, "groupingKey2"), SymbolStatsEstimate.builder().setAverageRowSize(1000000.0d).build()).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(10.0d).build()).on(planBuilder3 -> {
            Symbol symbol = planBuilder3.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder3.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder3.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol4 = planBuilder3.symbol("groupingKey2", VarcharType.VARCHAR);
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str2)).singleGroupingSet(symbol3, symbol4).addAggregation(planBuilder3.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder3.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder3.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2, symbol3, symbol4)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, GROUPING_KEY_COLUMN_HANDLE, symbol4, GROUPING_KEY2_COLUMN_HANDLE));
                }));
            });
        }).doesNotFire();
    }

    @Test
    public void testAutomaticDecisionForAggregationOnProjectedTableScan() {
        String str = "aggregationSourceId";
        String str2 = "aggregationId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(BigintType.BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10.0d).build()).addSymbolStatistics(new Symbol(BigintType.BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1000000.0d).build()).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(10.0d).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol4 = planBuilder.symbol("projectionInput1", BigintType.BIGINT);
            Symbol symbol5 = planBuilder.symbol("projectionInput2", VarcharType.VARCHAR);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str2)).singleGroupingSet(symbol3).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.project(Assignments.builder().putIdentity(symbol).putIdentity(symbol2).put(symbol3, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "projectionInput1"), new Cast(new Reference(BigintType.BIGINT, "projectionInput2"), BigintType.BIGINT)))).build(), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2, symbol4, symbol5)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol4, GROUPING_KEY_COLUMN_HANDLE, symbol5, GROUPING_KEY2_COLUMN_HANDLE));
                })));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(BigintType.BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10.0d).build()).addSymbolStatistics(new Symbol(BigintType.BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1000000.0d).build()).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(10.0d).build()).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder2.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder2.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol4 = planBuilder2.symbol("projectionInput1", BigintType.BIGINT);
            Symbol symbol5 = planBuilder2.symbol("projectionInput2", VarcharType.VARCHAR);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str2)).singleGroupingSet(symbol3).addAggregation(planBuilder2.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder2.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.project(Assignments.builder().put(symbol, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "projectionInput1"), new Cast(new Reference(BigintType.BIGINT, "projectionInput2"), BigintType.BIGINT)))).putIdentity(symbol2).putIdentity(symbol3).build(), planBuilder2.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol3, symbol2, symbol4, symbol5)).setAssignments(ImmutableMap.of(symbol3, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol4, GROUPING_KEY_COLUMN_HANDLE, symbol5, GROUPING_KEY2_COLUMN_HANDLE));
                })));
            });
        }).doesNotFire();
    }

    @Test
    public void testAutomaticDecisionForAggregationOnFilteredTableScan() {
        String str = "aggregationSourceId";
        String str2 = "aggregationId";
        String str3 = "filterId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(VarcharType.VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1.0d).build()).build()).overrideStats("filterId", PlanNodeStatsEstimate.builder().setOutputRowCount(1.0d).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(1.0d).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol4 = planBuilder.symbol("filterInput", VarcharType.VARCHAR);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str2)).singleGroupingSet(symbol3).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.filter(new PlanNodeId(str3), IrExpressions.not(this.ruleTester.getMetadata(), new IsNull(new Reference(VarcharType.VARCHAR, "filterInput"))), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2, symbol3, symbol4)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, GROUPING_KEY_COLUMN_HANDLE, symbol4, GROUPING_KEY2_COLUMN_HANDLE));
                })));
            });
        }).doesNotFire();
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(VarcharType.VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1.0d).build()).build()).overrideStats("filterId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).build()).on(planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder2.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder2.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol4 = planBuilder2.symbol("filterInput", VarcharType.VARCHAR);
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str2)).singleGroupingSet(symbol3).addAggregation(planBuilder2.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder2.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder2.filter(new PlanNodeId(str3), IrExpressions.not(this.ruleTester.getMetadata(), new IsNull(new Reference(VarcharType.VARCHAR, "filterInput"))), planBuilder2.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setNodeId(new PlanNodeId(str)).setTableHandle(testTableHandle(this.ruleTester)).setSymbols(ImmutableList.of(symbol, symbol2, symbol3, symbol4)).setAssignments(ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, GROUPING_KEY_COLUMN_HANDLE, symbol4, GROUPING_KEY2_COLUMN_HANDLE));
                })));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "group_by_key", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "left_groupingKey"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("left_groupingKey", "right_groupingKey").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("left_groupingKey"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter(IrExpressions.not(this.ruleTester.getMetadata(), new IsNull(new Reference(BigintType.BIGINT, "left_filterInput"))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1, "left_groupingKey", GROUPING_KEY_COLUMN, "left_filterInput", GROUPING_KEY2_COLUMN))))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("right_groupingKey"), ImmutableMap.of(Optional.of("output2"), PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.filter(IrExpressions.not(this.ruleTester.getMetadata(), new IsNull(new Reference(BigintType.BIGINT, "right_filterInput"))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2, "right_groupingKey", GROUPING_KEY_COLUMN, "right_filterInput", GROUPING_KEY2_COLUMN)))));
        })));
    }

    @Test
    public void testAutomaticDecisionForAggregationOnFilteredUnion() {
        String str = "aggregationId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "automatic").overrideStats("aggregationSourceId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(VarcharType.VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1.0d).build()).build()).overrideStats("filterId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).build()).overrideStats("aggregationId", PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input1_1Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("input1_2Symbol", BigintType.BIGINT);
            Symbol symbol4 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol5 = planBuilder.symbol("input2_1Symbol", BigintType.BIGINT);
            Symbol symbol6 = planBuilder.symbol("input2_2Symbol", BigintType.BIGINT);
            Symbol symbol7 = planBuilder.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol8 = planBuilder.symbol("groupingKey1", BigintType.BIGINT);
            Symbol symbol9 = planBuilder.symbol("groupingKey2", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str)).singleGroupingSet(symbol7).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.union(ImmutableListMultimap.builder().put(symbol, symbol2).put(symbol, symbol3).put(symbol4, symbol5).put(symbol4, symbol6).put(symbol7, symbol8).put(symbol7, symbol9).build(), ImmutableList.of(planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input1_1Symbol"), new Constant(BigintType.BIGINT, 0L)), planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol2, symbol5, symbol8), ImmutableMap.of(symbol2, COLUMN_1_HANDLE, symbol5, COLUMN_2_HANDLE, symbol8, GROUPING_KEY_COLUMN_HANDLE))), planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input2_2Symbol"), new Constant(BigintType.BIGINT, 2L)), planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol3, symbol6, symbol9), ImmutableMap.of(symbol3, COLUMN_1_HANDLE, symbol6, COLUMN_2_HANDLE, symbol9, GROUPING_KEY_COLUMN_HANDLE))))));
            });
        }).doesNotFire();
    }

    @Test
    public void testGlobalDistinctToSubqueries() {
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output1", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))).right(PlanMatchPattern.aggregation(ImmutableMap.of("output2", PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2))));
        })));
    }

    @Test
    public void testGlobalWith3DistinctToSubqueries() {
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("input3Symbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output3", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input3Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2, symbol3), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, COLUMN_3_HANDLE)));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "final_output3", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output3"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output1", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))).right(PlanMatchPattern.join(JoinType.INNER, builder -> {
                builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output2", PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2)))).right(PlanMatchPattern.aggregation(ImmutableMap.of("output3", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input3Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input3Symbol", COLUMN_3))));
            }));
        })));
    }

    @Test
    public void testGlobalWith4DistinctToSubqueries() {
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("input3Symbol", BigintType.BIGINT);
            Symbol symbol4 = planBuilder.symbol("input4Symbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output3", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input3Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output4", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input4Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2, symbol3, symbol4), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, COLUMN_3_HANDLE, symbol4, COLUMN_4_HANDLE)));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "final_output3", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output3")), "final_output4", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output4"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output1", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))).right(PlanMatchPattern.join(JoinType.INNER, builder -> {
                builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output2", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2)))).right(PlanMatchPattern.join(JoinType.INNER, builder -> {
                    builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output3", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input3Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input3Symbol", COLUMN_3)))).right(PlanMatchPattern.aggregation(ImmutableMap.of("output4", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input4Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input4Symbol", COLUMN_4))));
                }));
            }));
        })));
    }

    @Test
    public void testGlobal2DistinctOnTheSameInputToSubqueries() {
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output3", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE)));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "final_output3", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output3"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.aggregation(ImmutableMap.of("output1", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))).right(PlanMatchPattern.aggregation(ImmutableMap.of("output2", PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol"))), "output3", PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2))));
        })));
    }

    @Test
    public void testGroupByWithDistinctToSubqueries() {
        String str = "aggregationNodeId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").overrideStats("aggregationNodeId", PlanNodeStatsEstimate.builder().setOutputRowCount(100000.0d).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("groupingKey", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str)).singleGroupingSet(symbol3).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol, symbol2, symbol3), ImmutableMap.of(symbol, COLUMN_1_HANDLE, symbol2, COLUMN_2_HANDLE, symbol3, GROUPING_KEY_COLUMN_HANDLE)));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "group_by_key", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "left_groupingKey"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("left_groupingKey", "right_groupingKey").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("left_groupingKey"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1, "left_groupingKey", GROUPING_KEY_COLUMN)))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("right_groupingKey"), ImmutableMap.of(Optional.of("output2"), PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2, "right_groupingKey", GROUPING_KEY_COLUMN))));
        })));
    }

    @Test
    public void testGroupByWithDistinctOverUnionToSubqueries() {
        String str = "aggregationNodeId";
        this.ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(this.ruleTester)).setSystemProperty("distinct_aggregations_strategy", "split_to_subqueries").overrideStats("aggregationNodeId", PlanNodeStatsEstimate.builder().setOutputRowCount(100000.0d).build()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("input1Symbol", BigintType.BIGINT);
            Symbol symbol2 = planBuilder.symbol("input1_1Symbol", BigintType.BIGINT);
            Symbol symbol3 = planBuilder.symbol("input1_2Symbol", BigintType.BIGINT);
            Symbol symbol4 = planBuilder.symbol("input2Symbol", BigintType.BIGINT);
            Symbol symbol5 = planBuilder.symbol("input2_1Symbol", BigintType.BIGINT);
            Symbol symbol6 = planBuilder.symbol("input2_2Symbol", BigintType.BIGINT);
            Symbol symbol7 = planBuilder.symbol("groupingKey", BigintType.BIGINT);
            Symbol symbol8 = planBuilder.symbol("groupingKey1", BigintType.BIGINT);
            Symbol symbol9 = planBuilder.symbol("groupingKey2", BigintType.BIGINT);
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.nodeId(new PlanNodeId(str)).singleGroupingSet(symbol7).addAggregation(planBuilder.symbol("output1", BigintType.BIGINT), PlanBuilder.aggregation("count", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input1Symbol"))), ImmutableList.of(BigintType.BIGINT)).addAggregation(planBuilder.symbol("output2", BigintType.BIGINT), PlanBuilder.aggregation("sum", true, (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "input2Symbol"))), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.union(ImmutableListMultimap.builder().put(symbol, symbol2).put(symbol, symbol3).put(symbol4, symbol5).put(symbol4, symbol6).put(symbol7, symbol8).put(symbol7, symbol9).build(), ImmutableList.of(planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input1_1Symbol"), new Constant(BigintType.BIGINT, 0L)), planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol2, symbol5, symbol8), ImmutableMap.of(symbol2, COLUMN_1_HANDLE, symbol5, COLUMN_2_HANDLE, symbol8, GROUPING_KEY_COLUMN_HANDLE))), planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input2_2Symbol"), new Constant(BigintType.BIGINT, 2L)), planBuilder.tableScan(testTableHandle(this.ruleTester), ImmutableList.of(symbol3, symbol6, symbol9), ImmutableMap.of(symbol3, COLUMN_1_HANDLE, symbol6, COLUMN_2_HANDLE, symbol9, GROUPING_KEY_COLUMN_HANDLE))))));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("final_output1", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output1")), "final_output2", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "output2")), "group_by_key", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "left_groupingKey"))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("left_groupingKey", "right_groupingKey").left(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("left_groupingKey"), ImmutableMap.of(Optional.of("output1"), PlanMatchPattern.aggregationFunction("count", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input1Symbol1")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.union(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input1_1_1Symbol"), new Constant(BigintType.BIGINT, 0L)), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1_1_1Symbol", COLUMN_1, "input2_1_1Symbol", COLUMN_2, "left_groupingKey1", GROUPING_KEY_COLUMN))), PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input2_2_1Symbol"), new Constant(BigintType.BIGINT, 2L)), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1_2_1Symbol", COLUMN_1, "input2_2_1Symbol", COLUMN_2, "left_groupingKey2", GROUPING_KEY_COLUMN)))).withAlias("input1Symbol1", new SetOperationOutputMatcher(0)).withAlias("input2Symbol1", new SetOperationOutputMatcher(1)).withAlias("left_groupingKey", new SetOperationOutputMatcher(2)))).right(PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("right_groupingKey"), ImmutableMap.of(Optional.of("output2"), PlanMatchPattern.aggregationFunction("sum", true, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.symbol("input2Symbol2")))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.union(PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input1_1_2Symbol"), new Constant(BigintType.BIGINT, 0L)), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1_1_2Symbol", COLUMN_1, "input2_1_2Symbol", COLUMN_2, "right_groupingKey1", GROUPING_KEY_COLUMN))), PlanMatchPattern.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(BigintType.BIGINT, "input2_2_2Symbol"), new Constant(BigintType.BIGINT, 2L)), PlanMatchPattern.tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1_2_2Symbol", COLUMN_1, "input2_2_2Symbol", COLUMN_2, "right_groupingKey2", GROUPING_KEY_COLUMN)))).withAlias("input1Symbol2", new SetOperationOutputMatcher(0)).withAlias("input2Symbol2", new SetOperationOutputMatcher(1)).withAlias("right_groupingKey", new SetOperationOutputMatcher(2))));
        })));
    }

    private static MultipleDistinctAggregationsToSubqueries newMultipleDistinctAggregationsToSubqueries(RuleTester ruleTester) {
        return new MultipleDistinctAggregationsToSubqueries(new TaskCountEstimator(() -> {
            return Integer.MAX_VALUE;
        }), ruleTester.getMetadata());
    }

    private static TableHandle testTableHandle(RuleTester ruleTester) {
        return new TableHandle(ruleTester.getCurrentCatalogHandle(), new MockConnectorTableHandle(TABLE_SCHEMA, TupleDomain.all(), Optional.empty()), TestingTransactionHandle.create());
    }

    private static RuleTester tester(boolean z) {
        PlanTester create = PlanTester.create(MOCK_SESSION);
        create.createCatalog(MOCK_CATALOG, MockConnectorFactory.builder().withAllowSplittingReadIntoMultipleSubQueries(z).withGetTableHandle((connectorSession, schemaTableName) -> {
            return new MockConnectorTableHandle(schemaTableName);
        }).withGetColumns(schemaTableName2 -> {
            return ALL_COLUMNS;
        }).build(), ImmutableMap.of());
        return new RuleTester(create);
    }
}
