package org.apache.beam.runners.spark.translation;

import com.google.auto.service.AutoService;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.translation.TranslationUtils;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.construction.NativeTransforms;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.sdk.util.construction.graph.PipelineNode;
import org.apache.beam.sdk.util.construction.graph.QueryablePipeline;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BiMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.spark.HashPartitioner;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;

/* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.class */
public class SparkBatchPortablePipelineTranslator implements SparkPortablePipelineTranslator<SparkTranslationContext> {
    private final ImmutableMap<String, PTransformTranslator> urnToTransformTranslator;

    @AutoService({NativeTransforms.IsNativeTransform.class})
    /* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator$IsSparkNativeTransform.class */
    public static class IsSparkNativeTransform implements NativeTransforms.IsNativeTransform {
        public boolean test(RunnerApi.PTransform pTransform) {
            return "beam:transform:reshuffle:v1".equals(PTransformTranslation.urnForTransformOrNull(pTransform));
        }
    }

    /* loaded from: input_file:org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator$PTransformTranslator.class */
    interface PTransformTranslator {
        void translate(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext);
    }

    @Override // org.apache.beam.runners.spark.translation.SparkPortablePipelineTranslator
    public Set<String> knownUrns() {
        return this.urnToTransformTranslator.keySet();
    }

    public SparkBatchPortablePipelineTranslator() {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        builder.put("beam:transform:impulse:v1", SparkBatchPortablePipelineTranslator::translateImpulse);
        builder.put("beam:transform:group_by_key:v1", SparkBatchPortablePipelineTranslator::translateGroupByKey);
        builder.put("beam:runner:executable_stage:v1", SparkBatchPortablePipelineTranslator::translateExecutableStage);
        builder.put("beam:transform:flatten:v1", SparkBatchPortablePipelineTranslator::translateFlatten);
        builder.put("beam:transform:reshuffle:v1", SparkBatchPortablePipelineTranslator::translateReshuffle);
        this.urnToTransformTranslator = builder.build();
    }

    @Override // org.apache.beam.runners.spark.translation.SparkPortablePipelineTranslator
    public void translate(RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        QueryablePipeline forTransforms = QueryablePipeline.forTransforms(pipeline.getRootTransformIdsList(), pipeline.getComponents());
        for (PipelineNode.PTransformNode pTransformNode : forTransforms.getTopologicallyOrderedTransforms()) {
            Iterator it = pTransformNode.getTransform().getInputsMap().values().iterator();
            while (it.hasNext()) {
                sparkTranslationContext.incrementConsumptionCountBy((String) it.next(), 1);
            }
            if (pTransformNode.getTransform().getSpec().getUrn().equals("beam:runner:executable_stage:v1")) {
                sparkTranslationContext.incrementConsumptionCountBy(PipelineTranslatorUtils.getExecutableStageIntermediateId(pTransformNode), pTransformNode.getTransform().getOutputsMap().size());
            }
            for (String str : pTransformNode.getTransform().getOutputsMap().values()) {
                sparkTranslationContext.putCoder(str, PipelineTranslatorUtils.getWindowedValueCoder(str, pipeline.getComponents()));
            }
        }
        for (PipelineNode.PTransformNode pTransformNode2 : forTransforms.getTopologicallyOrderedTransforms()) {
            ((PTransformTranslator) this.urnToTransformTranslator.getOrDefault(pTransformNode2.getTransform().getSpec().getUrn(), SparkBatchPortablePipelineTranslator::urnNotFound)).translate(pTransformNode2, pipeline, sparkTranslationContext);
        }
    }

    private static void urnNotFound(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        throw new IllegalArgumentException(String.format("Transform %s has unknown URN %s", pTransformNode.getId(), pTransformNode.getTransform().getSpec().getUrn()));
    }

    private static void translateImpulse(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        sparkTranslationContext.pushDataset(PipelineTranslatorUtils.getOutputId(pTransformNode), new BoundedDataset(Collections.singletonList(new byte[0]), sparkTranslationContext.getSparkContext(), ByteArrayCoder.of()));
    }

    private static <K, V> void translateGroupByKey(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        RunnerApi.Components components = pipeline.getComponents();
        String inputId = PipelineTranslatorUtils.getInputId(pTransformNode);
        JavaRDD rdd = ((BoundedDataset) sparkTranslationContext.popDataset(inputId)).getRDD();
        KvCoder valueCoder = PipelineTranslatorUtils.getWindowedValueCoder(inputId, components).getValueCoder();
        Coder keyCoder = valueCoder.getKeyCoder();
        Coder valueCoder2 = valueCoder.getValueCoder();
        WindowingStrategy windowingStrategy = PipelineTranslatorUtils.getWindowingStrategy(inputId, components);
        WindowedValue.FullWindowedValueCoder of = WindowedValue.FullWindowedValueCoder.of(valueCoder2, windowingStrategy.getWindowFn().windowCoder());
        Partitioner partitioner = getPartitioner(sparkTranslationContext);
        sparkTranslationContext.pushDataset(PipelineTranslatorUtils.getOutputId(pTransformNode), new BoundedDataset((windowingStrategy.getWindowFn().equals(new GlobalWindows()) && windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) ? GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(rdd, keyCoder, valueCoder2, partitioner) : GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy) ? GroupNonMergingWindowsFunctions.groupByKeyAndWindow(rdd, keyCoder, valueCoder2, windowingStrategy, partitioner) : GroupCombineFunctions.groupByKeyOnly(rdd, keyCoder, of, partitioner).flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory(), SystemReduceFn.buffering(valueCoder2), sparkTranslationContext.serializablePipelineOptions))));
    }

    private static <InputT, OutputT, SideInputT> void translateExecutableStage(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        JavaRDD flatMap;
        try {
            RunnerApi.ExecutableStagePayload parseFrom = RunnerApi.ExecutableStagePayload.parseFrom(pTransformNode.getTransform().getSpec().getPayload());
            String input = parseFrom.getInput();
            Dataset popDataset = sparkTranslationContext.popDataset(input);
            Map outputsMap = pTransformNode.getTransform().getOutputsMap();
            BiMap createOutputMap = PipelineTranslatorUtils.createOutputMap(outputsMap.values());
            RunnerApi.Components components = pipeline.getComponents();
            Coder windowCoder = PipelineTranslatorUtils.getWindowingStrategy(input, components).getWindowFn().windowCoder();
            ImmutableMap broadcastSideInputs = broadcastSideInputs(parseFrom, sparkTranslationContext);
            if (parseFrom.getUserStatesCount() > 0 || parseFrom.getTimersCount() > 0) {
                KvCoder valueCoder = PipelineTranslatorUtils.instantiateCoder(input, components).getValueCoder();
                if (!(valueCoder instanceof KvCoder)) {
                    throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", input, valueCoder.getClass().getSimpleName()));
                }
                flatMap = groupByKeyPair(popDataset, valueCoder.getKeyCoder(), WindowedValue.FullWindowedValueCoder.of(valueCoder.getValueCoder(), PipelineTranslatorUtils.getWindowingStrategy(input, components).getWindowFn().windowCoder())).flatMap(new SparkExecutableStageFunction(sparkTranslationContext.getSerializableOptions(), parseFrom, sparkTranslationContext.jobInfo, createOutputMap, SparkExecutableStageContextFactory.getInstance(), broadcastSideInputs, MetricsAccumulator.getInstance(), windowCoder).forPair());
            } else {
                flatMap = ((BoundedDataset) popDataset).getRDD().mapPartitions(new SparkExecutableStageFunction(sparkTranslationContext.getSerializableOptions(), parseFrom, sparkTranslationContext.jobInfo, createOutputMap, SparkExecutableStageContextFactory.getInstance(), broadcastSideInputs, MetricsAccumulator.getInstance(), windowCoder));
            }
            String executableStageIntermediateId = PipelineTranslatorUtils.getExecutableStageIntermediateId(pTransformNode);
            final JavaRDD javaRDD = flatMap;
            sparkTranslationContext.pushDataset(executableStageIntermediateId, new Dataset() { // from class: org.apache.beam.runners.spark.translation.SparkBatchPortablePipelineTranslator.1
                @Override // org.apache.beam.runners.spark.translation.Dataset
                public void cache(String str, Coder<?> coder) {
                    javaRDD.persist(StorageLevel.fromString(str));
                }

                @Override // org.apache.beam.runners.spark.translation.Dataset
                public void action() {
                    javaRDD.foreach(TranslationUtils.emptyVoidFunction());
                }

                @Override // org.apache.beam.runners.spark.translation.Dataset
                public void setName(String str) {
                    javaRDD.setName(str);
                }
            });
            sparkTranslationContext.popDataset(executableStageIntermediateId);
            for (String str : outputsMap.values()) {
                sparkTranslationContext.pushDataset(str, new BoundedDataset(flatMap.flatMap(new SparkExecutableStageExtractionFunction(((Integer) createOutputMap.get(str)).intValue()))));
            }
            if (outputsMap.isEmpty()) {
                sparkTranslationContext.pushDataset(String.format("EmptyOutputSink_%d", Integer.valueOf(sparkTranslationContext.nextSinkId())), new BoundedDataset(flatMap.flatMap(rawUnionValue -> {
                    return Collections.emptyIterator();
                })));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static <K, V> JavaPairRDD<ByteArray, Iterable<WindowedValue<KV<K, V>>>> groupByKeyPair(Dataset dataset, Coder<K> coder, WindowedValue.WindowedValueCoder<V> windowedValueCoder) {
        return GroupCombineFunctions.groupByKeyPair(((BoundedDataset) dataset).getRDD(), coder, windowedValueCoder);
    }

    private static <SideInputT> ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> broadcastSideInputs(RunnerApi.ExecutableStagePayload executableStagePayload, SparkTranslationContext sparkTranslationContext) {
        HashMap hashMap = new HashMap();
        for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : executableStagePayload.getSideInputsList()) {
            RunnerApi.Components components = executableStagePayload.getComponents();
            String inputsOrThrow = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
            if (!hashMap.containsKey(inputsOrThrow)) {
                hashMap.put(inputsOrThrow, broadcastSideInput(inputsOrThrow, components, sparkTranslationContext));
            }
        }
        return ImmutableMap.copyOf(hashMap);
    }

    private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<T>> broadcastSideInput(String str, RunnerApi.Components components, SparkTranslationContext sparkTranslationContext) {
        BoundedDataset boundedDataset = (BoundedDataset) sparkTranslationContext.popDataset(str);
        WindowedValue.WindowedValueCoder<T> windowedValueCoder = PipelineTranslatorUtils.getWindowedValueCoder(str, components);
        return new Tuple2<>(sparkTranslationContext.getSparkContext().broadcast(boundedDataset.getBytes(windowedValueCoder)), windowedValueCoder);
    }

    private static <T> void translateFlatten(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        JavaRDD union;
        Map inputsMap = pTransformNode.getTransform().getInputsMap();
        if (inputsMap.isEmpty()) {
            union = sparkTranslationContext.getSparkContext().emptyRDD();
        } else {
            JavaRDD[] javaRDDArr = new JavaRDD[inputsMap.size()];
            int i = 0;
            Iterator it = inputsMap.values().iterator();
            while (it.hasNext()) {
                javaRDDArr[i] = ((BoundedDataset) sparkTranslationContext.popDataset((String) it.next())).getRDD();
                i++;
            }
            union = sparkTranslationContext.getSparkContext().union(javaRDDArr);
        }
        sparkTranslationContext.pushDataset(PipelineTranslatorUtils.getOutputId(pTransformNode), new BoundedDataset(union));
    }

    private static <T> void translateReshuffle(PipelineNode.PTransformNode pTransformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext sparkTranslationContext) {
        String inputId = PipelineTranslatorUtils.getInputId(pTransformNode);
        sparkTranslationContext.pushDataset(PipelineTranslatorUtils.getOutputId(pTransformNode), new BoundedDataset(GroupCombineFunctions.reshuffle(((BoundedDataset) sparkTranslationContext.popDataset(inputId)).getRDD(), PipelineTranslatorUtils.getWindowedValueCoder(inputId, pipeline.getComponents()))));
    }

    private static Partitioner getPartitioner(SparkTranslationContext sparkTranslationContext) {
        if (((SparkPipelineOptions) sparkTranslationContext.serializablePipelineOptions.get().as(SparkPipelineOptions.class)).getBundleSize().longValue() > 0) {
            return null;
        }
        return new HashPartitioner(sparkTranslationContext.getSparkContext().defaultParallelism().intValue());
    }

    @Override // org.apache.beam.runners.spark.translation.SparkPortablePipelineTranslator
    public SparkTranslationContext createTranslationContext(JavaSparkContext javaSparkContext, SparkPipelineOptions sparkPipelineOptions, JobInfo jobInfo) {
        return new SparkTranslationContext(javaSparkContext, sparkPipelineOptions, jobInfo);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1451005741:
                if (implMethodName.equals("lambda$translateExecutableStage$f002b951$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/beam/sdk/transforms/join/RawUnionValue;)Ljava/util/Iterator;")) {
                    return rawUnionValue -> {
                        return Collections.emptyIterator();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
