package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.Metadata;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.In;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.optimizer.IrExpressionOptimizer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.type.DateTimes;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.temporal.TemporalAdjusters;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/UnwrapYearInComparison.class */
public class UnwrapYearInComparison extends ExpressionRewriteRuleSet {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/UnwrapYearInComparison$Visitor.class */
    public static class Visitor extends ExpressionRewriter<Void> {
        private final IrExpressionOptimizer optimizer;
        private final Metadata metadata;
        private final Session session;

        public Visitor(PlannerContext plannerContext, Session session) {
            this.optimizer = IrExpressionOptimizer.newOptimizer(plannerContext);
            this.metadata = plannerContext.getMetadata();
            this.session = (Session) Objects.requireNonNull(session, "session is null");
        }

        @Override // io.trino.sql.ir.ExpressionRewriter
        public Expression rewriteComparison(Comparison comparison, Void r6, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
            return unwrapYear((Comparison) expressionTreeRewriter.defaultRewrite(comparison, null));
        }

        @Override // io.trino.sql.ir.ExpressionRewriter
        public Expression rewriteIn(In in, Void r8, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
            In in2 = (In) expressionTreeRewriter.defaultRewrite(in, null);
            Expression value = in2.value();
            if (value instanceof Call) {
                Call call = (Call) value;
                if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName("year")) && call.arguments().size() == 1) {
                    ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(in.valueList().size());
                    Iterator<Expression> it = in.valueList().iterator();
                    while (it.hasNext()) {
                        Comparison comparison = new Comparison(Comparison.Operator.EQUAL, value, it.next());
                        Expression unwrapYear = unwrapYear(comparison);
                        if (unwrapYear == comparison) {
                            return in2;
                        }
                        builderWithExpectedSize.add(unwrapYear);
                    }
                    return IrUtils.or((Collection<Expression>) builderWithExpectedSize.build());
                }
            }
            return in2;
        }

        private Expression unwrapYear(Comparison comparison) {
            Expression left = comparison.left();
            if (left instanceof Call) {
                Call call = (Call) left;
                if (call.function().name().equals(GlobalFunctionCatalog.builtinFunctionName("year")) && call.arguments().size() == 1) {
                    Expression expression = (Expression) Iterables.getOnlyElement(call.arguments());
                    DateType type = expression.type();
                    Expression orElse = this.optimizer.process(comparison.right(), this.session, (Map<Symbol, Expression>) ImmutableMap.of()).orElse(comparison.right());
                    if ((orElse instanceof Constant) && ((Constant) orElse).value() == null) {
                        switch (comparison.operator()) {
                            case EQUAL:
                            case NOT_EQUAL:
                            case LESS_THAN:
                            case LESS_THAN_OR_EQUAL:
                            case GREATER_THAN:
                            case GREATER_THAN_OR_EQUAL:
                                return new Constant(BooleanType.BOOLEAN, null);
                            case IDENTICAL:
                                return new IsNull(expression);
                            default:
                                throw new MatchException((String) null, (Throwable) null);
                        }
                    }
                    if (!(orElse instanceof Constant)) {
                        return comparison;
                    }
                    Constant constant = (Constant) orElse;
                    try {
                        constant.type();
                        Object value = constant.value();
                        if (type instanceof TimestampWithTimeZoneType) {
                            return comparison;
                        }
                        if (type != DateType.DATE && !(type instanceof TimestampType)) {
                            return comparison;
                        }
                        int intExact = Math.toIntExact(((Long) value).longValue());
                        switch (comparison.operator()) {
                            case EQUAL:
                                return between(expression, type, UnwrapYearInComparison.calculateRangeStartInclusive(intExact, type), UnwrapYearInComparison.calculateRangeEndInclusive(intExact, type));
                            case NOT_EQUAL:
                                return IrExpressions.not(this.metadata, between(expression, type, UnwrapYearInComparison.calculateRangeStartInclusive(intExact, type), UnwrapYearInComparison.calculateRangeEndInclusive(intExact, type)));
                            case LESS_THAN:
                                return new Comparison(Comparison.Operator.LESS_THAN, expression, new Constant(type, UnwrapYearInComparison.calculateRangeStartInclusive(intExact, type)));
                            case LESS_THAN_OR_EQUAL:
                                return new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, expression, new Constant(type, UnwrapYearInComparison.calculateRangeEndInclusive(intExact, type)));
                            case GREATER_THAN:
                                return new Comparison(Comparison.Operator.GREATER_THAN, expression, new Constant(type, UnwrapYearInComparison.calculateRangeEndInclusive(intExact, type)));
                            case GREATER_THAN_OR_EQUAL:
                                return new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, expression, new Constant(type, UnwrapYearInComparison.calculateRangeStartInclusive(intExact, type)));
                            case IDENTICAL:
                                return Logical.and(IrExpressions.not(this.metadata, new IsNull(expression)), between(expression, type, UnwrapYearInComparison.calculateRangeStartInclusive(intExact, type), UnwrapYearInComparison.calculateRangeEndInclusive(intExact, type)));
                            default:
                                throw new MatchException((String) null, (Throwable) null);
                        }
                    } catch (Throwable th) {
                        throw new MatchException(th.toString(), th);
                    }
                }
            }
            return comparison;
        }

        private Between between(Expression expression, Type type, Object obj, Object obj2) {
            return new Between(expression, new Constant(type, obj), new Constant(type, obj2));
        }
    }

    public UnwrapYearInComparison(PlannerContext plannerContext) {
        super(createRewrite(plannerContext));
    }

    private static ExpressionRewriteRuleSet.ExpressionRewriter createRewrite(PlannerContext plannerContext) {
        Objects.requireNonNull(plannerContext, "plannerContext is null");
        return (expression, context) -> {
            return unwrapYear(context.getSession(), plannerContext, expression);
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Expression unwrapYear(Session session, PlannerContext plannerContext, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, session), expression);
    }

    private static Object calculateRangeStartInclusive(int i, Type type) {
        if (type == DateType.DATE) {
            return Long.valueOf(LocalDate.ofYearDay(i, 1).toEpochDay());
        }
        if (!(type instanceof TimestampType)) {
            throw new UnsupportedOperationException("Unsupported type: " + String.valueOf(type));
        }
        long multiplyExact = Math.multiplyExact(LocalDateTime.of(i, 1, 1, 0, 0).toEpochSecond(ZoneOffset.UTC), 1000000);
        return ((TimestampType) type).isShort() ? Long.valueOf(multiplyExact) : new LongTimestamp(multiplyExact, 0);
    }

    @VisibleForTesting
    public static Object calculateRangeEndInclusive(int i, Type type) {
        if (type == DateType.DATE) {
            return Long.valueOf(LocalDate.ofYearDay(i, 1).with(TemporalAdjusters.lastDayOfYear()).toEpochDay());
        }
        if (!(type instanceof TimestampType)) {
            throw new UnsupportedOperationException("Unsupported type: " + String.valueOf(type));
        }
        TimestampType timestampType = (TimestampType) type;
        long multiplyExact = Math.multiplyExact(LocalDateTime.of(i + 1, 1, 1, 0, 0).toEpochSecond(ZoneOffset.UTC), 1000000);
        if (timestampType.isShort()) {
            return Long.valueOf(multiplyExact - DateTimes.scaleFactor(timestampType.getPrecision(), 6));
        }
        return new LongTimestamp(multiplyExact - 1, Math.toIntExact(1000000 - DateTimes.scaleFactor(timestampType.getPrecision(), 12)));
    }
}
