package org.apache.druid.sql.calcite.rule;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.lookup.LookupExtractionFn;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.sql.calcite.expression.builtin.MultiValueStringOperatorConversions;
import org.apache.druid.sql.calcite.expression.builtin.QueryLookupOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.SearchOperatorConversion;
import org.apache.druid.sql.calcite.filtration.CollectComparisons;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.utils.CollectionUtils;

/* loaded from: input_file:org/apache/druid/sql/calcite/rule/ReverseLookupRule.class */
public class ReverseLookupRule extends RelOptRule implements SubstitutionRule {
    public static final String CTX_MAX_OPTIMIZE_COUNT = "maxOptimizeCountForDruidReverseLookupRule";
    public static final String CTX_THRESHOLD = "sqlReverseLookupThreshold";
    public static final int DEFAULT_THRESHOLD = 10000;
    private final PlannerContext plannerContext;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/druid/sql/calcite/rule/ReverseLookupRule$ReverseLookupKey.class */
    public static class ReverseLookupKey {
        private final RexNode arg;
        private final String lookupName;
        private final String replaceMissingValueWith;
        private final boolean multiValue;
        private final boolean negate;

        private ReverseLookupKey(RexNode rexNode, String str, String str2, boolean z, boolean z2) {
            this.arg = rexNode;
            this.lookupName = str;
            this.replaceMissingValueWith = str2;
            this.multiValue = z;
            this.negate = z2;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            ReverseLookupKey reverseLookupKey = (ReverseLookupKey) obj;
            return this.multiValue == reverseLookupKey.multiValue && this.negate == reverseLookupKey.negate && Objects.equals(this.arg, reverseLookupKey.arg) && Objects.equals(this.lookupName, reverseLookupKey.lookupName) && Objects.equals(this.replaceMissingValueWith, reverseLookupKey.replaceMissingValueWith);
        }

        public int hashCode() {
            return Objects.hash(this.arg, this.lookupName, this.replaceMissingValueWith, Boolean.valueOf(this.multiValue), Boolean.valueOf(this.negate));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/druid/sql/calcite/rule/ReverseLookupRule$ReverseLookupShuttle.class */
    public static class ReverseLookupShuttle extends RexShuttle {
        private final PlannerContext plannerContext;
        private final RexBuilder rexBuilder;
        private final int maxOptimizeCount;
        private final int maxInSize;
        private final Set<RexNode> consideredAsChild = new HashSet();
        private boolean includeUnknown = false;
        private int optimizeCount = 0;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/druid/sql/calcite/rule/ReverseLookupRule$ReverseLookupShuttle$CollectReverseLookups.class */
        public class CollectReverseLookups extends CollectComparisons<RexNode, RexCall, RexNode, ReverseLookupKey, String, InDimFilter.ValuesSet> {
            private final RexBuilder rexBuilder;

            private CollectReverseLookups(List<RexNode> list, RexBuilder rexBuilder) {
                super(list);
                this.rexBuilder = rexBuilder;
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            @Nullable
            public Pair<RexCall, List<RexNode>> getCollectibleComparison(RexNode rexNode) {
                RexCall asLookupComparison = getAsLookupComparison(rexNode);
                if (asLookupComparison != null) {
                    return Pair.of(asLookupComparison, Collections.emptyList());
                }
                return null;
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            public InDimFilter.ValuesSet makeCollection() {
                return new InDimFilter.ValuesSet();
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            @Nullable
            public ReverseLookupKey getCollectionKey(RexCall rexCall) {
                List operands = ((RexCall) rexCall.getOperands().get(0)).getOperands();
                RexNode rexNode = (RexNode) operands.get(0);
                String stringValue = RexLiteral.stringValue((RexNode) operands.get(1));
                String stringValue2 = operands.size() >= 3 ? RexLiteral.stringValue((RexNode) operands.get(2)) : null;
                LookupExtractor lookup = ReverseLookupShuttle.this.plannerContext.getLookup(stringValue);
                if (lookup == null) {
                    return null;
                }
                if (!lookup.isOneToOne()) {
                    if (stringValue2 == null ? rexCall.isA(SqlKind.IS_NULL) : getMatchValues(rexCall).contains(stringValue2)) {
                        return null;
                    }
                }
                return new ReverseLookupKey(rexNode, stringValue, stringValue2, rexCall.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.mo59calciteOperator()) || rexCall.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.mo59calciteOperator()), rexCall.getKind() == SqlKind.NOT_EQUALS);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            public Set<String> getMatchValues(RexCall rexCall) {
                if (rexCall.isA(SqlKind.IS_NULL)) {
                    return Collections.singleton(null);
                }
                return ReverseLookupRule.toStringSet((RexNode) rexCall.getOperands().get(1), rexCall.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.mo59calciteOperator()) || rexCall.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.mo59calciteOperator()) || rexCall.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION));
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            @Nullable
            public RexNode makeCollectedComparison(ReverseLookupKey reverseLookupKey, InDimFilter.ValuesSet valuesSet) {
                Set<String> reverseLookup;
                LookupExtractor lookup = ReverseLookupShuttle.this.plannerContext.getLookup(reverseLookupKey.lookupName);
                if (lookup == null || (reverseLookup = reverseLookup(lookup, reverseLookupKey.replaceMissingValueWith, valuesSet, ReverseLookupShuttle.this.includeUnknown ^ reverseLookupKey.negate)) == null) {
                    return null;
                }
                return makeMatchCondition(reverseLookupKey, reverseLookup, this.rexBuilder);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.druid.sql.calcite.filtration.CollectComparisons
            public RexNode makeAnd(List<RexNode> list) {
                throw new UnsupportedOperationException();
            }

            @Nullable
            private RexCall getAsLookupComparison(RexNode rexNode) {
                if (rexNode.isA(SqlKind.IS_NULL) && ReverseLookupRule.isLookupCall((RexNode) ((RexCall) rexNode).getOperands().get(0))) {
                    return (RexCall) rexNode;
                }
                if (!ReverseLookupRule.isBinaryComparison(rexNode)) {
                    return null;
                }
                RexCall rexCall = (RexCall) rexNode;
                RexNode rexNode2 = (RexNode) rexCall.getOperands().get(0);
                RexNode rexNode3 = (RexNode) rexCall.getOperands().get(1);
                if ((rexNode3 instanceof RexCall) && Calcites.isLiteral(rexNode2, true, true)) {
                    if (rexCall.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.mo59calciteOperator())) {
                        return null;
                    }
                    rexNode2 = rexNode3;
                    rexNode3 = rexNode2;
                }
                RexNode removeNullabilityCast = RexUtil.removeNullabilityCast(this.rexBuilder.getTypeFactory(), rexNode2);
                RexNode removeNullabilityCast2 = RexUtil.removeNullabilityCast(this.rexBuilder.getTypeFactory(), rexNode3);
                if (ReverseLookupRule.isLookupCall(removeNullabilityCast) && Calcites.isLiteral(removeNullabilityCast2, true, true)) {
                    return this.rexBuilder.makeCall(rexCall.getOperator(), new RexNode[]{removeNullabilityCast, removeNullabilityCast2});
                }
                return null;
            }

            @Nullable
            private Set<String> reverseLookup(LookupExtractor lookupExtractor, @Nullable String str, InDimFilter.ValuesSet valuesSet, boolean z) {
                ReverseLookupShuttle.this.optimizeCount++;
                if (ReverseLookupShuttle.this.optimizeCount > ReverseLookupShuttle.this.maxOptimizeCount) {
                    throw new ISE("Too many optimize calls[%s]", new Object[]{Integer.valueOf(ReverseLookupShuttle.this.optimizeCount)});
                }
                return InDimFilter.optimizeLookup(new InDimFilter("__dummy__", valuesSet, new LookupExtractionFn(lookupExtractor, false, str, (Boolean) null, true)), z, ReverseLookupShuttle.this.maxInSize);
            }

            private RexNode makeMatchCondition(ReverseLookupKey reverseLookupKey, Set<String> set, RexBuilder rexBuilder) {
                if (set.isEmpty()) {
                    return rexBuilder.makeLiteral(reverseLookupKey.negate);
                }
                if (!reverseLookupKey.multiValue) {
                    return SearchOperatorConversion.makeIn(reverseLookupKey.arg, set, rexBuilder.getTypeFactory().createTypeWithNullability(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), true), reverseLookupKey.negate, set.size() >= ReverseLookupShuttle.this.plannerContext.queryContext().getInFunctionThreshold(), rexBuilder);
                }
                RexNode makeCall = set.size() == 1 ? rexBuilder.makeCall(MultiValueStringOperatorConversions.CONTAINS.mo59calciteOperator(), new RexNode[]{reverseLookupKey.arg, (RexNode) Iterables.getOnlyElement(ReverseLookupRule.stringsToRexNodes(set, rexBuilder))}) : rexBuilder.makeCall(MultiValueStringOperatorConversions.OVERLAP.mo59calciteOperator(), new RexNode[]{reverseLookupKey.arg, rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, ReverseLookupRule.stringsToRexNodes(set, rexBuilder))});
                if (reverseLookupKey.negate) {
                    makeCall = rexBuilder.makeCall(SqlStdOperatorTable.NOT, new RexNode[]{makeCall});
                }
                return makeCall;
            }
        }

        public ReverseLookupShuttle(PlannerContext plannerContext, RexBuilder rexBuilder, int i, int i2) {
            this.plannerContext = plannerContext;
            this.rexBuilder = rexBuilder;
            this.maxOptimizeCount = i;
            this.maxInSize = i2;
        }

        /* renamed from: visitCall, reason: merged with bridge method [inline-methods] */
        public RexNode m229visitCall(RexCall rexCall) {
            return rexCall.getKind() == SqlKind.NOT ? visitNot(rexCall) : rexCall.getKind() == SqlKind.AND ? visitAnd(rexCall) : rexCall.getKind() == SqlKind.OR ? visitOr(rexCall) : rexCall.isA(SqlKind.SEARCH) ? visitSearch(rexCall) : ((rexCall.isA(SqlKind.IS_NULL) || ReverseLookupRule.isBinaryComparison(rexCall)) && !this.consideredAsChild.contains(rexCall)) ? visitComparison(rexCall) : super.visitCall(rexCall);
        }

        private RexNode visitNot(RexCall rexCall) {
            this.includeUnknown = !this.includeUnknown;
            RexNode visitCall = super.visitCall(rexCall);
            this.includeUnknown = !this.includeUnknown;
            return visitCall;
        }

        private RexNode visitOr(RexCall rexCall) {
            this.consideredAsChild.addAll(rexCall.getOperands());
            List<RexNode> collect = new CollectReverseLookups(rexCall.getOperands(), this.rexBuilder).collect();
            return collect != rexCall.getOperands() ? RexUtil.composeDisjunction(this.rexBuilder, collect) : super.visitCall(rexCall);
        }

        private RexNode visitAnd(RexCall rexCall) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (RexCall rexCall2 : rexCall.getOperands()) {
                if (rexCall2.isA(SqlKind.NOT)) {
                    RexNode rexNode = (RexNode) Iterables.getOnlyElement(rexCall2.getOperands());
                    this.consideredAsChild.add(rexNode);
                    arrayList.add(rexNode);
                } else if (rexCall2.isA(SqlKind.NOT_EQUALS)) {
                    this.consideredAsChild.add(rexCall2);
                    arrayList.add(this.rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexCall2.getOperands()));
                } else if (rexCall2.isA(SqlKind.IS_NOT_NULL)) {
                    this.consideredAsChild.add(rexCall2);
                    arrayList.add(this.rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, rexCall2.getOperands()));
                } else {
                    arrayList2.add(rexCall2);
                }
            }
            if (!arrayList.isEmpty()) {
                this.includeUnknown = !this.includeUnknown;
                List<RexNode> collect = new CollectReverseLookups(arrayList, this.rexBuilder).collect();
                this.includeUnknown = !this.includeUnknown;
                if (collect != arrayList) {
                    RexNode makeCall = this.rexBuilder.makeCall(SqlStdOperatorTable.NOT, new RexNode[]{RexUtil.composeDisjunction(this.rexBuilder, collect)});
                    if (!arrayList2.isEmpty()) {
                        arrayList2.add(makeCall);
                        makeCall = this.rexBuilder.makeCall(SqlStdOperatorTable.AND, arrayList2);
                    }
                    return makeCall;
                }
            }
            return super.visitCall(rexCall);
        }

        private RexNode visitSearch(RexCall rexCall) {
            RexNode m229visitCall;
            RexNode expandSearch = SearchOperatorConversion.expandSearch(rexCall, this.rexBuilder, this.plannerContext.queryContext().getInFunctionThreshold());
            return (!(expandSearch instanceof RexCall) || (m229visitCall = m229visitCall((RexCall) expandSearch)) == expandSearch) ? rexCall : m229visitCall;
        }

        private RexNode visitComparison(RexCall rexCall) {
            RexNode rexNode = (RexNode) CollectionUtils.getOnlyElement(new CollectReverseLookups(Collections.singletonList(rexCall), this.rexBuilder).collect(), list -> {
                return new ISE("Expected to collect single node, got[%s]", new Object[]{list});
            });
            return rexNode != rexCall ? rexNode : super.visitCall(rexCall);
        }
    }

    public ReverseLookupRule(PlannerContext plannerContext) {
        super(operand(LogicalFilter.class, any()));
        this.plannerContext = plannerContext;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Filter rel = relOptRuleCall.rel(0);
        RexNode rexNode = (RexNode) rel.getCondition().accept(new ReverseLookupShuttle(this.plannerContext, rel.getCluster().getRexBuilder(), this.plannerContext.queryContext().getInt(CTX_MAX_OPTIMIZE_COUNT, Integer.MAX_VALUE), Math.min(this.plannerContext.queryContext().getInSubQueryThreshold(), this.plannerContext.queryContext().getInt(CTX_THRESHOLD, DEFAULT_THRESHOLD))));
        if (rexNode != rel.getCondition()) {
            relOptRuleCall.transformTo(relOptRuleCall.builder().push(rel.getInput()).filter(new RexNode[]{rexNode}).build());
            relOptRuleCall.getPlanner().prune(rel);
        }
    }

    private static List<RexNode> stringsToRexNodes(Iterable<String> iterable, RexBuilder rexBuilder) {
        return Lists.newArrayList(Iterables.transform(iterable, str -> {
            return str == null ? rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR)) : rexBuilder.makeLiteral(str);
        }));
    }

    private static boolean isBinaryComparison(RexNode rexNode) {
        if (!(rexNode instanceof RexCall)) {
            return false;
        }
        RexCall rexCall = (RexCall) rexNode;
        return rexCall.getKind() == SqlKind.EQUALS || rexCall.getKind() == SqlKind.NOT_EQUALS || rexCall.getOperator().equals(MultiValueStringOperatorConversions.CONTAINS.mo59calciteOperator()) || rexCall.getOperator().equals(MultiValueStringOperatorConversions.OVERLAP.mo59calciteOperator()) || rexCall.getOperator().equals(ScalarInArrayOperatorConversion.SQL_FUNCTION);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isLookupCall(RexNode rexNode) {
        return rexNode.isA(SqlKind.OTHER_FUNCTION) && ((RexCall) rexNode).getOperator().equals(QueryLookupOperatorConversion.SQL_FUNCTION);
    }

    @Nullable
    private static Set<String> toStringSet(RexNode rexNode, boolean z) {
        if (RexUtil.isNullLiteral(rexNode, true)) {
            return z ? Collections.singleton(null) : Collections.emptySet();
        }
        if (SqlTypeFamily.STRING.contains(rexNode.getType())) {
            String stringValue = RexLiteral.stringValue(rexNode);
            return (stringValue != null || z) ? Collections.singleton(stringValue) : Collections.emptySet();
        }
        if (rexNode.getType().getSqlTypeName() != SqlTypeName.ARRAY || !SqlTypeFamily.STRING.contains(rexNode.getType().getComponentType())) {
            return null;
        }
        HashSet hashSet = new HashSet();
        Iterator it = ((RexCall) rexNode).getOperands().iterator();
        while (it.hasNext()) {
            String stringValue2 = RexLiteral.stringValue((RexNode) it.next());
            if (stringValue2 != null || z) {
                hashSet.add(stringValue2);
            }
        }
        return hashSet;
    }
}
