package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.ThreadSafe;
import io.airlift.log.Logger;
import io.airlift.stats.Distribution;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
import io.trino.execution.StateMachine;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.operator.OperatorStats;
import io.trino.operator.PipelineStats;
import io.trino.operator.TaskStats;
import io.trino.plugin.base.metrics.TDigestHistogram;
import io.trino.spi.eventlistener.StageGcStatistics;
import io.trino.spi.metrics.Metrics;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.tracing.TrinoAttributes;
import io.trino.util.Failures;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongFunction;
import java.util.function.Supplier;
import org.joda.time.DateTime;

@ThreadSafe
/* loaded from: input_file:io/trino/execution/StageStateMachine.class */
public class StageStateMachine {
    private static final Logger log = Logger.get(StageStateMachine.class);
    private final StageId stageId;
    private final PlanFragment fragment;
    private final Map<PlanNodeId, TableInfo> tables;
    private final SplitSchedulerStats scheduledStats;
    private final StateMachine<StageState> stageState;
    private final StateMachine<Optional<StageInfo>> finalStageInfo;
    private final Span stageSpan;
    private final AtomicReference<ExecutionFailureInfo> failureCause = new AtomicReference<>();
    private final AtomicReference<DateTime> schedulingComplete = new AtomicReference<>();
    private final Distribution getSplitDistribution = new Distribution();
    private final AtomicLong peakUserMemory = new AtomicLong();
    private final AtomicLong peakRevocableMemory = new AtomicLong();
    private final AtomicLong currentUserMemory = new AtomicLong();
    private final AtomicLong currentRevocableMemory = new AtomicLong();
    private final AtomicLong currentTotalMemory = new AtomicLong();

    public StageStateMachine(StageId stageId, PlanFragment planFragment, Map<PlanNodeId, TableInfo> map, Executor executor, Tracer tracer, Span span, SplitSchedulerStats splitSchedulerStats) {
        this.stageId = (StageId) Objects.requireNonNull(stageId, "stageId is null");
        this.fragment = (PlanFragment) Objects.requireNonNull(planFragment, "fragment is null");
        this.tables = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "tables is null"));
        this.scheduledStats = (SplitSchedulerStats) Objects.requireNonNull(splitSchedulerStats, "schedulerStats is null");
        this.stageState = new StateMachine<>("stage " + String.valueOf(stageId), executor, StageState.PLANNED, StageState.TERMINAL_STAGE_STATES);
        this.stageState.addStateChangeListener(stageState -> {
            log.debug("Stage %s is %s", new Object[]{stageId, stageState});
        });
        this.finalStageInfo = new StateMachine<>("final stage " + String.valueOf(stageId), executor, Optional.empty());
        this.stageSpan = tracer.spanBuilder("stage").setParent(Context.current().with(span)).setAttribute(TrinoAttributes.QUERY_ID, stageId.getQueryId().toString()).setAttribute(TrinoAttributes.STAGE_ID, stageId.toString()).startSpan();
        this.stageState.addStateChangeListener(stageState2 -> {
            this.stageSpan.addEvent("stage_state", Attributes.of(TrinoAttributes.EVENT_STATE, stageState2.toString()));
            if (stageState2.isDone()) {
                this.stageSpan.end();
            }
        });
    }

    public StageId getStageId() {
        return this.stageId;
    }

    public StageState getState() {
        return this.stageState.get();
    }

    public PlanFragment getFragment() {
        return this.fragment;
    }

    public Span getStageSpan() {
        return this.stageSpan;
    }

    public void addStateChangeListener(StateMachine.StateChangeListener<StageState> stateChangeListener) {
        this.stageState.addStateChangeListener(stateChangeListener);
    }

    public boolean transitionToScheduling() {
        return this.stageState.compareAndSet(StageState.PLANNED, StageState.SCHEDULING);
    }

    public boolean transitionToRunning() {
        this.schedulingComplete.compareAndSet(null, DateTime.now());
        return this.stageState.setIf(StageState.RUNNING, stageState -> {
            return (stageState == StageState.RUNNING || stageState.isDone()) ? false : true;
        });
    }

    public boolean transitionToPending() {
        return this.stageState.setIf(StageState.PENDING, stageState -> {
            return (stageState == StageState.PENDING || stageState.isDone()) ? false : true;
        });
    }

    public boolean transitionToFinished() {
        return this.stageState.setIf(StageState.FINISHED, stageState -> {
            return !stageState.isDone();
        });
    }

    public boolean transitionToAborted() {
        return this.stageState.setIf(StageState.ABORTED, stageState -> {
            return !stageState.isDone();
        });
    }

    public boolean transitionToFailed(Throwable th) {
        Objects.requireNonNull(th, "throwable is null");
        this.failureCause.compareAndSet(null, Failures.toFailure(th));
        boolean z = this.stageState.setIf(StageState.FAILED, stageState -> {
            return !stageState.isDone();
        });
        if (z) {
            log.debug(th, "Stage %s failed", new Object[]{this.stageId});
        } else {
            log.debug(th, "Failure after stage %s finished", new Object[]{this.stageId});
        }
        return z;
    }

    public void addFinalStageInfoListener(StateMachine.StateChangeListener<StageInfo> stateChangeListener) {
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        this.finalStageInfo.addStateChangeListener(optional -> {
            if (optional.isPresent() && atomicBoolean.compareAndSet(false, true)) {
                stateChangeListener.stateChanged((StageInfo) optional.get());
            }
        });
    }

    public void setAllTasksFinal(Iterable<TaskInfo> iterable) {
        Objects.requireNonNull(iterable, "finalTaskInfos is null");
        Preconditions.checkState(this.stageState.get().isDone());
        StageInfo stageInfo = getStageInfo(() -> {
            return iterable;
        });
        Preconditions.checkArgument(stageInfo.isFinalStageInfo(), "finalTaskInfos are not all done");
        this.finalStageInfo.compareAndSet(Optional.empty(), Optional.of(stageInfo));
    }

    public long getUserMemoryReservation() {
        return this.currentUserMemory.get();
    }

    public long getTotalMemoryReservation() {
        return this.currentTotalMemory.get();
    }

    public void updateMemoryUsage(long j, long j2, long j3) {
        this.currentUserMemory.addAndGet(j);
        this.currentRevocableMemory.addAndGet(j2);
        this.currentTotalMemory.addAndGet(j3);
        this.peakUserMemory.updateAndGet(j4 -> {
            return Math.max(this.currentUserMemory.get(), j4);
        });
        this.peakRevocableMemory.updateAndGet(j5 -> {
            return Math.max(this.currentRevocableMemory.get(), j5);
        });
    }

    public BasicStageStats getBasicStageStats(Supplier<Iterable<TaskInfo>> supplier) {
        Optional<StageInfo> optional = this.finalStageInfo.get();
        if (optional.isPresent()) {
            return optional.get().getStageStats().toBasicStageStats(optional.get().getState());
        }
        StageState stageState = this.stageState.get();
        boolean z = stageState == StageState.RUNNING || stageState == StageState.PENDING || stageState.isDone();
        ImmutableList<TaskInfo> copyOf = ImmutableList.copyOf(supplier.get());
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = 0;
        long j7 = 0;
        long j8 = 0;
        long j9 = 0;
        long j10 = 0;
        long j11 = 0;
        long j12 = 0;
        long j13 = 0;
        long j14 = 0;
        long j15 = 0;
        boolean z2 = true;
        HashSet hashSet = new HashSet();
        for (TaskInfo taskInfo : copyOf) {
            TaskState state = taskInfo.taskStatus().getState();
            TaskStats stats = taskInfo.stats();
            boolean z3 = state == TaskState.FAILED || state == TaskState.FAILING;
            if (z3) {
                i++;
            }
            i2 += stats.getTotalDrivers();
            i3 += stats.getQueuedDrivers();
            i4 += stats.getRunningDrivers();
            i5 += stats.getCompletedDrivers();
            i6 += stats.getBlockedDrivers();
            d += stats.getCumulativeUserMemory();
            if (z3) {
                d2 += stats.getCumulativeUserMemory();
            }
            long bytes = stats.getUserMemoryReservation().toBytes();
            j += bytes;
            j2 += bytes + stats.getRevocableMemoryReservation().toBytes();
            j3 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
            j5 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            if (z3) {
                j4 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
                j6 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            }
            if (!state.isDone()) {
                z2 &= stats.isFullyBlocked();
                hashSet.addAll(stats.getBlockedReasons());
            }
            j7 += stats.getPhysicalInputDataSize().toBytes();
            j8 += stats.getPhysicalInputPositions();
            j9 += stats.getPhysicalInputReadTime().roundTo(TimeUnit.NANOSECONDS);
            j10 += stats.getPhysicalWrittenDataSize().toBytes();
            j11 += stats.getInternalNetworkInputDataSize().toBytes();
            j12 += stats.getInternalNetworkInputPositions();
            if (this.fragment.containsTableScanNode()) {
                j13 += stats.getRawInputDataSize().toBytes();
                j14 += stats.getRawInputPositions();
            }
            j15 += stats.getPipelines().stream().flatMap(pipelineStats -> {
                return pipelineStats.getOperatorSummaries().stream();
            }).mapToLong(operatorStats -> {
                return operatorStats.getSpilledDataSize().toBytes();
            }).sum();
        }
        OptionalDouble empty = OptionalDouble.empty();
        if (z && i2 != 0) {
            empty = OptionalDouble.of(Math.min(100.0d, (i5 * 100.0d) / i2));
        }
        OptionalDouble empty2 = OptionalDouble.empty();
        if (z && i2 != 0) {
            empty2 = OptionalDouble.of(Math.min(100.0d, (i4 * 100.0d) / i2));
        }
        return new BasicStageStats(z, i, i2, i3, i4, i5, i6, DataSize.succinctBytes(j7), j8, new Duration(j9, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), DataSize.succinctBytes(j10), DataSize.succinctBytes(j11), j12, DataSize.succinctBytes(j13), j14, DataSize.succinctBytes(j15), d, d2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), new Duration(j5, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(j6, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(j3, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), new Duration(j4, TimeUnit.NANOSECONDS).convertToMostSuccinctTimeUnit(), z2, hashSet, empty, empty2);
    }

    public StageInfo getStageInfo(Supplier<Iterable<TaskInfo>> supplier) {
        Optional<StageInfo> optional = this.finalStageInfo.get();
        if (optional.isPresent()) {
            return optional.get();
        }
        StageState stageState = this.stageState.get();
        ImmutableList<TaskInfo> copyOf = ImmutableList.copyOf(supplier.get());
        int size = copyOf.size();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        int i8 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        long j = this.currentUserMemory.get();
        long j2 = this.currentRevocableMemory.get();
        long j3 = this.currentTotalMemory.get();
        long j4 = this.peakUserMemory.get();
        long j5 = this.peakRevocableMemory.get();
        long j6 = 0;
        long j7 = 0;
        long j8 = 0;
        long j9 = 0;
        long j10 = 0;
        long j11 = 0;
        long j12 = 0;
        long j13 = 0;
        long j14 = 0;
        long j15 = 0;
        long j16 = 0;
        long j17 = 0;
        long j18 = 0;
        long j19 = 0;
        long j20 = 0;
        long j21 = 0;
        long j22 = 0;
        long j23 = 0;
        long j24 = 0;
        long j25 = 0;
        long j26 = 0;
        long j27 = 0;
        long j28 = 0;
        long j29 = 0;
        long j30 = 0;
        long j31 = 0;
        ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(copyOf.size());
        long j32 = 0;
        long j33 = 0;
        long j34 = 0;
        long j35 = 0;
        Metrics.Accumulator accumulator = Metrics.accumulator();
        long j36 = 0;
        long j37 = 0;
        long j38 = 0;
        long j39 = 0;
        int i9 = 0;
        int i10 = 0;
        int i11 = 0;
        int i12 = 0;
        int i13 = 0;
        boolean z = true;
        HashSet hashSet = new HashSet();
        int i14 = 0;
        for (TaskInfo taskInfo : copyOf) {
            TaskState state = taskInfo.taskStatus().getState();
            if (state.isDone()) {
                i2++;
            } else {
                i++;
            }
            boolean z2 = state == TaskState.FAILED || state == TaskState.FAILING;
            if (z2) {
                i3++;
            }
            TaskStats stats = taskInfo.stats();
            i4 += stats.getTotalDrivers();
            i5 += stats.getQueuedDrivers();
            i6 += stats.getRunningDrivers();
            i7 += stats.getBlockedDrivers();
            i8 += stats.getCompletedDrivers();
            d += stats.getCumulativeUserMemory();
            if (z2) {
                d2 += stats.getCumulativeUserMemory();
            }
            j6 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
            j8 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            j10 += stats.getTotalBlockedTime().roundTo(TimeUnit.NANOSECONDS);
            if (z2) {
                j7 += stats.getTotalScheduledTime().roundTo(TimeUnit.NANOSECONDS);
                j9 += stats.getTotalCpuTime().roundTo(TimeUnit.NANOSECONDS);
            }
            if (!state.isDone()) {
                z &= stats.isFullyBlocked();
                hashSet.addAll(stats.getBlockedReasons());
            }
            j11 += stats.getPhysicalInputDataSize().toBytes();
            j13 += stats.getPhysicalInputPositions();
            j15 += stats.getPhysicalInputReadTime().roundTo(TimeUnit.NANOSECONDS);
            j17 += stats.getInternalNetworkInputDataSize().toBytes();
            j19 += stats.getInternalNetworkInputPositions();
            j21 += stats.getRawInputDataSize().toBytes();
            j23 += stats.getRawInputPositions();
            j25 += stats.getProcessedInputDataSize().toBytes();
            j27 += stats.getProcessedInputPositions();
            j29 += stats.getInputBlockedTime().roundTo(TimeUnit.NANOSECONDS);
            j31 += taskInfo.outputBuffers().getTotalBufferedBytes();
            Optional<Metrics> metrics = taskInfo.outputBuffers().getMetrics();
            Optional<TDigestHistogram> utilization = taskInfo.outputBuffers().getUtilization();
            Objects.requireNonNull(builderWithExpectedSize);
            utilization.ifPresent((v1) -> {
                r1.add(v1);
            });
            j32 += stats.getOutputDataSize().toBytes();
            j34 += stats.getOutputPositions();
            Objects.requireNonNull(accumulator);
            metrics.ifPresent(accumulator::add);
            j36 += stats.getOutputBlockedTime().roundTo(TimeUnit.NANOSECONDS);
            j38 += stats.getPhysicalWrittenDataSize().toBytes();
            if (z2) {
                j12 += stats.getPhysicalInputDataSize().toBytes();
                j14 += stats.getPhysicalInputPositions();
                j16 += stats.getPhysicalInputReadTime().roundTo(TimeUnit.NANOSECONDS);
                j18 += stats.getInternalNetworkInputDataSize().toBytes();
                j20 += stats.getInternalNetworkInputPositions();
                j22 += stats.getRawInputDataSize().toBytes();
                j24 += stats.getRawInputPositions();
                j26 += stats.getProcessedInputDataSize().toBytes();
                j28 += stats.getProcessedInputPositions();
                j30 += stats.getInputBlockedTime().roundTo(TimeUnit.NANOSECONDS);
                j33 += stats.getOutputDataSize().toBytes();
                j35 += stats.getOutputPositions();
                j39 += stats.getPhysicalWrittenDataSize().toBytes();
                j37 += stats.getOutputBlockedTime().roundTo(TimeUnit.NANOSECONDS);
            }
            i9 += stats.getFullGcCount();
            i10 += stats.getFullGcCount() > 0 ? 1 : 0;
            int intExact = Math.toIntExact(stats.getFullGcTime().roundTo(TimeUnit.SECONDS));
            i13 += intExact;
            i11 = Math.min(i11, intExact);
            i12 = Math.max(i12, intExact);
            int i15 = 0;
            Iterator<PipelineStats> it = stats.getPipelines().iterator();
            while (it.hasNext()) {
                i15 += it.next().getOperatorSummaries().size();
            }
            i14 = Math.max(i15, i14);
        }
        return new StageInfo(this.stageId, stageState, this.fragment, this.fragment.getPartitioning().isCoordinatorOnly(), this.fragment.getTypes(), new StageStats(this.schedulingComplete.get(), this.getSplitDistribution.snapshot(), size, i, i2, i3, i4, i5, i6, i7, i8, d, d2, DataSize.succinctBytes(j), DataSize.succinctBytes(j2), DataSize.succinctBytes(j3), DataSize.succinctBytes(j4), DataSize.succinctBytes(j5), Duration.succinctDuration(j6, TimeUnit.NANOSECONDS), Duration.succinctDuration(j7, TimeUnit.NANOSECONDS), Duration.succinctDuration(j8, TimeUnit.NANOSECONDS), Duration.succinctDuration(j9, TimeUnit.NANOSECONDS), Duration.succinctDuration(j10, TimeUnit.NANOSECONDS), z && i > 0, hashSet, DataSize.succinctBytes(j11), DataSize.succinctBytes(j12), j13, j14, Duration.succinctDuration(j15, TimeUnit.NANOSECONDS), Duration.succinctDuration(j16, TimeUnit.NANOSECONDS), DataSize.succinctBytes(j17), DataSize.succinctBytes(j18), j19, j20, DataSize.succinctBytes(j21), DataSize.succinctBytes(j22), j23, j24, DataSize.succinctBytes(j25), DataSize.succinctBytes(j26), j27, j28, Duration.succinctDuration(j29, TimeUnit.NANOSECONDS), Duration.succinctDuration(j30, TimeUnit.NANOSECONDS), DataSize.succinctBytes(j31), TDigestHistogram.merge(builderWithExpectedSize.build()).map((v1) -> {
            return new DistributionSnapshot(v1);
        }), DataSize.succinctBytes(j32), DataSize.succinctBytes(j33), j34, j35, accumulator.get(), Duration.succinctDuration(j36, TimeUnit.NANOSECONDS), Duration.succinctDuration(j37, TimeUnit.NANOSECONDS), DataSize.succinctBytes(j38), DataSize.succinctBytes(j39), new StageGcStatistics(this.stageId.getId(), size, i10, i11, i12, i13, (int) ((1.0d * i13) / i9)), i14 == 0 ? ImmutableList.of() : combineTaskOperatorSummaries(copyOf, i14)), copyOf, ImmutableList.of(), this.tables, stageState == StageState.FAILED ? this.failureCause.get() : null);
    }

    public BasicStageInfo getBasicStageInfo(Supplier<Iterable<TaskInfo>> supplier) {
        Optional<StageInfo> optional = this.finalStageInfo.get();
        return optional.isPresent() ? new BasicStageInfo(optional.get()) : new BasicStageInfo(this.stageId, this.stageState.get(), this.fragment.getPartitioning().isCoordinatorOnly(), getBasicStageStats(supplier), ImmutableList.of(), ImmutableList.copyOf(supplier.get()));
    }

    private static List<OperatorStats> combineTaskOperatorSummaries(List<TaskInfo> list, int i) {
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap(i);
        int size = list.size();
        LongFunction longFunction = j -> {
            return new ArrayList(size);
        };
        Iterator<TaskInfo> it = list.iterator();
        while (it.hasNext()) {
            for (PipelineStats pipelineStats : it.next().stats().getPipelines()) {
                long unsignedLong = Integer.toUnsignedLong(pipelineStats.getPipelineId()) << 32;
                for (OperatorStats operatorStats : pipelineStats.getOperatorSummaries()) {
                    ((List) long2ObjectOpenHashMap.computeIfAbsent(unsignedLong | Integer.toUnsignedLong(operatorStats.getOperatorId()), longFunction)).add(operatorStats);
                }
            }
        }
        ImmutableList.Builder builderWithExpectedSize = ImmutableList.builderWithExpectedSize(long2ObjectOpenHashMap.size());
        ObjectIterator it2 = long2ObjectOpenHashMap.values().iterator();
        while (it2.hasNext()) {
            List list2 = (List) it2.next();
            OperatorStats operatorStats2 = (OperatorStats) list2.get(0);
            if (list2.size() > 1) {
                operatorStats2 = operatorStats2.add(list2.subList(1, list2.size()));
            }
            builderWithExpectedSize.add(operatorStats2);
        }
        return builderWithExpectedSize.build();
    }

    public void recordGetSplitTime(long j) {
        long nanoTime = System.nanoTime() - j;
        this.getSplitDistribution.add(nanoTime);
        this.scheduledStats.getGetSplitTime().add(nanoTime, TimeUnit.NANOSECONDS);
    }

    public String toString() {
        return MoreObjects.toStringHelper(this).add("stageId", this.stageId).add("stageState", this.stageState).toString();
    }
}
