package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BooleanType;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushFilterThroughBoolOrAggregation.class */
public class PushFilterThroughBoolOrAggregation {
    private static final CatalogSchemaFunctionName BOOL_OR = GlobalFunctionCatalog.builtinFunctionName("bool_or");
    private final PlannerContext plannerContext;

    @VisibleForTesting
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushFilterThroughBoolOrAggregation$PushFilterThroughBoolOrAggregationWithProject.class */
    public static final class PushFilterThroughBoolOrAggregationWithProject implements Rule<FilterNode> {
        private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern = Patterns.filter().with(Patterns.source().matching(Patterns.project().matching((v0) -> {
            return v0.isIdentity();
        }).capturedAs(PROJECT).with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughBoolOrAggregation::isGroupedBoolOr).capturedAs(AGGREGATION)))));

        public PushFilterThroughBoolOrAggregationWithProject(PlannerContext plannerContext) {
            this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            return PushFilterThroughBoolOrAggregation.pushFilter(filterNode, (AggregationNode) captures.get(AGGREGATION), Optional.of((ProjectNode) captures.get(PROJECT)), this.plannerContext, context);
        }
    }

    @VisibleForTesting
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushFilterThroughBoolOrAggregation$PushFilterThroughBoolOrAggregationWithoutProject.class */
    public static final class PushFilterThroughBoolOrAggregationWithoutProject implements Rule<FilterNode> {
        private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
        private final PlannerContext plannerContext;
        private final Pattern<FilterNode> pattern = Patterns.filter().with(Patterns.source().matching(Patterns.aggregation().matching(PushFilterThroughBoolOrAggregation::isGroupedBoolOr).capturedAs(AGGREGATION)));

        public PushFilterThroughBoolOrAggregationWithoutProject(PlannerContext plannerContext) {
            this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<FilterNode> getPattern() {
            return this.pattern;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            return PushFilterThroughBoolOrAggregation.pushFilter(filterNode, (AggregationNode) captures.get(AGGREGATION), Optional.empty(), this.plannerContext, context);
        }
    }

    public PushFilterThroughBoolOrAggregation(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(new PushFilterThroughBoolOrAggregationWithoutProject(this.plannerContext), new PushFilterThroughBoolOrAggregationWithProject(this.plannerContext));
    }

    private static Rule.Result pushFilter(FilterNode filterNode, AggregationNode aggregationNode, Optional<ProjectNode> optional, PlannerContext plannerContext, Rule.Context context) {
        Symbol symbol = (Symbol) Iterables.getOnlyElement(aggregationNode.getAggregations().keySet());
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) Iterables.getOnlyElement(aggregationNode.getAggregations().values());
        DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filterNode.getPredicate());
        TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
        Expression remainingExpression = extractionResult.getRemainingExpression();
        if (tupleDomain.isNone()) {
            return Rule.Result.ofPlanNode(new ValuesNode(filterNode.getId(), filterNode.getOutputSymbols()));
        }
        List<Expression> extractConjuncts = IrUtils.extractConjuncts(remainingExpression);
        Map map = (Map) extractConjuncts.stream().filter(expression -> {
            return SymbolsExtractor.extractUnique(expression).contains(symbol);
        }).collect(Collectors.partitioningBy(expression2 -> {
            return isSupportedCoalesce(expression2, symbol);
        }));
        if (!((List) map.get(Boolean.FALSE)).isEmpty()) {
            return Rule.Result.empty();
        }
        Optional map2 = Optional.ofNullable((List) map.get(Boolean.TRUE)).filter(list -> {
            return !list.isEmpty();
        }).map((v0) -> {
            return v0.getFirst();
        });
        Optional ofNullable = Optional.ofNullable((Domain) ((Map) tupleDomain.getDomains().get()).get(symbol));
        if (ofNullable.isPresent() && !((Domain) ofNullable.get()).equals(Domain.singleValue(BooleanType.BOOLEAN, true))) {
            return Rule.Result.empty();
        }
        if (map2.isEmpty() && ofNullable.isEmpty()) {
            return Rule.Result.empty();
        }
        AggregationNode build = AggregationNode.builderFrom(aggregationNode).setSource(new FilterNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), (Expression) aggregation.getArguments().getFirst())).setAggregations(ImmutableMap.of()).build();
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), build, Assignments.builder().putIdentities(build.getOutputSymbols()).put(symbol, Booleans.TRUE).build());
        PlanNode planNode = (PlanNode) optional.map(projectNode2 -> {
            return projectNode2.replaceChildren(ImmutableList.of(projectNode));
        }).orElse(projectNode);
        if (map2.isPresent()) {
            remainingExpression = IrUtils.combineConjuncts(extractConjuncts.stream().filter(expression3 -> {
                return !expression3.equals(map2.get());
            }).toList());
        }
        Expression combineConjuncts = IrUtils.combineConjuncts(new DomainTranslator(plannerContext.getMetadata()).toPredicate(tupleDomain.filter((symbol2, domain) -> {
            return !symbol2.equals(symbol);
        })), remainingExpression);
        return !combineConjuncts.equals(Booleans.TRUE) ? Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), planNode, combineConjuncts)) : Rule.Result.ofPlanNode(planNode);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSupportedCoalesce(Expression expression, Symbol symbol) {
        if (!(expression instanceof Coalesce)) {
            return false;
        }
        Coalesce coalesce = (Coalesce) expression;
        if (coalesce.operands().size() != 2) {
            return false;
        }
        return ((Expression) coalesce.operands().getFirst()).equals(symbol.toSymbolReference()) && ((Expression) coalesce.operands().getLast()).equals(Booleans.FALSE);
    }

    public static boolean isGroupedBoolOr(AggregationNode aggregationNode) {
        if (!isGroupedAggregation(aggregationNode) || aggregationNode.getAggregations().size() != 1) {
            return false;
        }
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) Iterables.getOnlyElement(aggregationNode.getAggregations().values());
        return !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent() && aggregation.getResolvedFunction().name().equals(BOOL_OR) && (aggregation.getArguments().getFirst() instanceof Reference);
    }

    private static boolean isGroupedAggregation(AggregationNode aggregationNode) {
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }
}
