package org.apache.druid.sql.calcite.rule;

import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.ArrayList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.DruidCorrelateUnnestRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.DruidUnnestRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;

/* loaded from: input_file:org/apache/druid/sql/calcite/rule/DruidCorrelateUnnestRule.class */
public class DruidCorrelateUnnestRule extends RelOptRule {
    private final PlannerContext plannerContext;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/druid/sql/calcite/rule/DruidCorrelateUnnestRule$PushCorrelatedFieldAccessPastProject.class */
    public static class PushCorrelatedFieldAccessPastProject extends RexShuttle {
        private final CorrelationId correlationId;
        private final CorrelationId newCorrelationId;
        private final Project project;
        private final IntSet requiredColumns = new IntAVLTreeSet();

        public PushCorrelatedFieldAccessPastProject(CorrelationId correlationId, CorrelationId correlationId2, Project project) {
            this.correlationId = correlationId;
            this.newCorrelationId = correlationId2;
            this.project = project;
        }

        public IntSet getRequiredColumns() {
            return this.requiredColumns;
        }

        /* JADX WARN: Type inference failed for: r0v22, types: [org.apache.druid.sql.calcite.rule.DruidCorrelateUnnestRule$PushCorrelatedFieldAccessPastProject$1] */
        /* renamed from: visitFieldAccess, reason: merged with bridge method [inline-methods] */
        public RexNode m213visitFieldAccess(RexFieldAccess rexFieldAccess) {
            if (!(rexFieldAccess.getReferenceExpr() instanceof RexCorrelVariable) || !rexFieldAccess.getReferenceExpr().id.equals(this.correlationId)) {
                return super.visitFieldAccess(rexFieldAccess);
            }
            RexNode rexNode = (RexNode) this.project.getProjects().get(rexFieldAccess.getField().getIndex());
            final RexNode makeCorrel = this.project.getCluster().getRexBuilder().makeCorrel(this.project.getInput().getRowType(), this.newCorrelationId);
            return new RexShuttle() { // from class: org.apache.druid.sql.calcite.rule.DruidCorrelateUnnestRule.PushCorrelatedFieldAccessPastProject.1
                /* renamed from: visitInputRef, reason: merged with bridge method [inline-methods] */
                public RexNode m214visitInputRef(RexInputRef rexInputRef) {
                    PushCorrelatedFieldAccessPastProject.this.requiredColumns.add(rexInputRef.getIndex());
                    return PushCorrelatedFieldAccessPastProject.this.project.getCluster().getRexBuilder().makeFieldAccess(makeCorrel, rexInputRef.getIndex());
                }
            }.apply(rexNode);
        }
    }

    public DruidCorrelateUnnestRule(PlannerContext plannerContext) {
        super(operand(Correlate.class, operand(DruidRel.class, any()), new RelOptRuleOperand[]{operand(DruidUnnestRel.class, any())}));
        this.plannerContext = plannerContext;
    }

    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return relOptRuleCall.rel(1).getPartialDruidQuery() != null;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        DruidRel druidRel;
        RexNode inputRexNode;
        ImmutableBitSet requiredColumns;
        CorrelationId correlationId;
        Correlate rel = relOptRuleCall.rel(0);
        DruidRel rel2 = relOptRuleCall.rel(1);
        DruidUnnestRel rel3 = relOptRuleCall.rel(2);
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ArrayList arrayList = new ArrayList();
        if (rel2.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT) {
            RelNode scan = rel2.getPartialDruidQuery().getScan();
            Project selectProject = rel2.getPartialDruidQuery().getSelectProject();
            arrayList.addAll(selectProject.getProjects());
            druidRel = rel2.withPartialQuery(PartialDruidQuery.create(scan).withWhereFilter(rel2.getPartialDruidQuery().getWhereFilter()));
            correlationId = rel.getCluster().createCorrel();
            PushCorrelatedFieldAccessPastProject pushCorrelatedFieldAccessPastProject = new PushCorrelatedFieldAccessPastProject(rel.getCorrelationId(), correlationId, selectProject);
            inputRexNode = pushCorrelatedFieldAccessPastProject.apply(rel3.getInputRexNode());
            requiredColumns = ImmutableBitSet.of(pushCorrelatedFieldAccessPastProject.getRequiredColumns());
        } else {
            for (int i = 0; i < rel2.getRowType().getFieldCount(); i++) {
                arrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(i)).getType(), i));
            }
            druidRel = rel2;
            inputRexNode = rel3.getInputRexNode();
            requiredColumns = rel.getRequiredColumns();
            correlationId = rel.getCorrelationId();
        }
        for (int i2 = 0; i2 < rel3.getRowType().getFieldCount(); i2++) {
            arrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(rel2.getRowType().getFieldCount() + i2)).getType(), druidRel.getRowType().getFieldCount() + i2));
        }
        DruidCorrelateUnnestRel create = DruidCorrelateUnnestRel.create(rel.copy(rel.getTraitSet(), druidRel, rel3.withUnnestRexNode(inputRexNode), correlationId, requiredColumns, rel.getJoinType()), this.plannerContext);
        RelBuilder project = relOptRuleCall.builder().push(create).project(RexUtil.fixUp(rexBuilder, arrayList, RelOptUtil.getFieldTypeList(create.getRowType())));
        project.convert(rel.getRowType(), false);
        relOptRuleCall.transformTo(project.build());
    }
}
