package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.SwitchBootstraps;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/RemoveRedundantDistinctAggregation.class */
public class RemoveRedundantDistinctAggregation implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((v0) -> {
        return v0.producesDistinctRows();
    });

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        Lookup lookup = context.getLookup();
        return isDistinctOverGroupingKeys(lookup.resolve(aggregationNode.getSource()), lookup, new HashSet(aggregationNode.getGroupingKeys())) ? Rule.Result.ofPlanNode(aggregationNode.getSource()) : Rule.Result.empty();
    }

    private static boolean isDistinctOverGroupingKeys(PlanNode planNode, Lookup lookup, Set<Symbol> set) {
        switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, PlanNode.class, Integer.TYPE), AggregationNode.class, ProjectNode.class, FilterNode.class).dynamicInvoker().invoke(planNode, 0) /* invoke-custom */) {
            case -1:
            default:
                return false;
            case 0:
                AggregationNode aggregationNode = (AggregationNode) planNode;
                return aggregationNode.getGroupingSets().getGroupingSetCount() == 1 && set.containsAll(aggregationNode.getGroupingSets().getGroupingKeys());
            case 1:
                ProjectNode projectNode = (ProjectNode) planNode;
                return isDistinctOverGroupingKeys(lookup.resolve(projectNode.getSource()), lookup, translateProjectReferences(projectNode, set));
            case 2:
                return isDistinctOverGroupingKeys(lookup.resolve(((FilterNode) planNode).getSource()), lookup, set);
        }
    }

    private static Set<Symbol> translateProjectReferences(ProjectNode projectNode, Set<Symbol> set) {
        HashSet hashSet = new HashSet();
        Assignments assignments = projectNode.getAssignments();
        Iterator<Symbol> it = set.iterator();
        while (it.hasNext()) {
            Expression expression = assignments.get(it.next());
            if (expression instanceof Reference) {
                hashSet.add(Symbol.from((Reference) expression));
            }
        }
        return hashSet;
    }
}
