package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.assertj.core.util.VisibleForTesting;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.class */
public class PushPartialAggregationThroughJoin {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin$PushPartialAggregationThroughJoinWithProjection.class */
    public class PushPartialAggregationThroughJoinWithProjection implements Rule<AggregationNode> {
        private static final Pattern<AggregationNode> PATTERN_WITH_PROJECTION = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.project().with(Patterns.source().matching(Patterns.join()))));

        private PushPartialAggregationThroughJoinWithProjection() {
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<AggregationNode> getPattern() {
            return PATTERN_WITH_PROJECTION;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushPartialAggregationThroughJoin(session);
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            Optional<PlanNode> pushProjectionThroughJoin = PushProjectionThroughJoin.pushProjectionThroughJoin((ProjectNode) context.getLookup().resolve(aggregationNode.getSource()), context.getLookup(), context.getIdAllocator());
            return pushProjectionThroughJoin.isEmpty() ? Rule.Result.empty() : PushPartialAggregationThroughJoin.this.applyPushdown((AggregationNode) aggregationNode.replaceChildren(ImmutableList.of(pushProjectionThroughJoin.get())), context);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin$PushPartialAggregationThroughJoinWithoutProjection.class */
    public class PushPartialAggregationThroughJoinWithoutProjection implements Rule<AggregationNode> {
        private static final Pattern<AggregationNode> PATTERN_WITHOUT_PROJECTION = Patterns.aggregation().matching(PushPartialAggregationThroughJoin::isSupportedAggregationNode).with(Patterns.source().matching(Patterns.join()));

        private PushPartialAggregationThroughJoinWithoutProjection() {
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<AggregationNode> getPattern() {
            return PATTERN_WITHOUT_PROJECTION;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.isPushPartialAggregationThroughJoin(session);
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            return PushPartialAggregationThroughJoin.this.applyPushdown(aggregationNode, context);
        }
    }

    private static boolean isSupportedAggregationNode(AggregationNode aggregationNode) {
        return !aggregationNode.isStreamable() && !aggregationNode.getHashSymbol().isPresent() && aggregationNode.getStep() == AggregationNode.Step.PARTIAL && aggregationNode.getGroupingSetCount() == 1;
    }

    public Iterable<Rule<?>> rules() {
        return ImmutableList.of(pushPartialAggregationThroughJoinWithoutProjection(), pushPartialAggregationThroughJoinWithProjection());
    }

    @VisibleForTesting
    Rule<?> pushPartialAggregationThroughJoinWithoutProjection() {
        return new PushPartialAggregationThroughJoinWithoutProjection();
    }

    @VisibleForTesting
    Rule<?> pushPartialAggregationThroughJoinWithProjection() {
        return new PushPartialAggregationThroughJoinWithProjection();
    }

    private Rule.Result applyPushdown(AggregationNode aggregationNode, Rule.Context context) {
        JoinNode joinNode = (JoinNode) context.getLookup().resolve(aggregationNode.getSource());
        return joinNode.getType() != JoinType.INNER ? Rule.Result.empty() : allAggregationsOn(aggregationNode.getAggregations(), joinNode.getLeft().getOutputSymbols()) ? (Rule.Result) pushPartialToLeftChild(aggregationNode, joinNode, context).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty()) : allAggregationsOn(aggregationNode.getAggregations(), joinNode.getRight().getOutputSymbols()) ? (Rule.Result) pushPartialToRightChild(aggregationNode, joinNode, context).map(Rule.Result::ofPlanNode).orElse(Rule.Result.empty()) : Rule.Result.empty();
    }

    private static boolean allAggregationsOn(Map<Symbol, AggregationNode.Aggregation> map, List<Symbol> list) {
        return list.containsAll((Set) map.values().stream().map(SymbolsExtractor::extractAll).flatMap((v0) -> {
            return v0.stream();
        }).collect(ImmutableSet.toImmutableSet()));
    }

    private Optional<PlanNode> pushPartialToLeftChild(AggregationNode aggregationNode, JoinNode joinNode, Rule.Context context) {
        return getPushedAggregation(aggregationNode, joinNode, joinNode.getLeft(), context).map(aggregationNode2 -> {
            return replaceJoin(aggregationNode, aggregationNode2, joinNode, aggregationNode2, joinNode.getRight(), context);
        });
    }

    private Optional<PlanNode> pushPartialToRightChild(AggregationNode aggregationNode, JoinNode joinNode, Rule.Context context) {
        return getPushedAggregation(aggregationNode, joinNode, joinNode.getRight(), context).map(aggregationNode2 -> {
            return replaceJoin(aggregationNode, aggregationNode2, joinNode, joinNode.getLeft(), aggregationNode2, context);
        });
    }

    private Optional<AggregationNode> getPushedAggregation(AggregationNode aggregationNode, JoinNode joinNode, PlanNode planNode, Rule.Context context) {
        ImmutableSet copyOf = ImmutableSet.copyOf(planNode.getOutputSymbols());
        AggregationNode replaceAggregationSource = replaceAggregationSource(aggregationNode, planNode, getPushedDownGroupingSet(aggregationNode, copyOf, Sets.intersection(getJoinRequiredSymbols(joinNode), copyOf)));
        return skipPartialAggregationPushdown(joinNode, aggregationNode, replaceAggregationSource, context) ? Optional.empty() : Optional.of(replaceAggregationSource);
    }

    private boolean skipPartialAggregationPushdown(JoinNode joinNode, AggregationNode aggregationNode, AggregationNode aggregationNode2, Rule.Context context) {
        PlanNodeStatsEstimate stats = context.getStatsProvider().getStats(aggregationNode2.getSource());
        double outputRowCount = stats.getOutputRowCount();
        double outputRowCount2 = context.getStatsProvider().getStats(joinNode).getOutputRowCount();
        if (Double.isNaN(outputRowCount) || Double.isNaN(outputRowCount2) || outputRowCount2 > 1.1d * outputRowCount || ImmutableSet.copyOf(aggregationNode.getGroupingKeys()).size() < ImmutableSet.copyOf(aggregationNode2.getGroupingKeys()).size()) {
            return true;
        }
        Iterator<Symbol> it = aggregationNode2.getGroupingKeys().iterator();
        while (it.hasNext()) {
            double distinctValuesCount = stats.getSymbolStatistics(it.next()).getDistinctValuesCount();
            if (Double.isNaN(distinctValuesCount) || distinctValuesCount * 2.0d > outputRowCount) {
                return true;
            }
        }
        return false;
    }

    private Set<Symbol> getJoinRequiredSymbols(JoinNode joinNode) {
        return (Set) Streams.concat(new Stream[]{joinNode.getCriteria().stream().map((v0) -> {
            return v0.getLeft();
        }), joinNode.getCriteria().stream().map((v0) -> {
            return v0.getRight();
        }), ((Set) joinNode.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getLeftHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getRightHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregationNode, Set<Symbol> set, Set<Symbol> set2) {
        Stream<Symbol> stream = aggregationNode.getGroupingKeys().stream();
        Objects.requireNonNull(set);
        List<Symbol> list = (List) stream.filter((v1) -> {
            return r1.contains(v1);
        }).collect(Collectors.toList());
        HashSet hashSet = new HashSet(list);
        Stream<Symbol> stream2 = set2.stream();
        Objects.requireNonNull(hashSet);
        Stream<Symbol> filter = stream2.filter((v1) -> {
            return r1.add(v1);
        });
        Objects.requireNonNull(list);
        filter.forEach((v1) -> {
            r1.add(v1);
        });
        return list;
    }

    private AggregationNode replaceAggregationSource(AggregationNode aggregationNode, PlanNode planNode, List<Symbol> list) {
        return AggregationNode.builderFrom(aggregationNode).setSource(planNode).setGroupingSets(AggregationNode.singleGroupingSet(list)).setPreGroupedSymbols(ImmutableList.of()).setIsInputReducingAggregation(false).build();
    }

    private PlanNode replaceJoin(AggregationNode aggregationNode, AggregationNode aggregationNode2, JoinNode joinNode, PlanNode planNode, PlanNode planNode2, Rule.Context context) {
        JoinNode joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), planNode, planNode2, joinNode.getCriteria(), planNode.getOutputSymbols(), planNode2.getOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
        PlanNode orElse = Util.restrictOutputs(context.getIdAllocator(), joinNode2, ImmutableSet.copyOf(aggregationNode.getOutputSymbols())).orElse(joinNode2);
        if (aggregationNode.isInputReducingAggregation() && !ImmutableSet.copyOf(aggregationNode.getGroupingKeys()).containsAll(aggregationNode2.getGroupingKeys())) {
            orElse = toIntermediateAggregation(aggregationNode, orElse, context);
        }
        return orElse;
    }

    private PlanNode toIntermediateAggregation(AggregationNode aggregationNode, PlanNode planNode, Rule.Context context) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            ResolvedFunction resolvedFunction = value.getResolvedFunction();
            Symbol key = entry.getKey();
            ImmutableList.Builder add = ImmutableList.builder().add(entry.getKey().toSymbolReference());
            Stream<Expression> stream = value.getArguments().stream();
            Class<Lambda> cls = Lambda.class;
            Objects.requireNonNull(Lambda.class);
            hashMap.put(key, new AggregationNode.Aggregation(resolvedFunction, add.addAll((Iterable) stream.filter((v1) -> {
                return r7.isInstance(v1);
            }).collect(ImmutableList.toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return new AggregationNode(context.getIdAllocator().getNextId(), planNode, hashMap, aggregationNode.getGroupingSets(), ImmutableList.of(), AggregationNode.Step.INTERMEDIATE, Optional.empty(), aggregationNode.getGroupIdSymbol());
    }
}
