package org.apache.druid.sql.calcite.rule.logical;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mappings;
import org.immutables.value.Value;

@Value.Enclosing
/* loaded from: input_file:org/apache/druid/sql/calcite/rule/logical/DruidAggregateRemoveRedundancyRule.class */
public class DruidAggregateRemoveRedundancyRule extends RelOptRule implements TransformationRule {
    private static final DruidAggregateRemoveRedundancyRule INSTANCE = new DruidAggregateRemoveRedundancyRule();

    private DruidAggregateRemoveRedundancyRule() {
        super(operand(Aggregate.class, operand(Project.class, any()), new RelOptRuleOperand[0]));
    }

    public static DruidAggregateRemoveRedundancyRule instance() {
        return INSTANCE;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate rel = relOptRuleCall.rel(0);
        RelNode apply = apply(relOptRuleCall, rel, relOptRuleCall.rel(1));
        if (apply != null) {
            relOptRuleCall.transformTo(apply);
            relOptRuleCall.getPlanner().prune(rel);
        }
    }

    public static RelNode apply(RelOptRuleCall relOptRuleCall, Aggregate aggregate, Project project) {
        Set allFields = RelOptUtil.getAllFields(aggregate);
        if (allFields.isEmpty()) {
            return null;
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        ArrayList arrayList = new ArrayList();
        Iterator it = allFields.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            RexNode rexNode = (RexNode) project.getProjects().get(intValue);
            if (!hashMap2.containsKey(rexNode)) {
                RexInputRef rexInputRef = new RexInputRef(intValue, rexNode.getType());
                hashMap2.put(rexNode, Integer.valueOf(arrayList.size()));
                arrayList.add(rexInputRef);
            }
            hashMap.put(Integer.valueOf(intValue), (Integer) hashMap2.get(rexNode));
        }
        if (arrayList.size() == project.getProjects().size()) {
            return null;
        }
        ImmutableBitSet permute = aggregate.getGroupSet().permute(hashMap);
        ImmutableList immutableSortedCopy = aggregate.getGroupType() != Aggregate.Group.SIMPLE ? ImmutableBitSet.ORDERING.immutableSortedCopy(Sets.newTreeSet(ImmutableBitSet.permute(aggregate.getGroupSets(), hashMap))) : null;
        ImmutableList.Builder builder = ImmutableList.builder();
        Mappings.TargetMapping target = Mappings.target(hashMap, aggregate.getInput().getRowType().getFieldCount(), arrayList.size());
        Iterator it2 = aggregate.getAggCallList().iterator();
        while (it2.hasNext()) {
            builder.add(((AggregateCall) it2.next()).transform(target));
        }
        RelBuilder builder2 = relOptRuleCall.builder();
        builder2.push(project);
        builder2.project(arrayList);
        Aggregate copy = aggregate.copy(aggregate.getTraitSet(), builder2.build(), permute, immutableSortedCopy, builder.build());
        builder2.push(copy);
        List transform = Util.transform(aggregate.getGroupSet().asList(), num -> {
            return (Integer) Objects.requireNonNull((Integer) hashMap.get(num), (Supplier<String>) () -> {
                return "no value found for key " + num + " in " + String.valueOf(hashMap);
            });
        });
        if (!transform.equals(permute.asList())) {
            ArrayList arrayList2 = new ArrayList();
            Iterator it3 = transform.iterator();
            while (it3.hasNext()) {
                arrayList2.add(Integer.valueOf(permute.indexOf(((Integer) it3.next()).intValue())));
            }
            for (int groupCount = copy.getGroupCount(); groupCount < copy.getRowType().getFieldCount(); groupCount++) {
                arrayList2.add(Integer.valueOf(groupCount));
            }
            builder2.project(builder2.fields(arrayList2));
        }
        return builder2.build();
    }
}
