package io.trino.cost;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingSession;
import io.trino.type.UnknownType;
import java.util.List;
import java.util.function.Function;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/cost/TestFilterProjectAggregationStatsRule.class */
public class TestFilterProjectAggregationStatsRule extends BaseStatsCalculatorTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_INTEGER = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER));
    private static final SymbolStatsEstimate SYMBOL_STATS_ESTIMATE_X = SymbolStatsEstimate.builder().setLowValue(0.0d).setHighValue(100.0d).setDistinctValuesCount(10.0d).setNullsFraction(0.1d).build();
    private static final SymbolStatsEstimate SYMBOL_STATS_ESTIMATE_Y = SymbolStatsEstimate.builder().setLowValue(0.0d).setHighValue(10.0d).setDistinctValuesCount(10.0d).setNullsFraction(0.0d).build();
    private static final Session APPROXIMATION_ENABLED = TestingSession.testSessionBuilder().setSystemProperty("non_estimatable_predicate_approximation_enabled", "true").build();
    private static final Session APPROXIMATION_DISABLED = TestingSession.testSessionBuilder().setSystemProperty("non_estimatable_predicate_approximation_enabled", "false").build();

    @Test
    public void testFilterOverAggregationStats() {
        Function<PlanBuilder, PlanNode> function = planBuilder -> {
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count_on_x"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder.symbol("count_on_x", BigintType.BIGINT), PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "x"))), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("y", BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT)));
            }));
        };
        PlanNodeStatsEstimate build = PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(DoubleType.DOUBLE, "y"), SYMBOL_STATS_ESTIMATE_Y).build();
        tester().assertStatsFor(APPROXIMATION_ENABLED, function).withSourceStats(build).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(90.0d).symbolStatsUnknown("count_on_x", DoubleType.DOUBLE);
        });
        tester().assertStatsFor(APPROXIMATION_DISABLED, function).withSourceStats(build).check((v0) -> {
            v0.outputRowsCountUnknown();
        });
        tester().assertStatsFor(APPROXIMATION_ENABLED, function).withSourceStats(PlanNodeStatsEstimate.builder().addSymbolStatistics(new Symbol(UnknownType.UNKNOWN, "y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50.0d).build()).build()).check((v0) -> {
            v0.outputRowsCountUnknown();
        });
        tester().assertStatsFor(APPROXIMATION_ENABLED, planBuilder2 -> {
            return planBuilder2.filter(new Comparison(Comparison.Operator.EQUAL, new Reference(DoubleType.DOUBLE, "y"), new Constant(DoubleType.DOUBLE, Double.valueOf(1.0d))), planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(planBuilder2.symbol("count_on_x", DoubleType.DOUBLE), PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(DoubleType.DOUBLE, "x"))), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder2.symbol("y", DoubleType.DOUBLE)).source(planBuilder2.values(planBuilder2.symbol("x", DoubleType.DOUBLE), planBuilder2.symbol("y", DoubleType.DOUBLE)));
            }));
        }).withSourceStats(build).check(planNodeStatsAssertion2 -> {
            planNodeStatsAssertion2.outputRowsCount(10.0d);
        });
    }

    @Test
    public void testFilterAndProjectOverAggregationStats() {
        PlanNodeId planNodeId = new PlanNodeId("aggregation");
        PlanNodeStatsEstimate build = PlanNodeStatsEstimate.builder().setOutputRowCount(100.0d).addSymbolStatistics(new Symbol(UnknownType.UNKNOWN, "x"), SYMBOL_STATS_ESTIMATE_X).addSymbolStatistics(new Symbol(UnknownType.UNKNOWN, "y"), SYMBOL_STATS_ESTIMATE_Y).build();
        tester().assertStatsFor(APPROXIMATION_ENABLED, planBuilder -> {
            Symbol symbol = planBuilder.symbol("count_on_x", BigintType.BIGINT);
            return planBuilder.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count_on_x"), new Constant(IntegerType.INTEGER, 0L)), planBuilder.project(Assignments.identity(new Symbol[]{symbol}), planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(symbol, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "x"))), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder.symbol("y", BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("x", BigintType.BIGINT), planBuilder.symbol("y", BigintType.BIGINT))).nodeId(planNodeId);
            })));
        }).withSourceStats(build).withSourceStats(planNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(50.0d).build()).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(45.0d);
        });
        tester().assertStatsFor(APPROXIMATION_ENABLED, planBuilder2 -> {
            Symbol symbol = planBuilder2.symbol("count_on_x", BigintType.BIGINT);
            return planBuilder2.filter(new Comparison(Comparison.Operator.GREATER_THAN, new Reference(IntegerType.INTEGER, "count_on_x"), new Constant(IntegerType.INTEGER, 0L)), planBuilder2.project(Assignments.of(planBuilder2.symbol("x_1", IntegerType.INTEGER), new Call(ADD_INTEGER, ImmutableList.of(new Reference(IntegerType.INTEGER, "x"), new Constant(IntegerType.INTEGER, 1L))), symbol, symbol.toSymbolReference()), planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.addAggregation(symbol, PlanBuilder.aggregation("count", (List<Expression>) ImmutableList.of(new Reference(BigintType.BIGINT, "x"))), ImmutableList.of(BigintType.BIGINT)).singleGroupingSet(planBuilder2.symbol("y", BigintType.BIGINT)).source(planBuilder2.values(planBuilder2.symbol("x", BigintType.BIGINT), planBuilder2.symbol("y", BigintType.BIGINT))).nodeId(planNodeId);
            })));
        }).withSourceStats(build).withSourceStats(planNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(50.0d).build()).check((v0) -> {
            v0.outputRowsCountUnknown();
        });
    }
}
