package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Iterables;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.ExceptNode;
import io.trino.sql.planner.plan.IntersectNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SetOperationNode;
import io.trino.sql.planner.plan.UnionNode;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/SetOperationMerge.class */
class SetOperationMerge {
    private final Rule.Context context;
    private final SetOperationNode node;
    private final List<PlanNode> newSources = new ArrayList();

    public SetOperationMerge(SetOperationNode setOperationNode, Rule.Context context) {
        this.node = setOperationNode;
        this.context = context;
    }

    public Optional<SetOperationNode> mergeFirstSource() {
        Lookup lookup = this.context.getLookup();
        Stream<PlanNode> stream = this.node.getSources().stream();
        Objects.requireNonNull(lookup);
        List list = (List) stream.map(lookup::resolve).collect(ImmutableList.toImmutableList());
        PlanNode planNode = (PlanNode) list.get(0);
        Optional<Boolean> mergedQuantifierIsDistinct = mergedQuantifierIsDistinct(this.node, planNode);
        if (mergedQuantifierIsDistinct.isEmpty()) {
            return Optional.empty();
        }
        ImmutableListMultimap.Builder<Symbol, Symbol> builder = ImmutableListMultimap.builder();
        addMergedMappings((SetOperationNode) planNode, 0, builder);
        for (int i = 1; i < list.size(); i++) {
            addOriginalMappings((PlanNode) list.get(i), i, builder);
        }
        SetOperationNode setOperationNode = this.node;
        Objects.requireNonNull(setOperationNode);
        switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, SetOperationNode.class, Integer.TYPE), UnionNode.class, IntersectNode.class, ExceptNode.class).dynamicInvoker().invoke(setOperationNode, 0) /* invoke-custom */) {
            case 0:
                return Optional.of(new UnionNode(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols()));
            case 1:
                return Optional.of(new IntersectNode(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols(), mergedQuantifierIsDistinct.get().booleanValue()));
            case 2:
                return Optional.of(new ExceptNode(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols(), mergedQuantifierIsDistinct.get().booleanValue()));
            default:
                throw new IllegalArgumentException("unexpected node type: " + this.node.getClass().getSimpleName());
        }
    }

    public Optional<SetOperationNode> merge() {
        Preconditions.checkState((this.node instanceof UnionNode) || (this.node instanceof IntersectNode), "unexpected node type: %s", this.node.getClass().getSimpleName());
        Lookup lookup = this.context.getLookup();
        Stream<PlanNode> stream = this.node.getSources().stream();
        Objects.requireNonNull(lookup);
        List list = (List) stream.map(lookup::resolve).collect(ImmutableList.toImmutableList());
        ImmutableListMultimap.Builder<Symbol, Symbol> builder = ImmutableListMultimap.builder();
        boolean z = false;
        boolean z2 = false;
        for (int i = 0; i < list.size(); i++) {
            PlanNode planNode = (PlanNode) list.get(i);
            Optional<Boolean> mergedQuantifierIsDistinct = mergedQuantifierIsDistinct(this.node, planNode);
            if (mergedQuantifierIsDistinct.isPresent()) {
                addMergedMappings((SetOperationNode) planNode, i, builder);
                z |= mergedQuantifierIsDistinct.get().booleanValue();
                z2 = true;
            } else {
                addOriginalMappings(planNode, i, builder);
            }
        }
        return !z2 ? Optional.empty() : this.node instanceof UnionNode ? Optional.of(new UnionNode(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols())) : Optional.of(new IntersectNode(this.node.getId(), this.newSources, builder.build(), this.node.getOutputSymbols(), z));
    }

    private Optional<Boolean> mergedQuantifierIsDistinct(SetOperationNode setOperationNode, PlanNode planNode) {
        if (!setOperationNode.getClass().equals(planNode.getClass())) {
            return Optional.empty();
        }
        if (setOperationNode instanceof UnionNode) {
            return Optional.of(false);
        }
        if (setOperationNode instanceof IntersectNode) {
            return (((IntersectNode) setOperationNode).isDistinct() || ((IntersectNode) planNode).isDistinct()) ? Optional.of(true) : Optional.of(false);
        }
        if (setOperationNode instanceof ExceptNode) {
            return (!((ExceptNode) setOperationNode).isDistinct() || ((ExceptNode) planNode).isDistinct()) ? Optional.of(Boolean.valueOf(((ExceptNode) planNode).isDistinct())) : Optional.empty();
        }
        throw new IllegalArgumentException("unexpected node type: %s".formatted(setOperationNode.getClass().getSimpleName()));
    }

    private void addMergedMappings(SetOperationNode setOperationNode, int i, ImmutableListMultimap.Builder<Symbol, Symbol> builder) {
        this.newSources.addAll(setOperationNode.getSources());
        for (Map.Entry entry : this.node.getSymbolMapping().asMap().entrySet()) {
            builder.putAll((Symbol) entry.getKey(), setOperationNode.getSymbolMapping().get((Symbol) Iterables.get((Iterable) entry.getValue(), i)));
        }
    }

    private void addOriginalMappings(PlanNode planNode, int i, ImmutableListMultimap.Builder<Symbol, Symbol> builder) {
        this.newSources.add(planNode);
        for (Map.Entry entry : this.node.getSymbolMapping().asMap().entrySet()) {
            builder.put((Symbol) entry.getKey(), (Symbol) Iterables.get((Iterable) entry.getValue(), i));
        }
    }
}
