package org.apache.beam.fn.harness.state;

import com.google.auto.value.AutoValue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.WeightedList;
import org.apache.beam.sdk.fn.stream.DataStreams;
import org.apache.beam.sdk.fn.stream.PrefetchableIterables;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
import org.checkerframework.dataflow.qual.Pure;

/* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators.class */
public class StateFetchingIterators {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable.class */
    public static class CachingStateIterable<T> extends PrefetchableIterables.Default<T> {
        private final Cache<IterableCacheKey, Blocks<T>> cache;
        private final BeamFnStateClient beamFnStateClient;
        private final BeamFnApi.StateRequest stateRequestForFirstChunk;
        private final Coder<T> valueCoder;

        /* JADX INFO: Access modifiers changed from: package-private */
        @AutoValue
        /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable$Block.class */
        public static abstract class Block<T> implements Weighted {
            public static <T> Block<T> mutatedBlock(List<T> list, long j) {
                return mutatedBlock(new WeightedList(list, j));
            }

            public static <T> Block<T> mutatedBlock(WeightedList<T> weightedList) {
                return new AutoValue_StateFetchingIterators_CachingStateIterable_Block(weightedList.getBacking(), null, weightedList.getWeight());
            }

            public static <T> Block<T> fromValues(List<T> list, @Nullable ByteString byteString) {
                return fromValues(new WeightedList(list, Caches.weigh(list)), byteString);
            }

            public static <T> Block<T> fromValues(WeightedList<T> weightedList, @Nullable ByteString byteString) {
                return new AutoValue_StateFetchingIterators_CachingStateIterable_Block(weightedList.getBacking(), byteString, weightedList.getWeight() + Caches.weigh(byteString));
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            public abstract List<T> getValues();

            /* JADX INFO: Access modifiers changed from: package-private */
            @Nullable
            public abstract ByteString getNextToken();

            public abstract long getWeight();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable$Blocks.class */
        public static abstract class Blocks<T> implements Weighted {
            Blocks() {
            }

            public abstract List<Block<T>> getBlocks();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable$BlocksPrefix.class */
        public static class BlocksPrefix<T> extends Blocks<T> implements Cache.Shrinkable<BlocksPrefix<T>> {
            private final List<Block<T>> blocks;

            public long getWeight() {
                return CachingStateIterable.sumWeight(this.blocks);
            }

            BlocksPrefix(List<Block<T>> list) {
                this.blocks = list;
            }

            @Override // org.apache.beam.fn.harness.Cache.Shrinkable
            public BlocksPrefix<T> shrink() {
                ArrayList arrayList = new ArrayList(getBlocks().subList(0, getBlocks().size() / 2));
                if (arrayList.isEmpty()) {
                    return null;
                }
                return new BlocksPrefix<>(arrayList);
            }

            @Override // org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable.Blocks
            public List<Block<T>> getBlocks() {
                return this.blocks;
            }
        }

        /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable$CachingStateIterator.class */
        class CachingStateIterator implements PrefetchableIterator<T> {
            private final LazyBlockingStateFetchingIterator underlyingStateFetchingIterator;
            private final DataStreams.DataStreamDecoder<T> dataStreamDecoder;
            private Block<T> currentBlock;
            private int currentCachedBlockValueIndex = 0;

            public CachingStateIterator() {
                this.underlyingStateFetchingIterator = new LazyBlockingStateFetchingIterator(CachingStateIterable.this.beamFnStateClient, CachingStateIterable.this.stateRequestForFirstChunk);
                this.dataStreamDecoder = new DataStreams.DataStreamDecoder<>(CachingStateIterable.this.valueCoder, this.underlyingStateFetchingIterator);
                this.currentBlock = Block.fromValues(new WeightedList(Collections.emptyList(), 0L), CachingStateIterable.this.stateRequestForFirstChunk.getGet().getContinuationToken());
            }

            @Override // org.apache.beam.sdk.fn.stream.PrefetchableIterator
            public boolean isReady() {
                while (this.currentBlock.getValues().size() <= this.currentCachedBlockValueIndex && this.currentBlock.getNextToken() != null) {
                    Blocks blocks = (Blocks) CachingStateIterable.this.cache.peek(IterableCacheKey.INSTANCE);
                    boolean equals = ByteString.EMPTY.equals(this.currentBlock.getNextToken());
                    if (blocks == null) {
                        return false;
                    }
                    if (equals) {
                        this.currentBlock = blocks.getBlocks().get(0);
                        this.currentCachedBlockValueIndex = 0;
                    } else {
                        List<Block<T>> blocks2 = blocks.getBlocks();
                        int i = 0;
                        while (i < blocks2.size() && !this.currentBlock.getNextToken().equals(blocks2.get(i).getNextToken())) {
                            i++;
                        }
                        if (i + 1 >= blocks2.size()) {
                            return false;
                        }
                        this.currentBlock = blocks2.get(i + 1);
                        this.currentCachedBlockValueIndex = 0;
                    }
                }
                return true;
            }

            @Override // org.apache.beam.sdk.fn.stream.PrefetchableIterator
            public void prefetch() {
                if (isReady()) {
                    return;
                }
                this.underlyingStateFetchingIterator.seekToContinuationToken(this.currentBlock.getNextToken());
                this.underlyingStateFetchingIterator.prefetch();
            }

            @Override // java.util.Iterator
            @Pure
            public boolean hasNext() {
                while (this.currentBlock.getValues().size() <= this.currentCachedBlockValueIndex) {
                    if (this.currentBlock.getNextToken() == null) {
                        return false;
                    }
                    Blocks blocks = (Blocks) CachingStateIterable.this.cache.peek(IterableCacheKey.INSTANCE);
                    boolean equals = ByteString.EMPTY.equals(this.currentBlock.getNextToken());
                    if (blocks == null) {
                        this.currentBlock = loadNextBlock(this.currentBlock.getNextToken());
                        if (equals) {
                            CachingStateIterable.this.cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix(Collections.singletonList(this.currentBlock)));
                        }
                    } else if (equals) {
                        this.currentBlock = blocks.getBlocks().get(0);
                    } else {
                        Preconditions.checkState(blocks instanceof BlocksPrefix, "Unexpected blocks type %s, expected a %s.", blocks.getClass(), BlocksPrefix.class);
                        List<Block<T>> blocks2 = blocks.getBlocks();
                        int i = 0;
                        while (i < blocks2.size() && !this.currentBlock.getNextToken().equals(blocks2.get(i).getNextToken())) {
                            i++;
                        }
                        if (i + 1 < blocks2.size()) {
                            this.currentBlock = blocks2.get(i + 1);
                        } else {
                            this.currentBlock = loadNextBlock(this.currentBlock.getNextToken());
                            if (i == blocks2.size() - 1) {
                                ArrayList arrayList = new ArrayList(i + 1);
                                arrayList.addAll(blocks2);
                                arrayList.add(this.currentBlock);
                                CachingStateIterable.this.cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix(arrayList));
                            }
                        }
                    }
                    this.currentCachedBlockValueIndex = 0;
                }
                return true;
            }

            @VisibleForTesting
            Block<T> loadNextBlock(ByteString byteString) {
                this.underlyingStateFetchingIterator.seekToContinuationToken(byteString);
                WeightedList<T> decodeFromChunkBoundaryToChunkBoundary = this.dataStreamDecoder.decodeFromChunkBoundaryToChunkBoundary();
                ByteString continuationToken = this.underlyingStateFetchingIterator.getContinuationToken();
                if (ByteString.EMPTY.equals(continuationToken)) {
                    continuationToken = null;
                }
                return Block.fromValues(decodeFromChunkBoundaryToChunkBoundary, continuationToken);
            }

            @Override // java.util.Iterator
            public T next() {
                if (!hasNext()) {
                    throw new NoSuchElementException();
                }
                List<T> values = this.currentBlock.getValues();
                int i = this.currentCachedBlockValueIndex;
                this.currentCachedBlockValueIndex = i + 1;
                return values.get(i);
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$CachingStateIterable$MutatedBlocks.class */
        public static class MutatedBlocks<T> extends Blocks<T> {
            private final Block<T> wholeBlock;

            MutatedBlocks(Block<T> block) {
                this.wholeBlock = block;
            }

            @Override // org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable.Blocks
            public List<Block<T>> getBlocks() {
                return Collections.singletonList(this.wholeBlock);
            }

            public long getWeight() {
                return this.wholeBlock.getWeight();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static <T> long sumWeight(List<Block<T>> list) {
            long j = 0;
            for (int i = 0; i < list.size(); i++) {
                try {
                    j = Math.addExact(j, list.get(i).getWeight());
                } catch (ArithmeticException e) {
                    return Long.MAX_VALUE;
                }
            }
            return j;
        }

        public CachingStateIterable(Cache<IterableCacheKey, Blocks<T>> cache, BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequest, Coder<T> coder) {
            this.cache = cache;
            this.beamFnStateClient = beamFnStateClient;
            this.stateRequestForFirstChunk = stateRequest;
            this.valueCoder = coder;
        }

        public void remove(Set<Object> set) {
            Blocks<T> peek;
            if (set.isEmpty() || (peek = this.cache.peek(IterableCacheKey.INSTANCE)) == null) {
                return;
            }
            if (peek.getBlocks().get(peek.getBlocks().size() - 1).getNextToken() != null) {
                this.cache.remove(IterableCacheKey.INSTANCE);
            }
            List<Block<T>> blocks = peek.getBlocks();
            int i = 0;
            Iterator<Block<T>> it = blocks.iterator();
            while (it.hasNext()) {
                i += it.next().getValues().size();
            }
            WeightedList weightedList = new WeightedList(new ArrayList(i), 0L);
            for (Block<T> block : blocks) {
                boolean z = false;
                ArrayList arrayList = new ArrayList();
                for (T t : block.getValues()) {
                    if (set.contains(this.valueCoder.structuralValue(t))) {
                        z = true;
                    } else {
                        arrayList.add(t);
                    }
                }
                if (z) {
                    weightedList.addAll(arrayList, Caches.weigh(block.getValues()));
                } else {
                    weightedList.addAll(block.getValues(), block.getWeight());
                }
            }
            this.cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks(Block.mutatedBlock(weightedList)));
        }

        public void clearAndAppend(List<T> list) {
            clearAndAppend(new WeightedList<>(list, Caches.weigh(list)));
        }

        public void clearAndAppend(WeightedList<T> weightedList) {
            this.cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks(Block.mutatedBlock(weightedList)));
        }

        @Override // org.apache.beam.sdk.fn.stream.PrefetchableIterables.Default
        public PrefetchableIterator<T> createIterator() {
            return new CachingStateIterator();
        }

        public void append(List<T> list) {
            append(new WeightedList<>(list, Caches.weigh(list)));
        }

        public void append(WeightedList<T> weightedList) {
            Blocks<T> peek;
            if (weightedList.isEmpty() || (peek = this.cache.peek(IterableCacheKey.INSTANCE)) == null) {
                return;
            }
            if (peek.getBlocks().get(peek.getBlocks().size() - 1).getNextToken() != null) {
                this.cache.remove(IterableCacheKey.INSTANCE);
            }
            List<Block<T>> blocks = peek.getBlocks();
            int size = weightedList.size();
            Iterator<Block<T>> it = blocks.iterator();
            while (it.hasNext()) {
                size += it.next().getValues().size();
            }
            WeightedList weightedList2 = new WeightedList(new ArrayList(size), 0L);
            for (Block<T> block : blocks) {
                weightedList2.addAll(block.getValues(), block.getWeight());
            }
            weightedList2.addAll(weightedList);
            this.cache.put(IterableCacheKey.INSTANCE, new MutatedBlocks(Block.mutatedBlock(weightedList2)));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$IterableCacheKey.class */
    public static class IterableCacheKey implements Weighted {
        static final IterableCacheKey INSTANCE = new IterableCacheKey();

        private IterableCacheKey() {
        }

        public long getWeight() {
            return 0L;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIterators$LazyBlockingStateFetchingIterator.class */
    public static class LazyBlockingStateFetchingIterator implements PrefetchableIterator<ByteString> {
        private final BeamFnStateClient beamFnStateClient;
        private final BeamFnApi.StateRequest stateRequestForFirstChunk;
        private ByteString continuationToken;
        private CompletableFuture<BeamFnApi.StateResponse> prefetchedResponse;

        LazyBlockingStateFetchingIterator(BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequest) {
            this.beamFnStateClient = beamFnStateClient;
            this.stateRequestForFirstChunk = stateRequest;
            this.continuationToken = stateRequest.getGet().getContinuationToken();
        }

        @Nullable
        public ByteString getContinuationToken() {
            return this.continuationToken;
        }

        public void seekToContinuationToken(@Nullable ByteString byteString) {
            if (Objects.equals(this.continuationToken, byteString)) {
                return;
            }
            this.continuationToken = byteString;
            this.prefetchedResponse = null;
        }

        @Override // org.apache.beam.sdk.fn.stream.PrefetchableIterator
        public boolean isReady() {
            return this.prefetchedResponse == null ? this.continuationToken == null : this.prefetchedResponse.isDone();
        }

        @Override // org.apache.beam.sdk.fn.stream.PrefetchableIterator
        public void prefetch() {
            if (this.continuationToken == null || this.prefetchedResponse != null) {
                return;
            }
            this.prefetchedResponse = loadPrefetchedResponse(this.continuationToken);
        }

        public CompletableFuture<BeamFnApi.StateResponse> loadPrefetchedResponse(ByteString byteString) {
            return this.beamFnStateClient.handle(this.stateRequestForFirstChunk.toBuilder().setGet(BeamFnApi.StateGetRequest.newBuilder().setContinuationToken(byteString)));
        }

        @Override // java.util.Iterator
        @Pure
        public boolean hasNext() {
            return this.continuationToken != null;
        }

        @Override // java.util.Iterator
        public ByteString next() {
            if (!hasNext()) {
                throw new NoSuchElementException();
            }
            prefetch();
            try {
                BeamFnApi.StateResponse stateResponse = this.prefetchedResponse.get();
                this.prefetchedResponse = null;
                if (ByteString.EMPTY.equals(stateResponse.getGet().getContinuationToken())) {
                    this.continuationToken = null;
                } else {
                    this.continuationToken = stateResponse.getGet().getContinuationToken();
                    prefetch();
                }
                return stateResponse.getGet().getData();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException(e);
            } catch (ExecutionException e2) {
                if (e2.getCause() == null) {
                    throw new IllegalStateException(e2);
                }
                Throwables.throwIfUnchecked(e2.getCause());
                throw new IllegalStateException(e2.getCause());
            }
        }
    }

    private StateFetchingIterators() {
    }

    public static <T> CachingStateIterable<T> readAllAndDecodeStartingFrom(Cache<?, ?> cache, BeamFnStateClient beamFnStateClient, BeamFnApi.StateRequest stateRequest, Coder<T> coder) {
        return new CachingStateIterable<>(cache, beamFnStateClient, stateRequest, coder);
    }
}
