package io.trino.sql.ir;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/ir/ExpressionTreeRewriter.class */
public final class ExpressionTreeRewriter<C> {
    private final ExpressionRewriter<C> rewriter;
    private final IrVisitor<Expression, Context<C>> visitor = new RewritingVisitor();

    /* loaded from: input_file:io/trino/sql/ir/ExpressionTreeRewriter$Context.class */
    public static class Context<C> {
        private final boolean defaultRewrite;
        private final C context;

        private Context(C c, boolean z) {
            this.context = c;
            this.defaultRewrite = z;
        }

        public C get() {
            return this.context;
        }

        public boolean isDefaultRewrite() {
            return this.defaultRewrite;
        }
    }

    /* loaded from: input_file:io/trino/sql/ir/ExpressionTreeRewriter$RewritingVisitor.class */
    private class RewritingVisitor extends IrVisitor<Expression, Context<C>> {
        private RewritingVisitor() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitExpression(Expression expression, Context<C> context) {
            throw new UnsupportedOperationException("visit() not implemented for " + expression.getClass().getName());
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitArray(Array array, Context<C> context) {
            Expression rewriteArray;
            if (!context.isDefaultRewrite() && (rewriteArray = ExpressionTreeRewriter.this.rewriter.rewriteArray(array, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteArray;
            }
            List<Expression> rewrite = ExpressionTreeRewriter.this.rewrite(array.elements(), context);
            return !ExpressionTreeRewriter.sameElements(array.elements(), rewrite) ? new Array(array.elementType(), rewrite) : array;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitRow(Row row, Context<C> context) {
            Expression rewriteRow;
            if (!context.isDefaultRewrite() && (rewriteRow = ExpressionTreeRewriter.this.rewriter.rewriteRow(row, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteRow;
            }
            List<Expression> rewrite = ExpressionTreeRewriter.this.rewrite(row.items(), context);
            return !ExpressionTreeRewriter.sameElements(row.items(), rewrite) ? new Row(rewrite) : row;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitFieldReference(FieldReference fieldReference, Context<C> context) {
            Expression rewriteSubscript;
            if (!context.isDefaultRewrite() && (rewriteSubscript = ExpressionTreeRewriter.this.rewriter.rewriteSubscript(fieldReference, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteSubscript;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) fieldReference.base(), (Expression) context.get());
            return rewrite != fieldReference.base() ? new FieldReference(rewrite, fieldReference.field()) : fieldReference;
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitComparison(Comparison comparison, Context<C> context) {
            Expression rewriteComparison;
            if (!context.isDefaultRewrite() && (rewriteComparison = ExpressionTreeRewriter.this.rewriter.rewriteComparison(comparison, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteComparison;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) comparison.left(), (Expression) context.get());
            Expression rewrite2 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) comparison.right(), (Expression) context.get());
            return (rewrite == comparison.left() && rewrite2 == comparison.right()) ? comparison : new Comparison(comparison.operator(), rewrite, rewrite2);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitBetween(Between between, Context<C> context) {
            Expression rewriteBetween;
            if (!context.isDefaultRewrite() && (rewriteBetween = ExpressionTreeRewriter.this.rewriter.rewriteBetween(between, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteBetween;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) between.value(), (Expression) context.get());
            Expression rewrite2 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) between.min(), (Expression) context.get());
            Expression rewrite3 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) between.max(), (Expression) context.get());
            return (rewrite == between.value() && rewrite2 == between.min() && rewrite3 == between.max()) ? between : new Between(rewrite, rewrite2, rewrite3);
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitLogical(Logical logical, Context<C> context) {
            Expression rewriteLogical;
            if (!context.isDefaultRewrite() && (rewriteLogical = ExpressionTreeRewriter.this.rewriter.rewriteLogical(logical, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteLogical;
            }
            List<Expression> rewrite = ExpressionTreeRewriter.this.rewrite(logical.terms(), context);
            return !ExpressionTreeRewriter.sameElements(logical.terms(), rewrite) ? new Logical(logical.operator(), rewrite) : logical;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitIsNull(IsNull isNull, Context<C> context) {
            Expression rewriteIsNull;
            if (!context.isDefaultRewrite() && (rewriteIsNull = ExpressionTreeRewriter.this.rewriter.rewriteIsNull(isNull, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteIsNull;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) isNull.value(), (Expression) context.get());
            return rewrite != isNull.value() ? new IsNull(rewrite) : isNull;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitNullIf(NullIf nullIf, Context<C> context) {
            Expression rewriteNullIf;
            if (!context.isDefaultRewrite() && (rewriteNullIf = ExpressionTreeRewriter.this.rewriter.rewriteNullIf(nullIf, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteNullIf;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) nullIf.first(), (Expression) context.get());
            Expression rewrite2 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) nullIf.second(), (Expression) context.get());
            return (rewrite == nullIf.first() && rewrite2 == nullIf.second()) ? nullIf : new NullIf(rewrite, rewrite2);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitCase(Case r6, Context<C> context) {
            Expression rewriteCase;
            if (!context.isDefaultRewrite() && (rewriteCase = ExpressionTreeRewriter.this.rewriter.rewriteCase(r6, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteCase;
            }
            ImmutableList.Builder builder = ImmutableList.builder();
            Iterator<WhenClause> it = r6.whenClauses().iterator();
            while (it.hasNext()) {
                builder.add(rewriteWhenClause(it.next(), context));
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) r6.defaultValue(), (Expression) context.get());
            return (r6.defaultValue() == rewrite && ExpressionTreeRewriter.sameElements((Iterable) r6.whenClauses(), (Iterable) builder.build())) ? r6 : new Case(builder.build(), rewrite);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitSwitch(Switch r7, Context<C> context) {
            Expression rewriteSwitch;
            if (!context.isDefaultRewrite() && (rewriteSwitch = ExpressionTreeRewriter.this.rewriter.rewriteSwitch(r7, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteSwitch;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) r7.operand(), (Expression) context.get());
            ImmutableList.Builder builder = ImmutableList.builder();
            Iterator<WhenClause> it = r7.whenClauses().iterator();
            while (it.hasNext()) {
                builder.add(rewriteWhenClause(it.next(), context));
            }
            Expression rewrite2 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) r7.defaultValue(), (Expression) context.get());
            return (rewrite == r7.operand() && r7.defaultValue() == rewrite2 && ExpressionTreeRewriter.sameElements((Iterable) r7.whenClauses(), (Iterable) builder.build())) ? r7 : new Switch(rewrite, builder.build(), rewrite2);
        }

        protected WhenClause rewriteWhenClause(WhenClause whenClause, Context<C> context) {
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) whenClause.getOperand(), (Expression) context.get());
            Expression rewrite2 = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) whenClause.getResult(), (Expression) context.get());
            return (rewrite == whenClause.getOperand() && rewrite2 == whenClause.getResult()) ? whenClause : new WhenClause(rewrite, rewrite2);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitCoalesce(Coalesce coalesce, Context<C> context) {
            Expression rewriteCoalesce;
            if (!context.isDefaultRewrite() && (rewriteCoalesce = ExpressionTreeRewriter.this.rewriter.rewriteCoalesce(coalesce, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteCoalesce;
            }
            List<Expression> rewrite = ExpressionTreeRewriter.this.rewrite(coalesce.operands(), context);
            return !ExpressionTreeRewriter.sameElements(coalesce.operands(), rewrite) ? new Coalesce(rewrite) : coalesce;
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitCall(Call call, Context<C> context) {
            Expression rewriteCall;
            if (!context.isDefaultRewrite() && (rewriteCall = ExpressionTreeRewriter.this.rewriter.rewriteCall(call, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteCall;
            }
            List<Expression> rewrite = ExpressionTreeRewriter.this.rewrite(call.arguments(), context);
            return !ExpressionTreeRewriter.sameElements(call.arguments(), rewrite) ? new Call(call.function(), rewrite) : call;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitLambda(Lambda lambda, Context<C> context) {
            Expression rewriteLambda;
            if (!context.isDefaultRewrite() && (rewriteLambda = ExpressionTreeRewriter.this.rewriter.rewriteLambda(lambda, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteLambda;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) lambda.body(), (Expression) context.get());
            return rewrite != lambda.body() ? new Lambda(lambda.arguments(), rewrite) : lambda;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitBind(Bind bind, Context<C> context) {
            Expression rewriteBind;
            if (!context.isDefaultRewrite() && (rewriteBind = ExpressionTreeRewriter.this.rewriter.rewriteBind(bind, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteBind;
            }
            List list = (List) bind.values().stream().map(expression -> {
                return ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) expression, (Expression) context.get());
            }).collect(ImmutableList.toImmutableList());
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) bind.function(), (Lambda) context.get());
            return (ExpressionTreeRewriter.sameElements(list, bind.values()) && rewrite == bind.function()) ? bind : new Bind(list, (Lambda) rewrite);
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitIn(In in, Context<C> context) {
            Expression rewriteIn;
            if (!context.isDefaultRewrite() && (rewriteIn = ExpressionTreeRewriter.this.rewriter.rewriteIn(in, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteIn;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) in.value(), (Expression) context.get());
            List list = (List) in.valueList().stream().map(expression -> {
                return ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) expression, (Expression) context.get());
            }).collect(ImmutableList.toImmutableList());
            return (in.value() == rewrite && ExpressionTreeRewriter.sameElements(list, in.valueList())) ? in : new In(rewrite, list);
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitConstant(Constant constant, Context<C> context) {
            Expression rewriteConstant;
            return (context.isDefaultRewrite() || (rewriteConstant = ExpressionTreeRewriter.this.rewriter.rewriteConstant(constant, context.get(), ExpressionTreeRewriter.this)) == null) ? constant : rewriteConstant;
        }

        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitCast(Cast cast, Context<C> context) {
            Expression rewriteCast;
            if (!context.isDefaultRewrite() && (rewriteCast = ExpressionTreeRewriter.this.rewriter.rewriteCast(cast, context.get(), ExpressionTreeRewriter.this)) != null) {
                return rewriteCast;
            }
            Expression rewrite = ExpressionTreeRewriter.this.rewrite((ExpressionTreeRewriter) cast.expression(), (Expression) context.get());
            return cast.expression() != rewrite ? new Cast(rewrite, cast.type()) : cast;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.ir.IrVisitor
        public Expression visitReference(Reference reference, Context<C> context) {
            Expression rewriteReference;
            return (context.isDefaultRewrite() || (rewriteReference = ExpressionTreeRewriter.this.rewriter.rewriteReference(reference, context.get(), ExpressionTreeRewriter.this)) == null) ? reference : rewriteReference;
        }
    }

    public static <T extends Expression> T rewriteWith(ExpressionRewriter<Void> expressionRewriter, T t) {
        return (T) new ExpressionTreeRewriter(expressionRewriter).rewrite((ExpressionTreeRewriter) t, (T) null);
    }

    public static <C, T extends Expression> T rewriteWith(ExpressionRewriter<C> expressionRewriter, T t, C c) {
        return (T) new ExpressionTreeRewriter(expressionRewriter).rewrite((ExpressionTreeRewriter) t, (T) c);
    }

    public ExpressionTreeRewriter(ExpressionRewriter<C> expressionRewriter) {
        this.rewriter = expressionRewriter;
    }

    private List<Expression> rewrite(List<Expression> list, Context<C> context) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Expression> it = list.iterator();
        while (it.hasNext()) {
            builder.add(rewrite((ExpressionTreeRewriter<C>) it.next(), (Expression) context.get()));
        }
        return builder.build();
    }

    public <T extends Expression> T rewrite(T t, C c) {
        return (T) this.visitor.process(t, new Context<>(c, false));
    }

    public <T extends Expression> T defaultRewrite(T t, C c) {
        return (T) this.visitor.process(t, new Context<>(c, true));
    }

    private static <T> boolean sameElements(Optional<T> optional, Optional<T> optional2) {
        if (optional.isEmpty() && optional2.isEmpty()) {
            return true;
        }
        return optional.isPresent() == optional2.isPresent() && optional.get() == optional2.get();
    }

    private static <T> boolean sameElements(Iterable<? extends T> iterable, Iterable<? extends T> iterable2) {
        if (Iterables.size(iterable) != Iterables.size(iterable2)) {
            return false;
        }
        Iterator<? extends T> it = iterable.iterator();
        Iterator<? extends T> it2 = iterable2.iterator();
        while (it.hasNext() && it2.hasNext()) {
            if (it.next() != it2.next()) {
                return false;
            }
        }
        return true;
    }
}
