package org.apache.beam.runners.spark.translation.streaming;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StatefulDoFnRunner;
import org.apache.beam.runners.core.StepContext;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator;
import org.apache.beam.runners.spark.stateful.SparkStateInternals;
import org.apache.beam.runners.spark.stateful.SparkTimerInternals;
import org.apache.beam.runners.spark.stateful.StateAndTimers;
import org.apache.beam.runners.spark.translation.DoFnRunnerWithMetrics;
import org.apache.beam.runners.spark.translation.SparkInputDataProcessor;
import org.apache.beam.runners.spark.translation.SparkProcessContext;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.runners.spark.util.CachedSideInputReader;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.runners.spark.util.SideInputBroadcast;
import org.apache.beam.runners.spark.util.SparkSideInputReader;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.spark.streaming.State;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Tuple2;
import scala.runtime.AbstractFunction3;

/* loaded from: input_file:org/apache/beam/runners/spark/translation/streaming/ParDoStateUpdateFn.class */
public class ParDoStateUpdateFn<KeyT, ValueT, InputT extends KV<KeyT, ValueT>, OutputT> extends AbstractFunction3<ByteArray, Option<byte[]>, State<StateAndTimers>, List<Tuple2<TupleTag<?>, byte[]>>> implements Serializable {
    private static final Logger LOG = LoggerFactory.getLogger(ParDoStateUpdateFn.class);
    private final MetricsContainerStepMapAccumulator metricsAccum;
    private final String stepName;
    private final DoFn<InputT, OutputT> doFn;
    private final Coder<KeyT> keyCoder;
    private final WindowedValue.FullWindowedValueCoder<ValueT> wvCoder;
    private transient boolean wasSetupCalled;
    private final SerializablePipelineOptions options;
    private final TupleTag<?> mainOutputTag;
    private final List<TupleTag<?>> additionalOutputTags;
    private final Coder<InputT> inputCoder;
    private final Map<TupleTag<?>, Coder<?>> outputCoders;
    private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs;
    private final WindowingStrategy<?, ?> windowingStrategy;
    private final DoFnSchemaInformation doFnSchemaInformation;
    private final Map<String, PCollectionView<?>> sideInputMapping;
    private final Map<Integer, GlobalWatermarkHolder.SparkWatermarks> watermarks;
    private final List<Integer> sourceIds;
    private final TimerInternals.TimerDataCoderV2 timerDataCoder;

    public ParDoStateUpdateFn(MetricsContainerStepMapAccumulator metricsContainerStepMapAccumulator, String str, DoFn<InputT, OutputT> doFn, Coder<KeyT> coder, WindowedValue.FullWindowedValueCoder<ValueT> fullWindowedValueCoder, SerializablePipelineOptions serializablePipelineOptions, TupleTag<?> tupleTag, List<TupleTag<?>> list, Coder<InputT> coder2, Map<TupleTag<?>, Coder<?>> map, Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> map2, WindowingStrategy<?, ?> windowingStrategy, DoFnSchemaInformation doFnSchemaInformation, Map<String, PCollectionView<?>> map3, Map<Integer, GlobalWatermarkHolder.SparkWatermarks> map4, List<Integer> list2) {
        this.metricsAccum = metricsContainerStepMapAccumulator;
        this.stepName = str;
        this.doFn = SerializableUtils.clone(doFn);
        this.options = serializablePipelineOptions;
        this.mainOutputTag = tupleTag;
        this.additionalOutputTags = list;
        this.keyCoder = coder;
        this.inputCoder = coder2;
        this.outputCoders = map;
        this.wvCoder = fullWindowedValueCoder;
        this.sideInputs = map2;
        this.windowingStrategy = windowingStrategy;
        this.doFnSchemaInformation = doFnSchemaInformation;
        this.sideInputMapping = map3;
        this.watermarks = map4;
        this.sourceIds = list2;
        this.timerDataCoder = TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder());
    }

    public List<Tuple2<TupleTag<?>, byte[]>> apply(ByteArray byteArray, Option<byte[]> option, State<StateAndTimers> state) {
        SparkStateInternals forKey;
        if (option.isEmpty()) {
            return Lists.newArrayList();
        }
        final SparkTimerInternals forStreamFromSources = SparkTimerInternals.forStreamFromSources(this.sourceIds, this.watermarks);
        Object fromByteArray = CoderHelpers.fromByteArray(byteArray.getValue(), this.keyCoder);
        if (state.exists()) {
            StateAndTimers stateAndTimers = (StateAndTimers) state.get();
            forKey = SparkStateInternals.forKeyAndState(fromByteArray, stateAndTimers.getState());
            forStreamFromSources.addTimers(SparkTimerInternals.deserializeTimers(stateAndTimers.getTimers(), this.timerDataCoder));
        } else {
            forKey = SparkStateInternals.forKey(fromByteArray);
        }
        WindowedValue windowedValue = (WindowedValue) CoderHelpers.fromByteArray((byte[]) option.get(), this.wvCoder);
        WindowedValue withValue = windowedValue.withValue(KV.of(fromByteArray, windowedValue.getValue()));
        if (!this.wasSetupCalled) {
            DoFnInvokers.tryInvokeSetupFor(this.doFn, this.options.get());
            this.wasSetupCalled = true;
        }
        SparkInputDataProcessor createUnbounded = SparkInputDataProcessor.createUnbounded();
        final SparkStateInternals sparkStateInternals = forKey;
        StepContext stepContext = new StepContext() { // from class: org.apache.beam.runners.spark.translation.streaming.ParDoStateUpdateFn.1
            public StateInternals stateInternals() {
                return sparkStateInternals;
            }

            public TimerInternals timerInternals() {
                return forStreamFromSources;
            }
        };
        DoFnRunner simpleRunner = DoFnRunners.simpleRunner(this.options.get(), this.doFn, CachedSideInputReader.of(new SparkSideInputReader(this.sideInputs)), createUnbounded.getOutputManager(), this.mainOutputTag, this.additionalOutputTags, stepContext, this.inputCoder, this.outputCoders, this.windowingStrategy, this.doFnSchemaInformation, this.sideInputMapping);
        Coder windowCoder = this.windowingStrategy.getWindowFn().windowCoder();
        Iterator<OutputT> createOutputIterator = createUnbounded.createOutputIterator(Lists.newArrayList(new WindowedValue[]{withValue}).iterator(), new SparkProcessContext(this.stepName, this.doFn, new DoFnRunnerWithMetrics(this.stepName, DoFnRunners.defaultStatefulDoFnRunner(this.doFn, this.inputCoder, simpleRunner, stepContext, this.windowingStrategy, new StatefulDoFnRunner.TimeInternalsCleanupTimer(forStreamFromSources, this.windowingStrategy), new StatefulDoFnRunner.StateInternalsStateCleaner(this.doFn, forKey, windowCoder)), this.metricsAccum), fromByteArray, forStreamFromSources.getTimers().iterator()));
        state.update(StateAndTimers.of(forKey.getState(), SparkTimerInternals.serializeTimers(forStreamFromSources.getTimers(), this.timerDataCoder)));
        return (List) Lists.newArrayList(createOutputIterator).stream().map(tuple2 -> {
            TupleTag tupleTag = (TupleTag) tuple2._1();
            return Tuple2.apply(tupleTag, CoderHelpers.toByteArray((WindowedValue) tuple2._2(), WindowedValue.FullWindowedValueCoder.of(this.outputCoders.get(tupleTag), windowCoder)));
        }).collect(Collectors.toList());
    }
}
