package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.functions;
import org.apache.spark.storage.StorageLevel;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.class */
class ParDoTranslatorBatch<InputT, OutputT> extends TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch$UnresolvedParDo.class */
    public static class UnresolvedParDo<InT, T> implements PipelineTranslator.UnresolvedTranslation<InT, T> {
        private final PCollection<InT> input;
        private final DoFnRunnerFactory<InT, T> doFnFact;
        private final Supplier<Encoder<WindowedValue<T>>> encoder;

        UnresolvedParDo(PCollection<InT> pCollection, DoFnRunnerFactory<InT, T> doFnRunnerFactory, Supplier<Encoder<WindowedValue<T>>> supplier) {
            this.input = pCollection;
            this.doFnFact = doFnRunnerFactory;
            this.encoder = supplier;
        }

        @Override // org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.UnresolvedTranslation
        public PCollection<InT> getInput() {
            return this.input;
        }

        @Override // org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.UnresolvedTranslation
        public <T2> PipelineTranslator.UnresolvedTranslation<InT, T2> fuse(PipelineTranslator.UnresolvedTranslation<T, T2> unresolvedTranslation) {
            UnresolvedParDo unresolvedParDo = (UnresolvedParDo) unresolvedTranslation;
            return new UnresolvedParDo(this.input, this.doFnFact.fuse(unresolvedParDo.doFnFact), unresolvedParDo.encoder);
        }

        @Override // org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator.UnresolvedTranslation
        public Dataset<WindowedValue<T>> resolve(Supplier<PipelineOptions> supplier, Dataset<WindowedValue<InT>> dataset) {
            return dataset.mapPartitions(DoFnPartitionIteratorFactory.singleOutput(supplier, MetricsAccumulator.getInstance(dataset.sparkSession()), this.doFnFact), this.encoder.get());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ParDoTranslatorBatch() {
        super(0.0f);
    }

    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public boolean canTranslate(ParDo.MultiOutput<InputT, OutputT> multiOutput) {
        DoFn fn = multiOutput.getFn();
        DoFnSignature signatureForDoFn = DoFnSignatures.signatureForDoFn(fn);
        Preconditions.checkState(!signatureForDoFn.processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", fn);
        Preconditions.checkState((signatureForDoFn.usesState() || signatureForDoFn.usesTimers()) ? false : true, "States and timers are not supported for the moment.");
        Preconditions.checkState(signatureForDoFn.onWindowExpiration() == null, "onWindowExpiration is not supported: %s", fn);
        Preconditions.checkState(!signatureForDoFn.processElement().requiresTimeSortedInput(), "@RequiresTimeSortedInput is not supported for the moment");
        SparkSideInputReader.validateMaterializations(multiOutput.getSideInputs().values());
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public void translate(ParDo.MultiOutput<InputT, OutputT> multiOutput, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) throws IOException {
        PCollection<? extends InputT> input = context.getInput();
        SideInputReader createSideInputReader = createSideInputReader(multiOutput.getSideInputs().values(), context);
        MetricsAccumulator metricsAccumulator = MetricsAccumulator.getInstance(context.getSparkSession());
        TupleTag<T> mainOutputTag = multiOutput.getMainOutputTag();
        Map<TupleTag<?>, PCollection<?>> skipUnconsumedOutputs = skipUnconsumedOutputs(context.getOutputs(), mainOutputTag, multiOutput.getAdditionalOutputTags(), context);
        if (skipUnconsumedOutputs.size() <= 1) {
            PCollection<T> output = context.getOutput(mainOutputTag);
            context.putUnresolved(output, new UnresolvedParDo(input, DoFnRunnerFactory.simple(context.getCurrentTransform(), input, createSideInputReader, context.getOutputs().size() > 1), () -> {
                return context.windowedEncoder(output.getCoder());
            }));
            return;
        }
        Map<String, Integer> tagsColumnIndex = tagsColumnIndex(skipUnconsumedOutputs.keySet());
        List<Encoder<WindowedValue<Object>>> createEncoders = createEncoders(skipUnconsumedOutputs, tagsColumnIndex, context);
        DoFnPartitionIteratorFactory multiOutput2 = DoFnPartitionIteratorFactory.multiOutput(context.getOptionsSupplier(), metricsAccumulator, DoFnRunnerFactory.simple(context.getCurrentTransform(), input, createSideInputReader, false), tagsColumnIndex);
        StorageLevel fromString = StorageLevel.fromString(((SparkCommonPipelineOptions) context.getOptions().as(SparkCommonPipelineOptions.class)).getStorageLevel());
        Dataset mapPartitions = context.getDataset(input).mapPartitions(multiOutput2, EncoderHelpers.oneOfEncoder(createEncoders));
        mapPartitions.persist(fromString);
        for (TupleTag<?> tupleTag : skipUnconsumedOutputs.keySet()) {
            int intValue = ((Integer) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(tagsColumnIndex.get(tupleTag.getId()), "Unknown tag")).intValue();
            TypedColumn as = functions.col(Integer.toString(intValue)).as(createEncoders.get(intValue));
            context.putDataset(context.getOutput(tupleTag), mapPartitions.filter(as.isNotNull()).select(as), false);
        }
    }

    private Map<TupleTag<?>, PCollection<?>> skipUnconsumedOutputs(Map<TupleTag<?>, PCollection<?>> map, TupleTag<?> tupleTag, TupleTagList tupleTagList, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) {
        switch (map.size()) {
            case 1:
                return map;
            case 2:
                return context.isLeaf((PCollection) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(map.get(tupleTagList.get(0)))) ? Collections.singletonMap(tupleTag, (PCollection) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(map.get(tupleTag))) : map;
            default:
                HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(map.size());
                for (Map.Entry<TupleTag<?>, PCollection<?>> entry : map.entrySet()) {
                    if (entry.getKey().equals(tupleTag) || !context.isLeaf(entry.getValue())) {
                        newHashMapWithExpectedSize.put(entry.getKey(), entry.getValue());
                    }
                }
                return newHashMapWithExpectedSize;
        }
    }

    private Map<String, Integer> tagsColumnIndex(Collection<TupleTag<?>> collection) {
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(collection.size());
        Iterator<TupleTag<?>> it = collection.iterator();
        while (it.hasNext()) {
            newHashMapWithExpectedSize.put(it.next().getId(), Integer.valueOf(newHashMapWithExpectedSize.size()));
        }
        return newHashMapWithExpectedSize;
    }

    private List<Encoder<WindowedValue<Object>>> createEncoders(Map<TupleTag<?>, PCollection<?>> map, Map<String, Integer> map2, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) {
        ArrayList arrayList = new ArrayList(map.size());
        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : map.entrySet()) {
            arrayList.add(((Integer) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(map2.get(entry.getKey().getId()))).intValue(), context.windowedEncoder(entry.getValue().getCoder()));
        }
        return arrayList;
    }

    private <T> SideInputReader createSideInputReader(Collection<PCollectionView<?>> collection, TransformTranslator<PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>>.Context context) {
        if (collection.isEmpty()) {
            return SparkSideInputReader.empty();
        }
        HashMap newHashMapWithExpectedSize = Maps.newHashMapWithExpectedSize(collection.size());
        for (PCollectionView<?> pCollectionView : collection) {
            PCollection<T> pCollection = (PCollection) org.apache.beam.sdk.util.Preconditions.checkStateNotNull(pCollectionView.getPCollection());
            newHashMapWithExpectedSize.put(pCollectionView.getTagInternal().getId(), context.getSideInputBroadcast(pCollection, SideInputValues.loader(pCollection)));
        }
        return SparkSideInputReader.create(newHashMapWithExpectedSize);
    }
}
