package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.Cardinality;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import java.util.List;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithSource.class */
public class ReplaceRedundantJoinWithSource implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        PlanNode left;
        List<Symbol> leftOutputSymbols;
        Cardinality extractCardinality = QueryCardinalityUtil.extractCardinality(joinNode.getLeft(), context.getLookup());
        if (extractCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        Cardinality extractCardinality2 = QueryCardinalityUtil.extractCardinality(joinNode.getRight(), context.getLookup());
        if (extractCardinality2.isEmpty()) {
            return Rule.Result.empty();
        }
        boolean z = joinNode.getLeft().getOutputSymbols().isEmpty() && extractCardinality.isScalar();
        boolean z2 = joinNode.getRight().getOutputSymbols().isEmpty() && extractCardinality2.isScalar();
        switch (joinNode.getType()) {
            case INNER:
                if (z) {
                    left = joinNode.getRight();
                    leftOutputSymbols = joinNode.getRightOutputSymbols();
                } else {
                    if (!z2) {
                        return Rule.Result.empty();
                    }
                    left = joinNode.getLeft();
                    leftOutputSymbols = joinNode.getLeftOutputSymbols();
                }
                if (joinNode.getFilter().isPresent()) {
                    left = new FilterNode(context.getIdAllocator().getNextId(), left, joinNode.getFilter().get());
                }
                return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), left, ImmutableSet.copyOf(leftOutputSymbols)).orElse(left));
            case LEFT:
                return z2 ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), joinNode.getLeft(), ImmutableSet.copyOf(joinNode.getLeftOutputSymbols())).orElse(joinNode.getLeft())) : Rule.Result.empty();
            case RIGHT:
                return z ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), joinNode.getRight(), ImmutableSet.copyOf(joinNode.getRightOutputSymbols())).orElse(joinNode.getRight())) : Rule.Result.empty();
            case FULL:
                return (z && extractCardinality2.isAtLeastScalar()) ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), joinNode.getRight(), ImmutableSet.copyOf(joinNode.getRightOutputSymbols())).orElse(joinNode.getRight())) : (z2 && extractCardinality.isAtLeastScalar()) ? Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), joinNode.getLeft(), ImmutableSet.copyOf(joinNode.getLeftOutputSymbols())).orElse(joinNode.getLeft())) : Rule.Result.empty();
            default:
                throw new MatchException((String) null, (Throwable) null);
        }
    }
}
