package io.trino.cost;

import com.google.common.base.Preconditions;
import io.trino.Session;
import io.trino.cost.ComposableStatsCalculator;
import io.trino.cost.StatsCalculator;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.QueryRunner;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;

/* loaded from: input_file:io/trino/cost/StatsCalculatorAssertion.class */
public class StatsCalculatorAssertion {
    private final QueryRunner queryRunner;
    private final Session session;
    private final PlanNode planNode;
    private RuntimeInfoProvider runtimeInfoProvider = RuntimeInfoProvider.noImplementation();
    private Optional<TableStatsProvider> tableStatsProvider = Optional.empty();
    private final Map<PlanNode, PlanNodeStatsEstimate> sourcesStats = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public StatsCalculatorAssertion(QueryRunner queryRunner, Session session, PlanNode planNode) {
        this.queryRunner = (QueryRunner) Objects.requireNonNull(queryRunner, "queryRunner is null");
        this.session = (Session) Objects.requireNonNull(session, "session cannot be null");
        this.planNode = (PlanNode) Objects.requireNonNull(planNode, "planNode is null");
        planNode.getSources().forEach(planNode2 -> {
            this.sourcesStats.put(planNode2, PlanNodeStatsEstimate.unknown());
        });
    }

    public StatsCalculatorAssertion withSourceStats(PlanNodeStatsEstimate planNodeStatsEstimate) {
        Preconditions.checkState(this.planNode.getSources().size() == 1, "expected single source");
        return withSourceStats(0, planNodeStatsEstimate);
    }

    public StatsCalculatorAssertion withSourceStats(int i, PlanNodeStatsEstimate planNodeStatsEstimate) {
        Preconditions.checkArgument(i < this.planNode.getSources().size(), "invalid sourceIndex %s; planNode has %s sources", i, this.planNode.getSources().size());
        this.sourcesStats.put((PlanNode) this.planNode.getSources().get(i), planNodeStatsEstimate);
        return this;
    }

    public StatsCalculatorAssertion withSourceStats(PlanNodeId planNodeId, PlanNodeStatsEstimate planNodeStatsEstimate) {
        this.sourcesStats.put(PlanNodeSearcher.searchFrom(this.planNode).where(planNode -> {
            return planNode.getId().equals(planNodeId);
        }).findOnlyElement(), planNodeStatsEstimate);
        return this;
    }

    public StatsCalculatorAssertion withSourceStats(Map<PlanNode, PlanNodeStatsEstimate> map) {
        this.sourcesStats.putAll(map);
        return this;
    }

    public StatsCalculatorAssertion withTableStatisticsProvider(TableStatsProvider tableStatsProvider) {
        this.tableStatsProvider = Optional.of(tableStatsProvider);
        return this;
    }

    public StatsCalculatorAssertion withRuntimeInfoProvider(RuntimeInfoProvider runtimeInfoProvider) {
        this.runtimeInfoProvider = runtimeInfoProvider;
        return this;
    }

    public StatsCalculatorAssertion check(Consumer<PlanNodeStatsAssertion> consumer) {
        consumer.accept(PlanNodeStatsAssertion.assertThat(this.queryRunner.getStatsCalculator().calculateStats(this.planNode, new StatsCalculator.Context(this::getSourceStats, Lookup.noLookup(), this.session, this.tableStatsProvider.orElseGet(() -> {
            return new CachingTableStatsProvider(this.queryRunner.getPlannerContext().getMetadata(), this.session);
        }), this.runtimeInfoProvider))));
        return this;
    }

    public StatsCalculatorAssertion check(ComposableStatsCalculator.Rule<?> rule, Consumer<PlanNodeStatsAssertion> consumer) {
        Optional<PlanNodeStatsEstimate> calculatedStats = calculatedStats(rule, this.planNode, new StatsCalculator.Context(this::getSourceStats, Lookup.noLookup(), this.session, this.tableStatsProvider.orElseGet(() -> {
            return new CachingTableStatsProvider(this.queryRunner.getPlannerContext().getMetadata(), this.session);
        }), this.runtimeInfoProvider));
        Preconditions.checkState(calculatedStats.isPresent(), "Expected stats estimates to be present");
        consumer.accept(PlanNodeStatsAssertion.assertThat(calculatedStats.get()));
        return this;
    }

    private static <T extends PlanNode> Optional<PlanNodeStatsEstimate> calculatedStats(ComposableStatsCalculator.Rule<T> rule, PlanNode planNode, StatsCalculator.Context context) {
        return rule.calculate(planNode, context);
    }

    private PlanNodeStatsEstimate getSourceStats(PlanNode planNode) {
        Preconditions.checkArgument(this.sourcesStats.containsKey(planNode), "stats not found for source %s", planNode);
        return this.sourcesStats.get(planNode);
    }
}
