package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.sql.ir.Booleans;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnnestNode;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.class */
public class DecorrelateLeftUnnestWithGlobalAggregation implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(Booleans.TRUE)).matching(correlatedJoinNode -> {
        return correlatedJoinNode.getType() == JoinType.INNER || correlatedJoinNode.getType() == JoinType.LEFT;
    });

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        if (PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateLeftUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(planNode -> {
            return (planNode instanceof ProjectNode) || isGroupedAggregation(planNode);
        }).findFirst().isEmpty()) {
            return Rule.Result.empty();
        }
        Optional<PlanNode> findFirst = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(planNode2 -> {
            return isSupportedUnnest(planNode2, correlatedJoinNode.getCorrelation(), context.getLookup());
        }).recurseOnlyWhen(planNode3 -> {
            return (planNode3 instanceof ProjectNode) || isGlobalAggregation(planNode3) || isGroupedAggregation(planNode3);
        }).findFirst();
        if (findFirst.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode) findFirst.get();
        PlanNode assignUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT));
        PlanNode resolve = context.getLookup().resolve(unnestNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            assignUniqueId = new ProjectNode(projectNode.getId(), assignUniqueId, Assignments.builder().putIdentities(assignUniqueId.getOutputSymbols()).putAll(projectNode.getAssignments()).build());
        }
        PlanNode rewriteNodeSequence = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), assignUniqueId.getOutputSymbols(), new UnnestNode(context.getIdAllocator().getNextId(), assignUniqueId, assignUniqueId.getOutputSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), JoinType.LEFT), unnestNode.getId(), context.getLookup());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewriteNodeSequence, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewriteNodeSequence));
    }

    private static boolean isGlobalAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasSingleGlobalAggregation() && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isGroupedAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Code restructure failed: missing block: B:9:0x006a, code lost:
    
        if (com.google.common.collect.ImmutableSet.copyOf(r4).containsAll(io.trino.sql.planner.SymbolsExtractor.extractUnique(((io.trino.sql.planner.plan.ProjectNode) r0).getAssignments().getExpressions())) != false) goto L13;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static boolean isSupportedUnnest(io.trino.sql.planner.plan.PlanNode r3, java.util.List<io.trino.sql.planner.Symbol> r4, io.trino.sql.planner.iterative.Lookup r5) {
        /*
            r0 = r3
            boolean r0 = r0 instanceof io.trino.sql.planner.plan.UnnestNode
            if (r0 == 0) goto Lf
            r0 = r3
            io.trino.sql.planner.plan.UnnestNode r0 = (io.trino.sql.planner.plan.UnnestNode) r0
            r6 = r0
            goto L11
        Lf:
            r0 = 0
            return r0
        L11:
            r0 = r6
            java.util.List r0 = r0.getMappings()
            java.util.stream.Stream r0 = r0.stream()
            boolean r1 = (v0) -> { // java.util.function.Function.apply(java.lang.Object):java.lang.Object
                return v0.getInput();
            }
            java.util.stream.Stream r0 = r0.map(r1)
            java.util.stream.Collector r1 = com.google.common.collect.ImmutableList.toImmutableList()
            java.lang.Object r0 = r0.collect(r1)
            java.util.List r0 = (java.util.List) r0
            r7 = r0
            r0 = r5
            r1 = r6
            io.trino.sql.planner.plan.PlanNode r1 = r1.getSource()
            io.trino.sql.planner.plan.PlanNode r0 = r0.resolve(r1)
            r8 = r0
            r0 = r4
            com.google.common.collect.ImmutableSet r0 = com.google.common.collect.ImmutableSet.copyOf(r0)
            r1 = r7
            boolean r0 = r0.containsAll(r1)
            if (r0 != 0) goto L6d
            r0 = r8
            boolean r0 = r0 instanceof io.trino.sql.planner.plan.ProjectNode
            if (r0 == 0) goto L71
            r0 = r8
            io.trino.sql.planner.plan.ProjectNode r0 = (io.trino.sql.planner.plan.ProjectNode) r0
            r10 = r0
            r0 = r4
            com.google.common.collect.ImmutableSet r0 = com.google.common.collect.ImmutableSet.copyOf(r0)
            r1 = r10
            io.trino.sql.planner.plan.Assignments r1 = r1.getAssignments()
            java.util.Collection r1 = r1.getExpressions()
            java.util.Set r1 = io.trino.sql.planner.SymbolsExtractor.extractUnique(r1)
            boolean r0 = r0.containsAll(r1)
            if (r0 == 0) goto L71
        L6d:
            r0 = 1
            goto L72
        L71:
            r0 = 0
        L72:
            r9 = r0
            r0 = r6
            io.trino.sql.planner.plan.PlanNode r0 = r0.getSource()
            r1 = r5
            boolean r0 = io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar(r0, r1)
            if (r0 == 0) goto L9e
            r0 = r6
            java.util.List r0 = r0.getReplicateSymbols()
            boolean r0 = r0.isEmpty()
            if (r0 == 0) goto L9e
            r0 = r9
            if (r0 == 0) goto L9e
            r0 = r6
            io.trino.sql.planner.plan.JoinType r0 = r0.getJoinType()
            io.trino.sql.planner.plan.JoinType r1 = io.trino.sql.planner.plan.JoinType.LEFT
            if (r0 != r1) goto L9e
            r0 = 1
            goto L9f
        L9e:
            r0 = 0
        L9f:
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: io.trino.sql.planner.iterative.rule.DecorrelateLeftUnnestWithGlobalAggregation.isSupportedUnnest(io.trino.sql.planner.plan.PlanNode, java.util.List, io.trino.sql.planner.iterative.Lookup):boolean");
    }

    private static PlanNode rewriteNodeSequence(PlanNode planNode, List<Symbol> list, PlanNode planNode2, PlanNodeId planNodeId, Lookup lookup) {
        if (planNode.getId().equals(planNodeId)) {
            return planNode2;
        }
        PlanNode rewriteNodeSequence = rewriteNodeSequence(lookup.resolve((PlanNode) Iterables.getOnlyElement(planNode.getSources())), list, planNode2, planNodeId, lookup);
        if (planNode instanceof AggregationNode) {
            return withGrouping((AggregationNode) planNode, list, rewriteNodeSequence);
        }
        if (!(planNode instanceof ProjectNode)) {
            throw new IllegalStateException("unexpected node: " + String.valueOf(planNode));
        }
        ProjectNode projectNode = (ProjectNode) planNode;
        return new ProjectNode(projectNode.getId(), rewriteNodeSequence, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(list).build());
    }

    private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> list, PlanNode planNode) {
        return AggregationNode.singleAggregation(aggregationNode.getId(), planNode, aggregationNode.getAggregations(), AggregationNode.singleGroupingSet((List) Streams.concat(new Stream[]{list.stream(), aggregationNode.getGroupingKeys().stream()}).distinct().collect(ImmutableList.toImmutableList())));
    }
}
