package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.matching.Property;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.TrinoException;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.IrExpressionOptimizer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations.class */
public class PreAggregateCaseAggregations implements Rule<AggregationNode> {
    private static final int MIN_AGGREGATION_COUNT = 4;
    private static final CatalogSchemaFunctionName MAX = GlobalFunctionCatalog.builtinFunctionName("max");
    private static final CatalogSchemaFunctionName MIN = GlobalFunctionCatalog.builtinFunctionName("min");
    private static final CatalogSchemaFunctionName SUM = GlobalFunctionCatalog.builtinFunctionName("sum");
    private static final Set<CatalogSchemaFunctionName> ALLOWED_FUNCTIONS = ImmutableSet.of(MAX, MIN, SUM);
    private static final Capture<ProjectNode> PROJECT_CAPTURE = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN;
    private final PlannerContext plannerContext;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations$CaseAggregation.class */
    public static class CaseAggregation {
        private final Symbol aggregationSymbol;
        private final ResolvedFunction function;
        private final ResolvedFunction cumulativeFunction;
        private final CatalogSchemaFunctionName name;
        private final Expression operand;
        private final Expression result;
        private final Expression cumulativeAggregationDefaultValue;

        public CaseAggregation(Symbol symbol, ResolvedFunction resolvedFunction, ResolvedFunction resolvedFunction2, CatalogSchemaFunctionName catalogSchemaFunctionName, Expression expression, Expression expression2, Expression expression3) {
            this.aggregationSymbol = (Symbol) Objects.requireNonNull(symbol, "aggregationSymbol is null");
            this.function = (ResolvedFunction) Objects.requireNonNull(resolvedFunction, "function is null");
            this.cumulativeFunction = (ResolvedFunction) Objects.requireNonNull(resolvedFunction2, "cumulativeFunction is null");
            this.name = (CatalogSchemaFunctionName) Objects.requireNonNull(catalogSchemaFunctionName, "name is null");
            this.operand = (Expression) Objects.requireNonNull(expression, "operand is null");
            this.result = (Expression) Objects.requireNonNull(expression2, "result is null");
            this.cumulativeAggregationDefaultValue = (Expression) Objects.requireNonNull(expression3, "cumulativeAggregationDefaultValue is null");
        }

        public Symbol getAggregationSymbol() {
            return this.aggregationSymbol;
        }

        public ResolvedFunction getFunction() {
            return this.function;
        }

        public ResolvedFunction getCumulativeFunction() {
            return this.cumulativeFunction;
        }

        public CatalogSchemaFunctionName getName() {
            return this.name;
        }

        public Expression getOperand() {
            return this.operand;
        }

        public Expression getResult() {
            return this.result;
        }

        public Expression getCumulativeAggregationDefaultValue() {
            return this.cumulativeAggregationDefaultValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations$PreAggregation.class */
    public static class PreAggregation {
        private final Symbol aggregationSymbol;
        private final Expression projection;
        private final Symbol projectionSymbol;

        public PreAggregation(Symbol symbol, Expression expression, Symbol symbol2) {
            this.aggregationSymbol = (Symbol) Objects.requireNonNull(symbol, "aggregationSymbol is null");
            this.projection = (Expression) Objects.requireNonNull(expression, "projection is null");
            this.projectionSymbol = (Symbol) Objects.requireNonNull(symbol2, "projectionSymbol is null");
        }

        public Symbol getAggregationSymbol() {
            return this.aggregationSymbol;
        }

        public Expression getProjection() {
            return this.projection;
        }

        public Symbol getProjectionSymbol() {
            return this.projectionSymbol;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/PreAggregateCaseAggregations$PreAggregationKey.class */
    public static class PreAggregationKey {
        private final ResolvedFunction function;
        private final Expression projection;

        private PreAggregationKey(CaseAggregation caseAggregation) {
            this.function = caseAggregation.getFunction();
            this.projection = caseAggregation.getResult();
        }

        public ResolvedFunction getFunction() {
            return this.function;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            PreAggregationKey preAggregationKey = (PreAggregationKey) obj;
            return Objects.equals(this.function, preAggregationKey.function) && Objects.equals(this.projection, preAggregationKey.projection);
        }

        public int hashCode() {
            return Objects.hash(this.function, this.projection);
        }
    }

    public PreAggregateCaseAggregations(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPreAggregateCaseAggregationsEnabled(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        ProjectNode projectNode = (ProjectNode) captures.get(PROJECT_CAPTURE);
        Optional<List<CaseAggregation>> extractCaseAggregations = extractCaseAggregations(aggregationNode, projectNode, context);
        if (extractCaseAggregations.isEmpty()) {
            return Rule.Result.empty();
        }
        List<CaseAggregation> list = extractCaseAggregations.get();
        if (list.size() < MIN_AGGREGATION_COUNT) {
            return Rule.Result.empty();
        }
        Set set = (Set) list.stream().flatMap(caseAggregation -> {
            return SymbolsExtractor.extractUnique(caseAggregation.getOperand()).stream();
        }).collect(ImmutableSet.toImmutableSet());
        if (set.size() != 1) {
            return Rule.Result.empty();
        }
        Map<PreAggregationKey, PreAggregation> preAggregations = getPreAggregations(list, context);
        if (preAggregations.size() == list.size()) {
            return Rule.Result.empty();
        }
        Assignments.Builder builder = Assignments.builder();
        builder.putIdentities(set);
        aggregationNode.getGroupingKeys().forEach(symbol -> {
            builder.put(symbol, projectNode.getAssignments().get(symbol));
        });
        Assignments build = builder.build();
        AggregationNode createPreAggregation = createPreAggregation(createPreProjection(projectNode.getSource(), build, preAggregations, context), build.getOutputs(), preAggregations, context);
        Map<CaseAggregation, Symbol> newProjectionSymbols = getNewProjectionSymbols(list, context);
        return Rule.Result.ofPlanNode(createNewAggregation(createNewProjection(createPreAggregation, aggregationNode, projectNode, newProjectionSymbols, preAggregations), aggregationNode, newProjectionSymbols));
    }

    private AggregationNode createNewAggregation(PlanNode planNode, AggregationNode aggregationNode, Map<CaseAggregation, Symbol> map) {
        return new AggregationNode(aggregationNode.getId(), planNode, (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((CaseAggregation) entry.getKey()).getAggregationSymbol();
        }, entry2 -> {
            return new AggregationNode.Aggregation(((CaseAggregation) entry2.getKey()).getCumulativeFunction(), ImmutableList.of(((Symbol) entry2.getValue()).toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty());
        })), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedSymbols(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol());
    }

    private ProjectNode createNewProjection(PlanNode planNode, AggregationNode aggregationNode, ProjectNode projectNode, Map<CaseAggregation, Symbol> map, Map<PreAggregationKey, PreAggregation> map2) {
        Assignments.Builder builder = Assignments.builder();
        builder.putIdentities(aggregationNode.getGroupingKeys());
        map.forEach((caseAggregation, symbol) -> {
            builder.put(symbol, new Case(ImmutableList.of(new WhenClause(caseAggregation.getOperand(), ((PreAggregation) map2.get(new PreAggregationKey(caseAggregation))).getAggregationSymbol().toSymbolReference())), caseAggregation.getCumulativeAggregationDefaultValue()));
        });
        return new ProjectNode(projectNode.getId(), planNode, builder.build());
    }

    private Map<CaseAggregation, Symbol> getNewProjectionSymbols(List<CaseAggregation> list, Rule.Context context) {
        return (Map) list.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), caseAggregation -> {
            return context.getSymbolAllocator().newSymbol(caseAggregation.getAggregationSymbol());
        }));
    }

    private AggregationNode createPreAggregation(PlanNode planNode, List<Symbol> list, Map<PreAggregationKey, PreAggregation> map, Rule.Context context) {
        return new AggregationNode(context.getIdAllocator().getNextId(), planNode, (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((PreAggregation) entry.getValue()).getAggregationSymbol();
        }, entry2 -> {
            return new AggregationNode.Aggregation(((PreAggregationKey) entry2.getKey()).getFunction(), ImmutableList.of(((PreAggregation) entry2.getValue()).getProjectionSymbol().toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty());
        })), AggregationNode.singleGroupingSet(list), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }

    private ProjectNode createPreProjection(PlanNode planNode, Assignments assignments, Map<PreAggregationKey, PreAggregation> map, Rule.Context context) {
        Assignments.Builder builder = Assignments.builder();
        builder.putAll(assignments);
        map.values().forEach(preAggregation -> {
            builder.put(preAggregation.getProjectionSymbol(), preAggregation.getProjection());
        });
        return new ProjectNode(context.getIdAllocator().getNextId(), planNode, builder.build());
    }

    private Map<PreAggregationKey, PreAggregation> getPreAggregations(List<CaseAggregation> list, Rule.Context context) {
        return (Map) ((ImmutableSetMultimap) list.stream().collect(ImmutableSetMultimap.toImmutableSetMultimap(PreAggregationKey::new, Function.identity()))).asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            PreAggregationKey preAggregationKey = (PreAggregationKey) entry.getKey();
            Set set = (Set) entry.getValue();
            Expression expression = preAggregationKey.projection;
            Type type = getType(expression);
            Type type2 = (Type) Iterables.getOnlyElement(preAggregationKey.getFunction().signature().getArgumentTypes());
            if (!type.equals(type2)) {
                expression = new Cast(expression, type2);
            }
            if (IrExpressions.mayFail(this.plannerContext, expression)) {
                expression = IrExpressions.ifExpression(IrUtils.or((Collection<Expression>) set.stream().map((v0) -> {
                    return v0.getOperand();
                }).collect(ImmutableSet.toImmutableSet())), expression);
            }
            return new PreAggregation(context.getSymbolAllocator().newSymbol(((CaseAggregation) set.iterator().next()).getAggregationSymbol()), expression, context.getSymbolAllocator().newSymbol(expression));
        }));
    }

    private Optional<List<CaseAggregation>> extractCaseAggregations(AggregationNode aggregationNode, ProjectNode projectNode, Rule.Context context) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            Optional<CaseAggregation> extractCaseAggregation = extractCaseAggregation(entry.getKey(), entry.getValue(), projectNode, context);
            if (extractCaseAggregation.isEmpty()) {
                return Optional.empty();
            }
            builder.add(extractCaseAggregation.get());
        }
        return Optional.of(builder.build());
    }

    private Optional<CaseAggregation> extractCaseAggregation(Symbol symbol, AggregationNode.Aggregation aggregation, ProjectNode projectNode, Rule.Context context) {
        if (aggregation.getArguments().size() != 1 || !(aggregation.getArguments().get(0) instanceof Reference) || aggregation.isDistinct() || aggregation.getFilter().isPresent() || aggregation.getMask().isPresent() || aggregation.getOrderingScheme().isPresent()) {
            return Optional.empty();
        }
        ResolvedFunction resolvedFunction = aggregation.getResolvedFunction();
        CatalogSchemaFunctionName name = resolvedFunction.signature().getName();
        if (!ALLOWED_FUNCTIONS.contains(name)) {
            return Optional.empty();
        }
        Expression expression = projectNode.getAssignments().get(Symbol.from(aggregation.getArguments().get(0)));
        Expression expression2 = expression instanceof Cast ? ((Cast) expression).expression() : expression;
        if (!(expression2 instanceof Case)) {
            return Optional.empty();
        }
        Case r0 = (Case) expression2;
        if (r0.whenClauses().size() != 1) {
            return Optional.empty();
        }
        Type returnType = resolvedFunction.signature().getReturnType();
        try {
            ResolvedFunction resolveBuiltinFunction = this.plannerContext.getMetadata().resolveBuiltinFunction(name.getFunctionName(), TypeSignatureProvider.fromTypes(returnType));
            if (!resolveBuiltinFunction.signature().getReturnType().equals(returnType)) {
                return Optional.empty();
            }
            Expression optimizeExpression = optimizeExpression(r0.defaultValue(), context);
            if (!(optimizeExpression instanceof Constant)) {
                return Optional.empty();
            }
            Constant constant = (Constant) optimizeExpression;
            try {
                RealType type = constant.type();
                Object value = constant.value();
                if (value != null) {
                    if (!name.equals(SUM)) {
                        return Optional.empty();
                    }
                    if (!value.equals(0L) && !value.equals(Double.valueOf(0.0d)) && !value.equals(Int128.ZERO)) {
                        return Optional.empty();
                    }
                    if (!(type instanceof BigintType) && type != IntegerType.INTEGER && type != SmallintType.SMALLINT && type != TinyintType.TINYINT && type != DoubleType.DOUBLE && type != RealType.REAL && !(type instanceof DecimalType)) {
                        return Optional.empty();
                    }
                }
                return Optional.of(new CaseAggregation(symbol, resolvedFunction, resolveBuiltinFunction, name, r0.whenClauses().get(0).getOperand(), r0.whenClauses().get(0).getResult(), new Cast(r0.defaultValue(), returnType)));
            } catch (Throwable th) {
                throw new MatchException(th.toString(), th);
            }
        } catch (TrinoException e) {
            return Optional.empty();
        }
    }

    private Type getType(Expression expression) {
        return expression.type();
    }

    private Expression optimizeExpression(Expression expression, Rule.Context context) {
        return IrExpressionOptimizer.newOptimizer(this.plannerContext).process(expression, context.getSession(), (Map<Symbol, Expression>) ImmutableMap.of()).orElse(expression);
    }

    static {
        Pattern matching = Patterns.aggregation().matching(aggregationNode -> {
            return aggregationNode.getStep() == AggregationNode.Step.SINGLE && aggregationNode.getGroupingSetCount() == 1;
        });
        Property<PlanNode, Lookup, PlanNode> source = Patterns.source();
        Pattern capturedAs = Patterns.project().capturedAs(PROJECT_CAPTURE);
        Property<PlanNode, Lookup, PlanNode> source2 = Patterns.source();
        Class<AggregationNode> cls = AggregationNode.class;
        Objects.requireNonNull(AggregationNode.class);
        PATTERN = matching.with(source.matching(capturedAs.with(source2.matching(Predicate.not((v1) -> {
            return r4.isInstance(v1);
        })))));
    }
}
