package org.apache.beam.fn.harness.state;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Objects;
import java.util.function.Function;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.transforms.Materializations;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;

/* loaded from: input_file:org/apache/beam/fn/harness/state/MultimapSideInput.class */
public class MultimapSideInput<K, V> implements Materializations.MultimapView<K, V> {
    private static final int BULK_READ_SIZE = 100;
    private final Cache<?, ?> cache;
    private final BeamFnStateClient beamFnStateClient;
    private final BeamFnApi.StateRequest keysRequest;
    private final Coder<K> keyCoder;
    private final Coder<V> valueCoder;
    private volatile Function<ByteString, Iterable<V>> bulkReadResult;
    private final boolean useBulkRead;

    public MultimapSideInput(Cache<?, ?> cache, BeamFnStateClient beamFnStateClient, String str, BeamFnApi.StateKey stateKey, Coder<K> coder, Coder<V> coder2, boolean z) {
        Preconditions.checkArgument(stateKey.hasMultimapKeysSideInput(), "Expected MultimapKeysSideInput StateKey but received %s.", stateKey);
        this.cache = cache;
        this.beamFnStateClient = beamFnStateClient;
        this.keysRequest = BeamFnApi.StateRequest.newBuilder().setInstructionId(str).setStateKey(stateKey).build();
        this.keyCoder = coder;
        this.valueCoder = coder2;
        this.useBulkRead = z;
    }

    public Iterable<K> get() {
        return (Iterable<K>) StateFetchingIterators.readAllAndDecodeStartingFrom(this.cache, this.beamFnStateClient, this.keysRequest, this.keyCoder);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Iterable<V> get(K k) {
        ByteString encodeKey = encodeKey(k);
        if (this.useBulkRead) {
            if (this.bulkReadResult == null) {
                synchronized (this) {
                    if (this.bulkReadResult == null) {
                        HashMap hashMap = new HashMap();
                        try {
                            PrefetchableIterator it = StateFetchingIterators.readAllAndDecodeStartingFrom(Caches.noop(), this.beamFnStateClient, this.keysRequest.toBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setMultimapKeysValuesSideInput(BeamFnApi.StateKey.MultimapKeysValuesSideInput.newBuilder().setTransformId(this.keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId()).setSideInputId(this.keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()).setWindow(this.keysRequest.getStateKey().getMultimapKeysSideInput().getWindow())).build()).build(), KvCoder.of(this.keyCoder, IterableCoder.of(this.valueCoder))).iterator();
                            while (hashMap.size() < BULK_READ_SIZE && it.hasNext()) {
                                KV kv = (KV) it.next();
                                hashMap.put(encodeKey(kv.getKey()), (Iterable) kv.getValue());
                            }
                            if (it.hasNext()) {
                                Objects.requireNonNull(hashMap);
                                this.bulkReadResult = (v1) -> {
                                    return r1.get(v1);
                                };
                            } else {
                                this.bulkReadResult = byteString -> {
                                    Iterable iterable = (Iterable) hashMap.get(byteString);
                                    return iterable == null ? Collections.emptyList() : iterable;
                                };
                            }
                        } catch (Exception e) {
                            Objects.requireNonNull(hashMap);
                            this.bulkReadResult = (v1) -> {
                                return r1.get(v1);
                            };
                        }
                    }
                }
            }
            Iterable<V> apply = this.bulkReadResult.apply(encodeKey);
            if (apply != null) {
                return apply;
            }
        }
        return (Iterable<V>) StateFetchingIterators.readAllAndDecodeStartingFrom(Caches.subCache(this.cache, "ValuesForKey", encodeKey), this.beamFnStateClient, this.keysRequest.toBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setMultimapSideInput(BeamFnApi.StateKey.MultimapSideInput.newBuilder().setTransformId(this.keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId()).setSideInputId(this.keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()).setWindow(this.keysRequest.getStateKey().getMultimapKeysSideInput().getWindow()).setKey(encodeKey)).build()).build(), this.valueCoder);
    }

    private ByteString encodeKey(K k) {
        ByteStringOutputStream byteStringOutputStream = new ByteStringOutputStream();
        try {
            this.keyCoder.encode(k, byteStringOutputStream);
            return byteStringOutputStream.toByteString();
        } catch (IOException e) {
            throw new IllegalStateException(String.format("Failed to encode key %s for side input id %s.", k, this.keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId()), e);
        }
    }
}
