package org.apache.druid.sql.calcite.planner;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql2rel.RelFieldTrimmer;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.IntPair;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.druid.sql.calcite.rule.logical.LogicalUnnest;

/* loaded from: input_file:org/apache/druid/sql/calcite/planner/DruidRelFieldTrimmer.class */
public class DruidRelFieldTrimmer extends RelFieldTrimmer {
    private final RelBuilder relBuilder;

    /* loaded from: input_file:org/apache/druid/sql/calcite/planner/DruidRelFieldTrimmer$RexCorrelVariableMapShuttle.class */
    static class RexCorrelVariableMapShuttle extends RexShuttle {
        private final CorrelationId correlationId;
        private final Mapping mapping;
        private final RelDataType newCorrelRowType;
        private final RexBuilder rexBuilder;

        public RexCorrelVariableMapShuttle(CorrelationId correlationId, RelDataType relDataType, Mapping mapping, RexBuilder rexBuilder) {
            this.correlationId = correlationId;
            this.newCorrelRowType = relDataType;
            this.mapping = mapping;
            this.rexBuilder = rexBuilder;
        }

        /* renamed from: visitFieldAccess, reason: merged with bridge method [inline-methods] */
        public RexNode m168visitFieldAccess(RexFieldAccess rexFieldAccess) {
            if (rexFieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
                RexCorrelVariable rexCorrelVariable = (RexCorrelVariable) rexFieldAccess.getReferenceExpr();
                if (rexCorrelVariable.id.equals(this.correlationId)) {
                    return this.rexBuilder.makeFieldAccess(map(rexCorrelVariable), this.mapping.getTarget(rexFieldAccess.getField().getIndex()));
                }
            }
            return super.visitFieldAccess(rexFieldAccess);
        }

        private RexNode map(RexCorrelVariable rexCorrelVariable) {
            return this.rexBuilder.makeCorrel(this.newCorrelRowType, rexCorrelVariable.id);
        }
    }

    /* loaded from: input_file:org/apache/druid/sql/calcite/planner/DruidRelFieldTrimmer$RexRewritingRelShuttle.class */
    static class RexRewritingRelShuttle extends RelHomogeneousShuttle {
        private final RexShuttle rexVisitor;

        RexRewritingRelShuttle(RexShuttle rexShuttle) {
            this.rexVisitor = rexShuttle;
        }

        public RelNode visit(RelNode relNode) {
            return super.visit(relNode).accept(this.rexVisitor);
        }
    }

    public DruidRelFieldTrimmer(SqlValidator sqlValidator, RelBuilder relBuilder) {
        super(sqlValidator, relBuilder);
        this.relBuilder = relBuilder;
    }

    protected RelFieldTrimmer.TrimResult dummyProject(int i, RelNode relNode) {
        return makeIdentityMapping(relNode);
    }

    protected RelFieldTrimmer.TrimResult dummyProject(int i, RelNode relNode, RelNode relNode2) {
        if (i != 0) {
            return super.dummyProject(i, relNode, relNode2);
        }
        Mapping create = Mappings.create(MappingType.INVERSE_SURJECTION, i, 0);
        if (relNode.getRowType().getFieldCount() == 0) {
            return result(relNode, create);
        }
        this.relBuilder.push(relNode);
        this.relBuilder.project(Collections.emptyList(), Collections.emptyList());
        RelNode build = this.relBuilder.build();
        if (relNode2 != null) {
            build = RelOptUtil.propagateRelHints(relNode2, build);
        }
        return result(build, create);
    }

    private RelFieldTrimmer.TrimResult makeIdentityMapping(RelNode relNode) {
        return result(relNode, Mappings.createIdentity(relNode.getRowType().getFieldCount()));
    }

    public RelFieldTrimmer.TrimResult trimFields(LogicalCorrelate logicalCorrelate, ImmutableBitSet immutableBitSet, Set<RelDataTypeField> set) {
        if (!set.isEmpty()) {
            return trimFields((RelNode) logicalCorrelate, immutableBitSet, (Set) set);
        }
        ImmutableBitSet union = immutableBitSet.union(logicalCorrelate.getRequiredColumns());
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i = 0;
        int i2 = 0;
        for (RelNode relNode : logicalCorrelate.getInputs()) {
            int fieldCount = relNode.getRowType().getFieldCount();
            try {
                RelFieldTrimmer.TrimResult dispatchTrimFields = dispatchTrimFields(relNode, union.intersect(ImmutableBitSet.range(i2, i2 + fieldCount)).shift(-i2), set);
                arrayList.add((RelNode) dispatchTrimFields.left);
                if (dispatchTrimFields.left != relNode) {
                    i++;
                }
                arrayList2.add((Mapping) dispatchTrimFields.right);
                i2 += fieldCount;
            } catch (RuntimeException e) {
                throw e;
            }
        }
        if (i == 0) {
            return result(logicalCorrelate, Mappings.createIdentity(logicalCorrelate.getRowType().getFieldCount()));
        }
        Mapping makeMapping = makeMapping(arrayList2);
        return result(logicalCorrelate.copy(logicalCorrelate.getTraitSet(), (RelNode) arrayList.get(0), ((RelNode) arrayList.get(1)).accept(new RexRewritingRelShuttle(new RexCorrelVariableMapShuttle(logicalCorrelate.getCorrelationId(), ((RelNode) arrayList.get(0)).getRowType(), makeMapping, logicalCorrelate.getCluster().getRexBuilder()))), logicalCorrelate.getCorrelationId(), logicalCorrelate.getRequiredColumns().permute(makeMapping), logicalCorrelate.getJoinType()), makeMapping);
    }

    public RelFieldTrimmer.TrimResult trimFields(LogicalUnnest logicalUnnest, ImmutableBitSet immutableBitSet, Set<RelDataTypeField> set) {
        if (!set.isEmpty()) {
            return trimFields((RelNode) logicalUnnest, immutableBitSet, (Set) set);
        }
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(set);
        logicalUnnest.getUnnestExpr().accept(inputFinder);
        ImmutableBitSet build = ImmutableBitSet.builder().addAll(immutableBitSet.clear(logicalUnnest.getRowType().getFieldCount() - 1)).addAll(inputFinder.build()).build();
        RelNode input = logicalUnnest.getInput();
        RelFieldTrimmer.TrimResult trimChild = trimChild(logicalUnnest, input, build, set);
        RelNode relNode = (RelNode) trimChild.left;
        Mapping mapping = (Mapping) trimChild.right;
        return relNode == input ? result(logicalUnnest, Mappings.createIdentity(logicalUnnest.getRowType().getFieldCount())) : result(logicalUnnest.copy(logicalUnnest.getTraitSet(), relNode, (RexNode) logicalUnnest.getUnnestExpr().accept(new RexPermuteInputsShuttle(mapping, new RelNode[]{relNode})), logicalUnnest.getFilter()), makeMapping(ImmutableList.of(mapping, Mappings.createIdentity(1))));
    }

    private Mapping makeMapping(List<Mapping> list) {
        int i = 0;
        int i2 = 0;
        for (Mapping mapping : list) {
            i += mapping.getSourceCount();
            i2 += mapping.getTargetCount();
        }
        Mapping create = Mappings.create(MappingType.INVERSE_SURJECTION, i, i2);
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < list.size(); i5++) {
            Mapping<IntPair> mapping2 = list.get(i5);
            for (IntPair intPair : mapping2) {
                create.set(intPair.source + i3, intPair.target + i4);
            }
            i3 += mapping2.getSourceCount();
            i4 += mapping2.getTargetCount();
        }
        return create;
    }
}
