package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils;
import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.catalyst.expressions.CreateArray;
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import scala.collection.Iterator;
import scala.collection.immutable.List;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.class */
class GroupByKeyTranslatorBatch<K, V> extends TransformTranslator<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> {
    private static final Expression PANE_NO_FIRING = lit(CoderHelpers.toByteArray(PaneInfo.NO_FIRING, PaneInfo.PaneInfoCoder.of()));
    private static final List<Expression> GLOBAL_WINDOW_DETAILS = windowDetails(lit(new byte[]{ArrayUtils.EMPTY_BYTE_ARRAY}));

    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch$SerStateInternalsFactory.class */
    private interface SerStateInternalsFactory<K> extends StateInternalsFactory<K>, Serializable {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GroupByKeyTranslatorBatch() {
        super(0.2f);
    }

    @Override // org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator
    public void translate(GroupByKey<K, V> groupByKey, TransformTranslator<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>>.Context context) {
        Dataset flatMapGroups;
        WindowingStrategy windowingStrategy = context.getInput().getWindowingStrategy();
        TimestampCombiner timestampCombiner = windowingStrategy.getTimestampCombiner();
        Dataset<WindowedValue<T>> dataset = context.getDataset(context.getInput());
        KvCoder<K, V> kvCoder = (KvCoder) context.getInput().getCoder();
        KvCoder<K, V> kvCoder2 = (KvCoder) context.getOutput().getCoder();
        Encoder valueEncoderOf = context.valueEncoderOf(kvCoder);
        Encoder keyEncoderOf = context.keyEncoderOf(kvCoder);
        boolean z = !((SparkCommonPipelineOptions) context.getOptions().as(SparkCommonPipelineOptions.class)).getPreferGroupByKeyToHandleHugeValues().booleanValue();
        if (z && GroupByKeyHelpers.eligibleForGlobalGroupBy(windowingStrategy, false)) {
            flatMapGroups = dataset.groupBy(new Column[]{functions.col("value.key").as("key")}).agg(functions.collect_list(functions.col("value.value")).as("values"), timestampAggregator(timestampCombiner)).select(inGlobalWindow(keyValue(functions.col("key").as(keyEncoderOf), functions.col("values").as(iterableEnc(valueEncoderOf))), windowTimestamp(timestampCombiner)));
        } else if (GroupByKeyHelpers.eligibleForGlobalGroupBy(windowingStrategy, true)) {
            flatMapGroups = context.getDataset(context.getInput()).groupByKey(GroupByKeyHelpers.valueKey(), keyEncoderOf).mapValues(GroupByKeyHelpers.valueValue(), context.valueEncoderOf(kvCoder)).mapGroups(ScalaInterop.fun2((obj, iterator) -> {
                return KV.of(obj, iterableOnce(iterator));
            }), context.kvEncoderOf(kvCoder2)).map(ScalaInterop.fun1((v0) -> {
                return WindowedValue.valueInGlobalWindow(v0);
            }), context.windowedEncoder((Coder) kvCoder2));
        } else if (z && GroupByKeyHelpers.eligibleForGroupByWindow(windowingStrategy, false) && (windowingStrategy.getWindowFn().assignsToOneWindow() || groupByKey.fewKeys())) {
            flatMapGroups = dataset.select(new Column[]{functions.explode(functions.col("windows")).as("window"), functions.col("value"), functions.col("timestamp")}).groupBy(new Column[]{functions.col("value.key").as("key"), functions.col("window")}).agg(functions.collect_list(functions.col("value.value")).as("values"), timestampAggregator(timestampCombiner)).select(inSingleWindow(keyValue(functions.col("key").as(keyEncoderOf), functions.col("values").as(iterableEnc(valueEncoderOf))), functions.col("window").as(context.windowEncoder()), windowTimestamp(timestampCombiner)));
        } else if (GroupByKeyHelpers.eligibleForGroupByWindow(windowingStrategy, true) && (windowingStrategy.getWindowFn().assignsToOneWindow() || groupByKey.fewKeys())) {
            Encoder tupleEncoder = context.tupleEncoder(context.windowEncoder(), keyEncoderOf);
            flatMapGroups = context.getDataset(context.getInput()).flatMap(GroupByKeyHelpers.explodeWindowedKey(GroupByKeyHelpers.valueValue()), context.tupleEncoder(tupleEncoder, valueEncoderOf)).groupByKey(ScalaInterop.fun1((v0) -> {
                return v0._1();
            }), tupleEncoder).mapValues(ScalaInterop.fun1((v0) -> {
                return v0._2();
            }), valueEncoderOf).mapGroups(ScalaInterop.fun2((tuple2, iterator2) -> {
                return GroupByKeyHelpers.windowedKV(tuple2, iterableOnce(iterator2));
            }), context.windowedEncoder((Coder) kvCoder2));
        } else {
            flatMapGroups = dataset.groupByKey(GroupByKeyHelpers.valueKey(), keyEncoderOf).flatMapGroups(new GroupAlsoByWindowViaOutputBufferFn(windowingStrategy, obj2 -> {
                return InMemoryStateInternals.forKey(obj2);
            }, SystemReduceFn.buffering(kvCoder.getValueCoder()), context.getOptionsSupplier()), context.windowedEncoder((Coder) kvCoder2));
        }
        context.putDataset(context.getOutput(), flatMapGroups);
    }

    private Encoder<Iterable<V>> iterableEnc(Encoder<V> encoder) {
        return EncoderHelpers.collectionEncoder(encoder);
    }

    private static Column[] timestampAggregator(TimestampCombiner timestampCombiner) {
        if (timestampCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
            return new Column[0];
        }
        return new Column[]{(timestampCombiner.equals(TimestampCombiner.EARLIEST) ? functions.min(functions.col("timestamp")) : functions.max(functions.col("timestamp"))).as("timestamp")};
    }

    private static Expression windowTimestamp(TimestampCombiner timestampCombiner) {
        return timestampCombiner.equals(TimestampCombiner.END_OF_WINDOW) ? litNull(DataTypes.LongType) : functions.col("timestamp").expr();
    }

    private static <T> Iterable<T> iterableOnce(Iterator<T> iterator) {
        return () -> {
            Preconditions.checkState(!iterator.isEmpty(), "Iterator on values can only be consumed once!");
            return ScalaInterop.javaIterator(iterator);
        };
    }

    private <T> TypedColumn<?, KV<K, T>> keyValue(TypedColumn<?, K> typedColumn, TypedColumn<?, T> typedColumn2) {
        return functions.struct(new Column[]{typedColumn.as("key"), typedColumn2.as("value")}).as(EncoderHelpers.kvEncoder(typedColumn.encoder(), typedColumn2.encoder()));
    }

    private static <InT, T> TypedColumn<InT, WindowedValue<T>> inGlobalWindow(TypedColumn<?, T> typedColumn, Expression expression) {
        List concat = ScalaInterop.concat(timestampedValue(typedColumn, expression), GLOBAL_WINDOW_DETAILS);
        return new Column(new CreateNamedStruct(concat)).as(EncoderHelpers.windowedValueEncoder(typedColumn.encoder(), EncoderHelpers.encoderOf(GlobalWindow.class)));
    }

    public static <InT, T> TypedColumn<InT, WindowedValue<T>> inSingleWindow(TypedColumn<?, T> typedColumn, TypedColumn<?, ? extends BoundedWindow> typedColumn2, Expression expression) {
        List concat = ScalaInterop.concat(timestampedValue(typedColumn, expression), windowDetails(new CreateArray(ScalaInterop.listOf(typedColumn2.expr()))));
        return new Column(new CreateNamedStruct(concat)).as(EncoderHelpers.windowedValueEncoder(typedColumn.encoder(), typedColumn2.encoder()));
    }

    private static List<Expression> timestampedValue(Column column, Expression expression) {
        return ScalaInterop.seqOf(lit("value"), column.expr(), lit("timestamp"), expression).toList();
    }

    private static List<Expression> windowDetails(Expression expression) {
        return ScalaInterop.seqOf(lit("windows"), expression, lit("pane"), PANE_NO_FIRING).toList();
    }

    private static <T> Expression lit(T t) {
        return Literal$.MODULE$.apply(t);
    }

    private static Expression litNull(DataType dataType) {
        return new Literal((Object) null, dataType);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1477534967:
                if (implMethodName.equals("valueInGlobalWindow")) {
                    z = 2;
                    break;
                }
                break;
            case 2994:
                if (implMethodName.equals("_1")) {
                    z = false;
                    break;
                }
                break;
            case 2995:
                if (implMethodName.equals("_2")) {
                    z = true;
                    break;
                }
                break;
            case 184951066:
                if (implMethodName.equals("lambda$translate$849f58f4$1")) {
                    z = 4;
                    break;
                }
                break;
            case 2102185712:
                if (implMethodName.equals("lambda$translate$615cfad0$1")) {
                    z = 3;
                    break;
                }
                break;
            case 2102185713:
                if (implMethodName.equals("lambda$translate$615cfad0$2")) {
                    z = 5;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun1") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("scala/Tuple2") && serializedLambda.getImplMethodSignature().equals("()Ljava/lang/Object;")) {
                    return (v0) -> {
                        return v0._1();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun1") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("scala/Tuple2") && serializedLambda.getImplMethodSignature().equals("()Ljava/lang/Object;")) {
                    return (v0) -> {
                        return v0._2();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun1") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/sdk/util/WindowedValue") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Lorg/apache/beam/sdk/util/WindowedValue;")) {
                    return (v0) -> {
                        return WindowedValue.valueInGlobalWindow(v0);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun2") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Lscala/collection/Iterator;)Lorg/apache/beam/sdk/values/KV;")) {
                    return (obj, iterator) -> {
                        return KV.of(obj, iterableOnce(iterator));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch$SerStateInternalsFactory") && serializedLambda.getFunctionalInterfaceMethodName().equals("stateInternalsForKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lorg/apache/beam/runners/core/StateInternals;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Lorg/apache/beam/runners/core/StateInternals;")) {
                    return obj2 -> {
                        return InMemoryStateInternals.forKey(obj2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop$Fun2") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;Lscala/collection/Iterator;)Lorg/apache/beam/sdk/util/WindowedValue;")) {
                    return (tuple2, iterator2) -> {
                        return GroupByKeyHelpers.windowedKV(tuple2, iterableOnce(iterator2));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
