package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.NullIf;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/SimplifyFilterPredicate.class */
public class SimplifyFilterPredicate implements Rule<FilterNode> {
    private static final Pattern<FilterNode> PATTERN = Patterns.filter();
    private final Metadata metadata;

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    public SimplifyFilterPredicate(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        Optional<Expression> empty;
        List<Expression> extractConjuncts = IrUtils.extractConjuncts(filterNode.getPredicate());
        ImmutableList.Builder builder = ImmutableList.builder();
        boolean z = false;
        for (Expression expression : extractConjuncts) {
            switch ((int) SwitchBootstraps.typeSwitch(MethodHandles.lookup(), "typeSwitch", MethodType.methodType(Integer.TYPE, Expression.class, Integer.TYPE), NullIf.class, Case.class, Switch.class).dynamicInvoker().invoke(expression, 0) /* invoke-custom */) {
                case -1:
                default:
                    empty = Optional.empty();
                    break;
                case 0:
                    NullIf nullIf = (NullIf) expression;
                    empty = Optional.of(Logical.and(nullIf.first(), isFalseOrNullPredicate(nullIf.second())));
                    break;
                case 1:
                    empty = simplify((Case) expression);
                    break;
                case 2:
                    empty = simplify((Switch) expression);
                    break;
            }
            Optional<Expression> optional = empty;
            if (optional.isPresent()) {
                z = true;
                builder.add(optional.get());
            } else {
                builder.add(expression);
            }
        }
        if (!z) {
            return Rule.Result.empty();
        }
        Expression combineConjuncts = IrUtils.combineConjuncts((Collection<Expression>) builder.build());
        if ((combineConjuncts instanceof Constant) && ((Constant) combineConjuncts).value() == null) {
            combineConjuncts = Booleans.FALSE;
        }
        return Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterNode.getSource(), combineConjuncts));
    }

    private Optional<Expression> simplify(Expression expression, Expression expression2, Expression expression3) {
        return (expression2.equals(Booleans.TRUE) && isNotTrue(expression3)) ? Optional.of(expression) : (isNotTrue(expression2) && expression3.equals(Booleans.TRUE)) ? Optional.of(isFalseOrNullPredicate(expression)) : (expression3.equals(expression2) && DeterminismEvaluator.isDeterministic(expression2)) ? Optional.of(expression2) : (isNotTrue(expression2) && isNotTrue(expression3)) ? Optional.of(Booleans.FALSE) : expression.equals(Booleans.TRUE) ? Optional.of(expression2) : isNotTrue(expression) ? Optional.of(expression3) : Optional.empty();
    }

    private Optional<Expression> simplify(Case r6) {
        if (r6.whenClauses().size() == 1) {
            return simplify(((WhenClause) r6.whenClauses().getFirst()).getOperand(), ((WhenClause) r6.whenClauses().getFirst()).getResult(), r6.defaultValue());
        }
        List list = (List) r6.whenClauses().stream().map((v0) -> {
            return v0.getOperand();
        }).collect(ImmutableList.toImmutableList());
        List list2 = (List) r6.whenClauses().stream().map((v0) -> {
            return v0.getResult();
        }).collect(ImmutableList.toImmutableList());
        long count = list2.stream().filter(expression -> {
            return expression.equals(Booleans.TRUE);
        }).count();
        long count2 = list2.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
        if (count == list2.size() && r6.defaultValue().equals(Booleans.TRUE)) {
            return Optional.of(Booleans.TRUE);
        }
        if (count2 == list2.size() && isNotTrue(r6.defaultValue())) {
            return Optional.of(Booleans.FALSE);
        }
        if (count == 1 && count2 == list2.size() - 1 && isNotTrue(r6.defaultValue())) {
            ImmutableList.Builder builder = ImmutableList.builder();
            for (WhenClause whenClause : r6.whenClauses()) {
                Expression operand = whenClause.getOperand();
                if (!isNotTrue(whenClause.getResult())) {
                    builder.add(operand);
                    return Optional.of(IrUtils.combineConjuncts((Collection<Expression>) builder.build()));
                }
                builder.add(isFalseOrNullPredicate(operand));
            }
        }
        if (count2 == list2.size() && r6.defaultValue().equals(Booleans.TRUE)) {
            ImmutableList.Builder builder2 = ImmutableList.builder();
            list.forEach(expression2 -> {
                builder2.add(isFalseOrNullPredicate(expression2));
            });
            return Optional.of(IrUtils.combineConjuncts((Collection<Expression>) builder2.build()));
        }
        ArrayList arrayList = new ArrayList();
        for (WhenClause whenClause2 : r6.whenClauses()) {
            Expression operand2 = whenClause2.getOperand();
            if (operand2.equals(Booleans.TRUE)) {
                return arrayList.isEmpty() ? Optional.of(whenClause2.getResult()) : Optional.of(new Case(arrayList, whenClause2.getResult()));
            }
            if (!isNotTrue(operand2)) {
                arrayList.add(whenClause2);
            }
        }
        return arrayList.isEmpty() ? Optional.of(r6.defaultValue()) : arrayList.size() < r6.whenClauses().size() ? Optional.of(new Case(arrayList, r6.defaultValue())) : Optional.empty();
    }

    private static Optional<Expression> simplify(Switch r3) {
        Optional<Expression> of = Optional.of(r3.defaultValue());
        Expression operand = r3.operand();
        if ((operand instanceof Constant) && ((Constant) operand).value() == null) {
            return of;
        }
        List list = (List) r3.whenClauses().stream().map((v0) -> {
            return v0.getResult();
        }).collect(ImmutableList.toImmutableList());
        return (list.stream().allMatch(expression -> {
            return expression.equals(Booleans.TRUE);
        }) && of.get().equals(Booleans.TRUE)) ? Optional.of(Booleans.TRUE) : (list.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && isNotTrue(of.get())) ? Optional.of(Booleans.FALSE) : Optional.empty();
    }

    private static boolean isNotTrue(Expression expression) {
        return expression.equals(Booleans.FALSE) || ((expression instanceof Constant) && ((Constant) expression).value() == null);
    }

    private Expression isFalseOrNullPredicate(Expression expression) {
        return Logical.or(new IsNull(expression), IrExpressions.not(this.metadata, expression));
    }
}
