package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.StreamPreferredProperties;
import io.trino.sql.planner.optimizations.StreamPropertyDerivations;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin.class */
public class AdaptiveReorderPartitionedJoin implements Rule<JoinNode> {
    private static final Capture<ExchangeNode> LOCAL_EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(AdaptiveReorderPartitionedJoin::isPartitionedJoinWithNoHashSymbols).or(new Function[]{pattern -> {
        return pattern.with(Patterns.Join.right().matching(Patterns.exchange().matching(exchangeNode -> {
            return exchangeNode.getScope().equals(ExchangeNode.Scope.LOCAL) && !exchangeNode.getType().equals(ExchangeNode.Type.GATHER);
        }).capturedAs(LOCAL_EXCHANGE_NODE)));
    }, pattern2 -> {
        return pattern2.with(Patterns.Join.right().matching(Patterns.aggregation().matching(aggregationNode -> {
            return aggregationNode.getStep() == AggregationNode.Step.PARTIAL;
        }).with(Patterns.source().matching(Patterns.exchange().matching(exchangeNode -> {
            return exchangeNode.getScope().equals(ExchangeNode.Scope.LOCAL) && !exchangeNode.getType().equals(ExchangeNode.Type.GATHER);
        }).capturedAs(LOCAL_EXCHANGE_NODE)))));
    }});
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin$BuildToProbeLocalExchangeRewriter.class */
    public static class BuildToProbeLocalExchangeRewriter extends SimplePlanRewriter<Void> {
        private final PlanNodeId localExchangeNodeId;
        private final Rule.Context context;

        private BuildToProbeLocalExchangeRewriter(PlanNodeId planNodeId, Rule.Context context) {
            this.localExchangeNodeId = (PlanNodeId) Objects.requireNonNull(planNodeId, "localExchangeNodeId is null");
            this.context = (Rule.Context) Objects.requireNonNull(context, "context is null");
        }

        @Override // io.trino.sql.planner.plan.SimplePlanRewriter, io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            throw new UnsupportedOperationException("Unexpected plan node: " + planNode.getClass().getSimpleName());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAggregation(AggregationNode aggregationNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Verify.verify(aggregationNode.getStep() == AggregationNode.Step.PARTIAL, "Unexpected aggregation step: %s", aggregationNode.getStep());
            return AdaptiveReorderPartitionedJoin.rewriteSources(this, aggregationNode, this.context);
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            Verify.verify(exchangeNode.getScope().equals(ExchangeNode.Scope.LOCAL) && exchangeNode.getId().equals(this.localExchangeNodeId), "Unexpected exchange node: %s", exchangeNode.getId());
            return exchangeNode.getSources().size() == 1 ? (PlanNode) exchangeNode.getSources().getFirst() : ExchangeNode.roundRobinExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, exchangeNode.getSources(), exchangeNode.getOutputSymbols());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/AdaptiveReorderPartitionedJoin$ProbeToBuildLocalExchangeRewriter.class */
    public static class ProbeToBuildLocalExchangeRewriter extends SimplePlanRewriter<Void> {
        private final Rule.Context context;
        private final List<Symbol> buildSymbols;

        private ProbeToBuildLocalExchangeRewriter(List<Symbol> list, Rule.Context context) {
            this.buildSymbols = (List) Objects.requireNonNull(list, "buildSymbols is null");
            this.context = (Rule.Context) Objects.requireNonNull(context, "context is null");
        }

        @Override // io.trino.sql.planner.plan.SimplePlanRewriter, io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            return ExchangeNode.partitionedExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, planNode, this.buildSymbols, (Optional<Symbol>) Optional.empty());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            return (exchangeNode.getScope().equals(ExchangeNode.Scope.LOCAL) && exchangeNode.getSources().size() > 1 && exchangeNode.getPartitioningScheme().getPartitioning().getHandle().equals(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION)) ? ExchangeNode.partitionedExchange(this.context.getIdAllocator().getNextId(), ExchangeNode.Scope.LOCAL, exchangeNode.getSources(), this.buildSymbols, exchangeNode.getOutputSymbols()) : visitPlan((PlanNode) exchangeNode, rewriteContext);
        }
    }

    public AdaptiveReorderPartitionedJoin(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    private static boolean isPartitionedJoinWithNoHashSymbols(JoinNode joinNode) {
        return joinNode.getDistributionType().equals(Optional.of(JoinNode.DistributionType.PARTITIONED)) && joinNode.getRightHashSymbol().isEmpty() && joinNode.getLeftHashSymbol().isEmpty();
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK && SystemSessionProperties.isFaultTolerantExecutionAdaptiveJoinReorderingEnabled(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode) captures.get(LOCAL_EXCHANGE_NODE);
        if (isBuildSideLocalExchangeNode(exchangeNode, ImmutableSet.copyOf(Lists.transform(joinNode.getCriteria(), (v0) -> {
            return v0.getRight();
        }))) && flipJoinBasedOnStats(joinNode, context)) {
            return Rule.Result.ofPlanNode(flipJoinAndFixLocalExchanges(joinNode, exchangeNode.getId(), this.metadata, context));
        }
        return Rule.Result.empty();
    }

    private static boolean isBuildSideLocalExchangeNode(ExchangeNode exchangeNode, Set<Symbol> set) {
        return exchangeNode.getScope() == ExchangeNode.Scope.LOCAL && exchangeNode.getPartitioningScheme().getPartitioning().getColumns().equals(set) && exchangeNode.getPartitioningScheme().getHashColumn().isEmpty();
    }

    private static JoinNode flipJoinAndFixLocalExchanges(JoinNode joinNode, PlanNodeId planNodeId, Metadata metadata, Rule.Context context) {
        JoinNode flipChildren = joinNode.flipChildren();
        PlanNode rewriteWith = SimplePlanRewriter.rewriteWith(new BuildToProbeLocalExchangeRewriter(planNodeId, context), context.getLookup().resolve(flipChildren.getLeft()));
        PlanNode right = flipChildren.getRight();
        StreamPropertyDerivations.StreamProperties deriveStreamPropertiesRecursively = deriveStreamPropertiesRecursively(right, metadata, context.getLookup(), context.getSession());
        List transform = Lists.transform(flipChildren.getCriteria(), (v0) -> {
            return v0.getRight();
        });
        if (!StreamPreferredProperties.partitionedOn(transform).isSatisfiedBy(deriveStreamPropertiesRecursively)) {
            right = SimplePlanRewriter.rewriteWith(new ProbeToBuildLocalExchangeRewriter(transform, context), context.getLookup().resolve(right));
        }
        return new JoinNode(flipChildren.getId(), flipChildren.getType(), rewriteWith, right, flipChildren.getCriteria(), flipChildren.getLeftOutputSymbols(), flipChildren.getRightOutputSymbols(), flipChildren.isMaySkipOutputDuplicates(), flipChildren.getFilter(), flipChildren.getLeftHashSymbol(), flipChildren.getRightHashSymbol(), flipChildren.getDistributionType(), flipChildren.isSpillable(), flipChildren.getDynamicFilters(), flipChildren.getReorderJoinStatsAndCost());
    }

    private static boolean flipJoinBasedOnStats(JoinNode joinNode, Rule.Context context) {
        double firstKnownOutputSizeInBytes = PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes(joinNode.getLeft(), context.getLookup(), context.getStatsProvider());
        double firstKnownOutputSizeInBytes2 = PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes(joinNode.getRight(), context.getLookup(), context.getStatsProvider());
        return firstKnownOutputSizeInBytes2 > ((double) SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingMinSizeThreshold(context.getSession()).toBytes()) && firstKnownOutputSizeInBytes2 > SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingSizeDifferenceRatio(context.getSession()) * firstKnownOutputSizeInBytes;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static StreamPropertyDerivations.StreamProperties deriveStreamPropertiesRecursively(PlanNode planNode, Metadata metadata, Lookup lookup, Session session) {
        PlanNode resolve = lookup.resolve(planNode);
        return StreamPropertyDerivations.deriveStreamPropertiesWithoutActualProperties(resolve, (List) resolve.getSources().stream().map(planNode2 -> {
            return deriveStreamPropertiesRecursively(planNode2, metadata, lookup, session);
        }).collect(ImmutableList.toImmutableList()), metadata, session);
    }

    private static PlanNode rewriteSources(SimplePlanRewriter<Void> simplePlanRewriter, PlanNode planNode, Rule.Context context) {
        ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(planNode.getSources().size());
        planNode.getSources().forEach(planNode2 -> {
            builderWithExpectedSize.add(SimplePlanRewriter.rewriteWith(simplePlanRewriter, context.getLookup().resolve(planNode2)));
        });
        return ChildReplacer.replaceChildren(planNode, builderWithExpectedSize.build());
    }
}
