package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DistinctAggregationStrategyChooser.class */
public class DistinctAggregationStrategyChooser {
    private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8;
    private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 64;
    private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 1.048576E8d;
    private final TaskCountEstimator taskCountEstimator;
    private final Metadata metadata;

    public DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        this.taskCountEstimator = (TaskCountEstimator) Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) {
        return new DistinctAggregationStrategyChooser(taskCountEstimator, metadata);
    }

    public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT;
    }

    public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE;
    }

    public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES;
    }

    private OptimizerConfig.DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        OptimizerConfig.DistinctAggregationsStrategy distinctAggregationsStrategy = SystemSessionProperties.distinctAggregationsStrategy(session);
        if (distinctAggregationsStrategy != OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC) {
            return (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT && MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct(aggregationNode)) ? OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT : (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE && OptimizeMixedDistinctAggregations.canUsePreAggregate(aggregationNode)) ? OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE : (distinctAggregationsStrategy == OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES && MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries(aggregationNode) && isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) ? OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES : OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
        }
        double minDistinctValueCountEstimate = getMinDistinctValueCountEstimate(aggregationNode, statsProvider);
        int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(session);
        return (aggregationNode.getGroupingKeys().isEmpty() || Double.isNaN(minDistinctValueCountEstimate) || (minDistinctValueCountEstimate <= ((double) (64 * maxNumberOfConcurrentThreadsForAggregation)) && (minDistinctValueCountEstimate <= ((double) (8 * maxNumberOfConcurrentThreadsForAggregation)) || aggregationNode.getGroupingKeys().size() <= 2))) ? (MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries(aggregationNode) && shouldSplitAggregationToSubqueries(aggregationNode, session, statsProvider, lookup)) ? OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES : (!OptimizeMixedDistinctAggregations.canUsePreAggregate(aggregationNode) || aggregationNode.getGroupingKeys().size() > 2) ? MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct(aggregationNode) ? OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT : OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP : OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE : OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP;
    }

    private int getMaxNumberOfConcurrentThreadsForAggregation(Session session) {
        return this.taskCountEstimator.estimateHashedTaskCount(session) * SystemSessionProperties.getTaskConcurrency(session);
    }

    private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode, StatsProvider statsProvider) {
        PlanNodeStatsEstimate stats = statsProvider.getStats(aggregationNode.getSource());
        return ((Double) aggregationNode.getGroupingKeys().stream().filter(symbol -> {
            return !Double.isNaN(stats.getSymbolStatistics(symbol).getDistinctValuesCount());
        }).map(symbol2 -> {
            return Double.valueOf(stats.getSymbolStatistics(symbol2).getDistinctValuesCount());
        }).max((v0, v1) -> {
            return v0.compareTo(v1);
        }).orElse(Double.valueOf(Double.NaN))).doubleValue();
    }

    private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) {
        if (!isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup) || PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent() || PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).where(planNode -> {
            return (planNode instanceof FilterNode) && isSelective((FilterNode) planNode, statsProvider);
        }).matches() || isAdditionalReadOverheadTooExpensive(aggregationNode, statsProvider, lookup)) {
            return false;
        }
        if (aggregationNode.hasSingleGlobalAggregation()) {
            return true;
        }
        double outputSizeInBytes = statsProvider.getStats(aggregationNode).getOutputSizeInBytes(aggregationNode.getGroupingKeys());
        return !Double.isNaN(outputSizeInBytes) && outputSizeInBytes <= MAX_JOIN_GROUPING_KEYS_SIZE;
    }

    private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, StatsProvider statsProvider, Lookup lookup) {
        Stream<R> flatMap = aggregationNode.getAggregations().values().stream().filter((v0) -> {
            return v0.isDistinct();
        }).flatMap(aggregation -> {
            return aggregation.getArguments().stream();
        });
        Class<Reference> cls = Reference.class;
        Objects.requireNonNull(Reference.class);
        Set set = (Set) flatMap.filter((v1) -> {
            return r1.isInstance(v1);
        }).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
        TableScanNode tableScanNode = (TableScanNode) PlanNodeSearcher.searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement();
        Sets.SetView difference = Sets.difference(ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), set);
        double outputSizeInBytes = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols());
        double outputSizeInBytes2 = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(difference);
        double distinctAggregationsUniqueArgumentCount = (outputSizeInBytes2 * OptimizeMixedDistinctAggregations.distinctAggregationsUniqueArgumentCount(aggregationNode)) + (outputSizeInBytes - outputSizeInBytes2);
        return Double.isNaN(distinctAggregationsUniqueArgumentCount) || Double.isNaN(outputSizeInBytes) || distinctAggregationsUniqueArgumentCount / outputSizeInBytes > 1.5d;
    }

    private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider) {
        return statsProvider.getStats(filterNode).getOutputRowCount() / statsProvider.getStats(filterNode.getSource()).getOutputRowCount() < 0.5d;
    }

    private boolean isAggregationSourceSupportedForSubqueries(PlanNode planNode, Session session, Lookup lookup) {
        if (PlanNodeSearcher.searchFrom(planNode, lookup).where(planNode2 -> {
            return ((planNode2 instanceof TableScanNode) || (planNode2 instanceof FilterNode) || (planNode2 instanceof ProjectNode) || (planNode2 instanceof UnionNode)) ? false : true;
        }).findFirst().isPresent()) {
            return false;
        }
        List<PlanNode> findAll = PlanNodeSearcher.searchFrom(planNode, lookup).whereIsInstanceOfAny(TableScanNode.class).findAll();
        if (findAll.isEmpty()) {
            return false;
        }
        return findAll.stream().allMatch(planNode3 -> {
            return this.metadata.allowSplittingReadIntoMultipleSubQueries(session, ((TableScanNode) planNode3).getTable());
        });
    }
}
