package io.trino.cost;

import com.google.common.base.Verify;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsCalculator;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.TopNRankingNode;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/cost/TopNRankingStatsRule.class */
public class TopNRankingStatsRule extends SimpleStatsRule<TopNRankingNode> {
    private static final double INDEPENDENCE_FACTOR = 0.9d;
    private static final Pattern<TopNRankingNode> PATTERN = Patterns.topNRanking();

    public TopNRankingStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // io.trino.cost.ComposableStatsCalculator.Rule
    public Pattern<TopNRankingNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(TopNRankingNode topNRankingNode, StatsCalculator.Context context) {
        double min;
        double min2;
        PlanNodeStatsEstimate stats = context.statsProvider().getStats(topNRankingNode.getSource());
        if (stats.isOutputRowCountUnknown()) {
            return Optional.empty();
        }
        if (topNRankingNode.isPartial()) {
            return Optional.of(stats);
        }
        double outputRowCount = stats.getOutputRowCount();
        if (outputRowCount == 0.0d) {
            return Optional.of(PlanNodeStatsEstimate.buildFrom(stats).addSymbolStatistics(topNRankingNode.getRankingSymbol(), SymbolStatsEstimate.zero()).build());
        }
        double min3 = Math.min(outputRowCount, estimateCorrelatedPartitionCount(topNRankingNode.getPartitionBy(), stats));
        if (Double.isNaN(min3)) {
            return Optional.of(stats);
        }
        double d = outputRowCount / min3;
        if (topNRankingNode.getRankingType() == TopNRankingNode.RankingType.ROW_NUMBER) {
            min = Math.min(d, topNRankingNode.getMaxRankingPerPartition());
            min2 = min;
        } else if (topNRankingNode.getRankingType() == TopNRankingNode.RankingType.RANK) {
            double min4 = d / Math.min(estimateCorrelatedPartitionCount(topNRankingNode.getOrderingScheme().orderBy(), stats), d);
            min2 = Math.ceil(topNRankingNode.getMaxRankingPerPartition() / min4);
            min = Math.min(d, min2 * min4);
        } else {
            double min5 = Math.min(estimateCorrelatedPartitionCount(topNRankingNode.getOrderingScheme().orderBy(), stats), d);
            min = Math.min(d, topNRankingNode.getMaxRankingPerPartition() * (d / min5));
            min2 = Math.min(min5, topNRankingNode.getMaxRankingPerPartition());
        }
        if (Double.isNaN(min)) {
            min = d;
        }
        if (Double.isNaN(min2)) {
            min2 = min;
        }
        PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(stats);
        adjustOrderBySymbolDistinctCount(topNRankingNode.getOrderingScheme().orderBy(), min3, min, d, stats, buildFrom);
        return Optional.of(buildFrom.setOutputRowCount(min3 * min).addSymbolStatistics(topNRankingNode.getRankingSymbol(), SymbolStatsEstimate.builder().setLowValue(1.0d).setHighValue(topNRankingNode.getMaxRankingPerPartition()).setDistinctValuesCount(min2).setNullsFraction(0.0d).setAverageRowSize(BigintType.BIGINT.getFixedSize()).build()).build());
    }

    private static void adjustOrderBySymbolDistinctCount(List<Symbol> list, double d, double d2, double d3, PlanNodeStatsEstimate planNodeStatsEstimate, PlanNodeStatsEstimate.Builder builder) {
        Verify.verify(!list.isEmpty(), "Order by symbols should not be empty for TopNRankingNode.", new Object[0]);
        Symbol symbol = (Symbol) list.getFirst();
        SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(symbol);
        double ceil = Math.ceil((Math.min(symbolStatistics.getDistinctValuesCount(), d3) / d3) * d2);
        if (Double.isNaN(ceil)) {
            return;
        }
        builder.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(symbolStatistics).setDistinctValuesCount(Math.min(symbolStatistics.getDistinctValuesCount(), ceil * d)).build());
    }

    private static double estimateCorrelatedPartitionCount(List<Symbol> list, PlanNodeStatsEstimate planNodeStatsEstimate) {
        double d = 1.0d;
        double d2 = 1.0d;
        Iterator<Symbol> it = list.iterator();
        while (it.hasNext()) {
            SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(it.next());
            d *= Math.pow(symbolStatistics.getDistinctValuesCount() + (symbolStatistics.getNullsFraction() == 0.0d ? 0 : 1), d2);
            d2 *= INDEPENDENCE_FACTOR;
        }
        return d;
    }
}
