package io.trino.cost;

import com.google.common.base.Verify;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsCalculator;
import io.trino.execution.scheduler.OutputDataSizeEstimate;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.matching.Pattern;
import io.trino.spi.type.FixedWidthType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.util.MoreMath;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/cost/RemoteSourceStatsRule.class */
public class RemoteSourceStatsRule extends SimpleStatsRule<RemoteSourceNode> {
    private static final Pattern<RemoteSourceNode> PATTERN = Patterns.remoteSourceNode();

    public RemoteSourceStatsRule(StatsNormalizer statsNormalizer) {
        super(statsNormalizer);
    }

    @Override // io.trino.cost.ComposableStatsCalculator.Rule
    public Pattern<RemoteSourceNode> getPattern() {
        return PATTERN;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.trino.cost.SimpleStatsRule
    public Optional<PlanNodeStatsEstimate> doCalculate(RemoteSourceNode remoteSourceNode, StatsCalculator.Context context) {
        Optional<PlanNodeStatsEstimate> empty = Optional.empty();
        RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();
        for (int i = 0; i < remoteSourceNode.getSourceFragmentIds().size(); i++) {
            PlanFragmentId planFragmentId = remoteSourceNode.getSourceFragmentIds().get(i);
            PlanNodeStatsEstimate adjustStats = adjustStats(remoteSourceNode.getOutputSymbols(), runtimeInfoProvider.getRuntimeOutputStats(planFragmentId), getEstimatedStats(runtimeInfoProvider, context.statsProvider(), planFragmentId));
            empty = empty.map(planNodeStatsEstimate -> {
                return PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues(planNodeStatsEstimate, adjustStats);
            }).or(() -> {
                return Optional.of(adjustStats);
            });
        }
        Verify.verify(empty.isPresent());
        return empty;
    }

    private PlanNodeStatsEstimate getEstimatedStats(RuntimeInfoProvider runtimeInfoProvider, StatsProvider statsProvider, PlanFragmentId planFragmentId) {
        PlanFragment planFragment = runtimeInfoProvider.getPlanFragment(planFragmentId);
        PlanNode root = planFragment.getRoot();
        PlanNodeStatsEstimate planNodeStatsEstimate = planFragment.getStatsAndCosts().getStats().get(root.getId());
        return (planNodeStatsEstimate == null || planNodeStatsEstimate.isOutputRowCountUnknown()) ? statsProvider.getStats(root) : planNodeStatsEstimate;
    }

    private PlanNodeStatsEstimate adjustStats(List<Symbol> list, OutputStatsEstimator.OutputStatsEstimateResult outputStatsEstimateResult, PlanNodeStatsEstimate planNodeStatsEstimate) {
        if (outputStatsEstimateResult.isUnknown()) {
            return planNodeStatsEstimate;
        }
        OutputDataSizeEstimate outputDataSizeEstimate = outputStatsEstimateResult.outputDataSizeEstimate();
        PlanNodeStatsEstimate.Builder outputRowCount = PlanNodeStatsEstimate.builder().setOutputRowCount(outputStatsEstimateResult.outputRowCountEstimate());
        double d = 0.0d;
        double d2 = 0.0d;
        for (Symbol symbol : list) {
            FixedWidthType type = symbol.type();
            double outputRowCountEstimate = outputStatsEstimateResult.outputRowCountEstimate() * (1.0d - MoreMath.firstNonNaN(planNodeStatsEstimate.getSymbolStatistics(symbol).getNullsFraction(), 0.0d));
            if (type instanceof FixedWidthType) {
                d += outputRowCountEstimate * type.getFixedSize();
            } else {
                d2 += outputRowCountEstimate;
            }
        }
        double totalSizeInBytes = outputDataSizeEstimate.getTotalSizeInBytes();
        double d3 = Double.NaN;
        if (d2 > 0.0d && totalSizeInBytes > d) {
            d3 = (totalSizeInBytes - d) / d2;
        }
        for (Symbol symbol2 : list) {
            SymbolStatsEstimate symbolStatistics = planNodeStatsEstimate.getSymbolStatistics(symbol2);
            Type type2 = symbol2.type();
            if (!Double.isNaN(d3) && !(type2 instanceof FixedWidthType)) {
                symbolStatistics = SymbolStatsEstimate.buildFrom(symbolStatistics).setAverageRowSize(d3).build();
            }
            outputRowCount.addSymbolStatistics(symbol2, symbolStatistics);
        }
        return outputRowCount.build();
    }
}
