package io.trino.sql.planner;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.Traverser;
import io.trino.Session;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AdaptivePlanNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.sanity.PlanSanityChecker;
import io.trino.tracing.ScopedSpan;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner.class */
public class AdaptivePlanner {
    private final Session session;
    private final PlannerContext plannerContext;
    private final List<AdaptivePlanOptimizer> planOptimizers;
    private final PlanFragmenter planFragmenter;
    private final PlanSanityChecker planSanityChecker;
    private final WarningCollector warningCollector;
    private final PlanOptimizersStatsCollector planOptimizersStatsCollector;
    private final CachingTableStatsProvider tableStatsProvider;
    private final Set<PlanNodeId> cummulativeChangedPlanNodes = new HashSet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner$CurrentPlanRewriter.class */
    public static class CurrentPlanRewriter extends SimplePlanRewriter<List<SubPlan>> {
        private CurrentPlanRewriter() {
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode adaptivePlanNode, SimplePlanRewriter.RewriteContext<List<SubPlan>> rewriteContext) {
            Verify.verify(!AdaptivePlanner.containsAdaptivePlanNode(adaptivePlanNode.getCurrentPlan()), "Adaptive plan node cannot have a nested adaptive plan node", new Object[0]);
            return adaptivePlanNode.getCurrentPlan();
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner$ExchangeSourceId.class */
    public static final class ExchangeSourceId extends Record {
        private final PlanNodeId exchangeId;
        private final PlanNodeId sourceId;

        public ExchangeSourceId(PlanNodeId planNodeId, PlanNodeId planNodeId2) {
            Objects.requireNonNull(planNodeId, "exchangeId is null");
            Objects.requireNonNull(planNodeId2, "sourceId is null");
            this.exchangeId = planNodeId;
            this.sourceId = planNodeId2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ExchangeSourceId.class), ExchangeSourceId.class, "exchangeId;sourceId", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->exchangeId:Lio/trino/sql/planner/plan/PlanNodeId;", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->sourceId:Lio/trino/sql/planner/plan/PlanNodeId;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ExchangeSourceId.class), ExchangeSourceId.class, "exchangeId;sourceId", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->exchangeId:Lio/trino/sql/planner/plan/PlanNodeId;", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->sourceId:Lio/trino/sql/planner/plan/PlanNodeId;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ExchangeSourceId.class, Object.class), ExchangeSourceId.class, "exchangeId;sourceId", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->exchangeId:Lio/trino/sql/planner/plan/PlanNodeId;", "FIELD:Lio/trino/sql/planner/AdaptivePlanner$ExchangeSourceId;->sourceId:Lio/trino/sql/planner/plan/PlanNodeId;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public PlanNodeId exchangeId() {
            return this.exchangeId;
        }

        public PlanNodeId sourceId() {
            return this.sourceId;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner$ExchangeSourceIdToSubPlanCollector.class */
    public static class ExchangeSourceIdToSubPlanCollector extends SimplePlanVisitor<List<SubPlan>> {
        private final Map<ExchangeSourceId, SubPlan> exchangeSourceIdToSubPlan = new HashMap();

        private ExchangeSourceIdToSubPlanCollector() {
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitExchange(ExchangeNode exchangeNode, List<SubPlan> list) {
            visitPlan((PlanNode) exchangeNode, (ExchangeNode) list);
            if (exchangeNode.getScope() != ExchangeNode.Scope.REMOTE) {
                return null;
            }
            Set set = (Set) exchangeNode.getSources().stream().map((v0) -> {
                return v0.getId();
            }).collect(ImmutableSet.toImmutableSet());
            List<SubPlan> list2 = (List) list.stream().filter(subPlan -> {
                return set.contains(subPlan.getFragment().getRoot().getId());
            }).collect(ImmutableList.toImmutableList());
            if (list2.size() != set.size()) {
                throw new IllegalStateException(String.format("Source subPlans not found for exchange node %s; sourceIds: %s; filteredSubPlans: %s; allSubPlans: %s", exchangeNode.getId(), set, list2.stream().map(subPlan2 -> {
                    return String.valueOf(subPlan2.getFragment().getId()) + "->" + String.valueOf(subPlan2.getFragment().getRoot().getId());
                }).collect(ImmutableList.toImmutableList()), list.stream().map(subPlan3 -> {
                    return String.valueOf(subPlan3.getFragment().getId()) + "->" + String.valueOf(subPlan3.getFragment().getRoot().getId());
                }).collect(ImmutableList.toImmutableList())));
            }
            for (SubPlan subPlan4 : list2) {
                this.exchangeSourceIdToSubPlan.put(new ExchangeSourceId(exchangeNode.getId(), subPlan4.getFragment().getRoot().getId()), subPlan4);
            }
            return null;
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public Void visitRemoteSource(RemoteSourceNode remoteSourceNode, List<SubPlan> list) {
            for (SubPlan subPlan : (List) list.stream().filter(subPlan2 -> {
                return remoteSourceNode.getSourceFragmentIds().contains(subPlan2.getFragment().getId());
            }).collect(ImmutableList.toImmutableList())) {
                this.exchangeSourceIdToSubPlan.put(new ExchangeSourceId(remoteSourceNode.getId(), subPlan.getFragment().getRoot().getId()), subPlan);
            }
            return null;
        }

        public Map<ExchangeSourceId, SubPlan> getExchangeSourceIdToSubPlan() {
            return ImmutableMap.copyOf(this.exchangeSourceIdToSubPlan);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner$InitialPlanRewriter.class */
    public static class InitialPlanRewriter extends SimplePlanRewriter<List<SubPlan>> {
        private InitialPlanRewriter() {
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode adaptivePlanNode, SimplePlanRewriter.RewriteContext<List<SubPlan>> rewriteContext) {
            Verify.verify(!AdaptivePlanner.containsAdaptivePlanNode(adaptivePlanNode.getInitialPlan()), "Adaptive plan node cannot have a nested adaptive plan node", new Object[0]);
            return adaptivePlanNode.getInitialPlan();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/AdaptivePlanner$ReplaceRemoteSourcesWithExchanges.class */
    public static class ReplaceRemoteSourcesWithExchanges extends SimplePlanRewriter<List<SubPlan>> {
        private final RuntimeInfoProvider runtimeInfoProvider;

        private ReplaceRemoteSourcesWithExchanges(RuntimeInfoProvider runtimeInfoProvider) {
            this.runtimeInfoProvider = (RuntimeInfoProvider) Objects.requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode adaptivePlanNode, SimplePlanRewriter.RewriteContext<List<SubPlan>> rewriteContext) {
            PlanNode rewrite = rewriteContext.rewrite(adaptivePlanNode.getInitialPlan(), rewriteContext.get());
            return new AdaptivePlanNode(adaptivePlanNode.getId(), rewrite, SymbolsExtractor.extractOutputSymbols(rewrite), rewriteContext.rewrite(adaptivePlanNode.getCurrentPlan(), rewriteContext.get()));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitRemoteSource(RemoteSourceNode remoteSourceNode, SimplePlanRewriter.RewriteContext<List<SubPlan>> rewriteContext) {
            if (remoteSourceNode.getSourceFragmentIds().stream().anyMatch(planFragmentId -> {
                return this.runtimeInfoProvider.getRuntimeOutputStats(planFragmentId).isAccurate();
            })) {
                return remoteSourceNode;
            }
            List<SubPlan> list = (List) rewriteContext.get().stream().filter(subPlan -> {
                return remoteSourceNode.getSourceFragmentIds().contains(subPlan.getFragment().getId());
            }).collect(ImmutableList.toImmutableList());
            ImmutableList.Builder builder = ImmutableList.builder();
            for (SubPlan subPlan2 : list) {
                builder.add(rewriteContext.rewrite(subPlan2.getFragment().getRoot(), subPlan2.getChildren()));
            }
            ImmutableList build = builder.build();
            Stream<PlanFragmentId> stream = remoteSourceNode.getSourceFragmentIds().stream();
            RuntimeInfoProvider runtimeInfoProvider = this.runtimeInfoProvider;
            Objects.requireNonNull(runtimeInfoProvider);
            List list2 = (List) stream.map(runtimeInfoProvider::getPlanFragment).map((v0) -> {
                return v0.getOutputPartitioningScheme();
            }).collect(ImmutableList.toImmutableList());
            Verify.verify(list2.size() == build.size(), "Output partitioning schemes size does not match source nodes size", new Object[0]);
            return new ExchangeNode(remoteSourceNode.getId(), remoteSourceNode.getExchangeType(), ExchangeNode.Scope.REMOTE, ((PartitioningScheme) list2.getFirst()).translateOutputLayout(remoteSourceNode.getOutputSymbols()), build, (List) list2.stream().map((v0) -> {
                return v0.getOutputLayout();
            }).collect(ImmutableList.toImmutableList()), remoteSourceNode.getOrderingScheme());
        }
    }

    public AdaptivePlanner(Session session, PlannerContext plannerContext, List<AdaptivePlanOptimizer> list, PlanFragmenter planFragmenter, PlanSanityChecker planSanityChecker, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, CachingTableStatsProvider cachingTableStatsProvider) {
        this.session = (Session) Objects.requireNonNull(session, "session is null");
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.planOptimizers = (List) Objects.requireNonNull(list, "planOptimizers is null");
        this.planFragmenter = (PlanFragmenter) Objects.requireNonNull(planFragmenter, "planFragmenter is null");
        this.planSanityChecker = (PlanSanityChecker) Objects.requireNonNull(planSanityChecker, "planSanityChecker is null");
        this.warningCollector = (WarningCollector) Objects.requireNonNull(warningCollector, "warningCollector is null");
        this.planOptimizersStatsCollector = (PlanOptimizersStatsCollector) Objects.requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null");
        this.tableStatsProvider = (CachingTableStatsProvider) Objects.requireNonNull(cachingTableStatsProvider, "tableStatsProvider is null");
    }

    public SubPlan optimize(SubPlan subPlan, RuntimeInfoProvider runtimeInfoProvider) {
        if (runtimeInfoProvider.getRuntimeOutputStats(subPlan.getFragment().getId()).isAccurate()) {
            return subPlan;
        }
        List<SubPlan> list = (List) traverse(subPlan).collect(ImmutableList.toImmutableList());
        PlanFragmentIdAllocator planFragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(list) + 1);
        SymbolAllocator createSymbolAllocator = createSymbolAllocator(list);
        PlanNode rewriteWith = SimplePlanRewriter.rewriteWith(new ReplaceRemoteSourcesWithExchanges(runtimeInfoProvider), subPlan.getFragment().getRoot(), subPlan.getChildren());
        PlanNode initialPlan = getInitialPlan(rewriteWith);
        PlanNode currentPlan = getCurrentPlan(rewriteWith);
        ExchangeSourceIdToSubPlanCollector exchangeSourceIdToSubPlanCollector = new ExchangeSourceIdToSubPlanCollector();
        rewriteWith.accept(exchangeSourceIdToSubPlanCollector, list);
        Map<ExchangeSourceId, SubPlan> exchangeSourceIdToSubPlan = exchangeSourceIdToSubPlanCollector.getExchangeSourceIdToSubPlan();
        PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(getMaxPlanId(currentPlan) + 1);
        AdaptivePlanOptimizer.Result optimizePlan = optimizePlan(currentPlan, createSymbolAllocator, runtimeInfoProvider, planNodeIdAllocator);
        if (optimizePlan.changedPlanNodes().isEmpty()) {
            return subPlan;
        }
        this.cummulativeChangedPlanNodes.addAll(optimizePlan.changedPlanNodes());
        PlanNode addAdaptivePlanNode = addAdaptivePlanNode(planNodeIdAllocator, initialPlan, optimizePlan.plan(), this.cummulativeChangedPlanNodes);
        ScopedSpan scopedSpan = ScopedSpan.scopedSpan(this.plannerContext.getTracer(), "validate-adaptive-plan");
        try {
            this.planSanityChecker.validateAdaptivePlan(addAdaptivePlanNode, this.session, this.plannerContext, this.warningCollector);
            if (scopedSpan != null) {
                scopedSpan.close();
            }
            return this.planFragmenter.createSubPlans(this.session, new Plan(addAdaptivePlanNode, StatsAndCosts.empty()), false, this.warningCollector, planFragmentIdAllocator, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), addAdaptivePlanNode.getOutputSymbols()), getUnchangedSubPlans(addAdaptivePlanNode, optimizePlan.changedPlanNodes(), exchangeSourceIdToSubPlan));
        } catch (Throwable th) {
            if (scopedSpan != null) {
                try {
                    scopedSpan.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private AdaptivePlanOptimizer.Result optimizePlan(PlanNode planNode, SymbolAllocator symbolAllocator, RuntimeInfoProvider runtimeInfoProvider, PlanNodeIdAllocator planNodeIdAllocator) {
        AdaptivePlanOptimizer.Result result = new AdaptivePlanOptimizer.Result(planNode, Set.of());
        ImmutableSet.Builder builder = ImmutableSet.builder();
        Iterator<AdaptivePlanOptimizer> it = this.planOptimizers.iterator();
        while (it.hasNext()) {
            result = it.next().optimizeAndMarkPlanChanges(result.plan(), new PlanOptimizer.Context(this.session, symbolAllocator, planNodeIdAllocator, this.warningCollector, this.planOptimizersStatsCollector, this.tableStatsProvider, runtimeInfoProvider));
            builder.addAll(result.changedPlanNodes());
        }
        return new AdaptivePlanOptimizer.Result(result.plan(), builder.build());
    }

    private PlanNode addAdaptivePlanNode(PlanNodeIdAllocator planNodeIdAllocator, PlanNode planNode, PlanNode planNode2, Set<PlanNodeId> set) {
        if (set.contains(planNode2.getId())) {
            return new AdaptivePlanNode(planNodeIdAllocator.getNextId(), planNode, SymbolsExtractor.extractOutputSymbols(planNode), planNode2);
        }
        Verify.verify(planNode.getSources().size() == planNode2.getSources().size());
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < planNode.getSources().size(); i++) {
            builder.add(addAdaptivePlanNode(planNodeIdAllocator, planNode.getSources().get(i), planNode2.getSources().get(i), set));
        }
        return planNode2.replaceChildren(builder.build());
    }

    private Map<ExchangeSourceId, SubPlan> getUnchangedSubPlans(PlanNode planNode, Set<PlanNodeId> set, Map<ExchangeSourceId, SubPlan> map) {
        HashSet hashSet = new HashSet();
        Iterator<PlanNodeId> it = set.iterator();
        while (it.hasNext()) {
            hashSet.addAll(getDownstreamPlanNodeIds(planNode, it.next()));
        }
        return (Map) map.entrySet().stream().filter(entry -> {
            return (hashSet.contains(((ExchangeSourceId) entry.getKey()).exchangeId()) || hashSet.contains(((ExchangeSourceId) entry.getKey()).sourceId())) ? false : true;
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    private Set<PlanNodeId> getDownstreamPlanNodeIds(PlanNode planNode, PlanNodeId planNodeId) {
        if (planNode.getId().equals(planNodeId)) {
            return ImmutableSet.of(planNodeId);
        }
        HashSet hashSet = new HashSet();
        Stream<R> map = planNode.getSources().stream().map(planNode2 -> {
            return getDownstreamPlanNodeIds(planNode2, planNodeId);
        });
        Objects.requireNonNull(hashSet);
        map.forEach((v1) -> {
            r1.addAll(v1);
        });
        if (!hashSet.isEmpty()) {
            hashSet.add(planNode.getId());
        }
        return hashSet;
    }

    private PlanNode getCurrentPlan(PlanNode planNode) {
        return SimplePlanRewriter.rewriteWith(new CurrentPlanRewriter(), planNode);
    }

    private PlanNode getInitialPlan(PlanNode planNode) {
        return SimplePlanRewriter.rewriteWith(new InitialPlanRewriter(), planNode);
    }

    private SymbolAllocator createSymbolAllocator(List<SubPlan> list) {
        return new SymbolAllocator((Collection) list.stream().map((v0) -> {
            return v0.getFragment();
        }).map((v0) -> {
            return v0.getSymbols();
        }).flatMap((v0) -> {
            return v0.stream();
        }).collect(ImmutableSet.toImmutableSet()));
    }

    private int getMaxPlanFragmentId(List<SubPlan> list) {
        return list.stream().map((v0) -> {
            return v0.getFragment();
        }).map((v0) -> {
            return v0.getId();
        }).mapToInt(planFragmentId -> {
            return Integer.parseInt(planFragmentId.toString());
        }).max().orElseThrow();
    }

    private int getMaxPlanId(PlanNode planNode) {
        return traverse(planNode).map((v0) -> {
            return v0.getId();
        }).mapToInt(planNodeId -> {
            return Integer.parseInt(planNodeId.toString());
        }).max().orElseThrow();
    }

    private Stream<PlanNode> traverse(PlanNode planNode) {
        return StreamSupport.stream(Traverser.forTree((v0) -> {
            return v0.getSources();
        }).depthFirstPreOrder(planNode).spliterator(), false);
    }

    private Stream<SubPlan> traverse(SubPlan subPlan) {
        return StreamSupport.stream(Traverser.forTree((v0) -> {
            return v0.getChildren();
        }).depthFirstPreOrder(subPlan).spliterator(), false);
    }

    private static boolean containsAdaptivePlanNode(PlanNode planNode) {
        return PlanNodeSearcher.searchFrom(planNode).whereIsInstanceOfAny(AdaptivePlanNode.class).matches();
    }
}
