package org.apache.beam.sdk.io.kafka;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import org.apache.beam.sdk.io.kafka.AutoValue_KafkaWriteSchemaTransformProvider_KafkaWriteSchemaTransformConfiguration;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.schemas.utils.JsonUtils;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.kafka.common.serialization.ByteArraySerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService({SchemaTransformProvider.class})
/* loaded from: input_file:org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.class */
public class KafkaWriteSchemaTransformProvider extends TypedSchemaTransformProvider<KafkaWriteSchemaTransformConfiguration> {
    public static final String SUPPORTED_FORMATS_STR = "RAW,JSON,AVRO,PROTO";
    public static final Set<String> SUPPORTED_FORMATS = Sets.newHashSet(SUPPORTED_FORMATS_STR.split(","));
    public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() { // from class: org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.1
    };
    public static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG = new TupleTag<KV<byte[], byte[]>>() { // from class: org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.2
    };
    private static final Logger LOG = LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider$KafkaWriteSchemaTransform.class */
    public static final class KafkaWriteSchemaTransform extends SchemaTransform implements Serializable {
        final KafkaWriteSchemaTransformConfiguration configuration;

        /* loaded from: input_file:org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider$KafkaWriteSchemaTransform$ErrorCounterFn.class */
        public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
            private final SerializableFunction<Row, byte[]> toBytesFn;
            private final Counter errorCounter;
            private Long errorsInBundle = 0L;
            private final boolean handleErrors;
            private final Schema errorSchema;

            public ErrorCounterFn(String str, SerializableFunction<Row, byte[]> serializableFunction, Schema schema, boolean z) {
                this.toBytesFn = serializableFunction;
                this.errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, str);
                this.handleErrors = z;
                this.errorSchema = schema;
            }

            @DoFn.ProcessElement
            public void process(@DoFn.Element Row row, DoFn.MultiOutputReceiver multiOutputReceiver) {
                KV kv = null;
                try {
                    kv = KV.of(new byte[1], (byte[]) this.toBytesFn.apply(row));
                } catch (Exception e) {
                    if (!this.handleErrors) {
                        throw new RuntimeException(e);
                    }
                    this.errorsInBundle = Long.valueOf(this.errorsInBundle.longValue() + 1);
                    KafkaWriteSchemaTransformProvider.LOG.warn("Error while processing the element", e);
                    multiOutputReceiver.get(KafkaWriteSchemaTransformProvider.ERROR_TAG).output(ErrorHandling.errorRecord(this.errorSchema, row, e));
                }
                if (kv != null) {
                    multiOutputReceiver.get(KafkaWriteSchemaTransformProvider.OUTPUT_TAG).output(kv);
                }
            }

            @DoFn.FinishBundle
            public void finish() {
                this.errorCounter.inc(this.errorsInBundle.longValue());
                this.errorsInBundle = 0L;
            }
        }

        KafkaWriteSchemaTransform(KafkaWriteSchemaTransformConfiguration kafkaWriteSchemaTransformConfiguration) {
            this.configuration = kafkaWriteSchemaTransformConfiguration;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Row getConfigurationRow() {
            try {
                return ((Row) SchemaRegistry.createDefault().getToRowFunction(KafkaWriteSchemaTransformConfiguration.class).apply(this.configuration)).sorted().toSnakeCase();
            } catch (NoSuchSchemaException e) {
                throw new RuntimeException((Throwable) e);
            }
        }

        public PCollectionRowTuple expand(PCollectionRowTuple pCollectionRowTuple) {
            SerializableFunction<Row, byte[]> rowToAvroBytesFunction;
            Schema schema = pCollectionRowTuple.get("input").getSchema();
            if (this.configuration.getFormat().equals("RAW")) {
                int size = schema.getFields().size();
                if (size != 1) {
                    throw new IllegalArgumentException("Expecting exactly one field, found " + size);
                }
                if (!schema.getField(0).getType().equals(Schema.FieldType.BYTES)) {
                    throw new IllegalArgumentException("The input schema must have exactly one field of type byte.");
                }
                rowToAvroBytesFunction = KafkaWriteSchemaTransformProvider.getRowToRawBytesFunction(schema.getField(0).getName());
            } else if (this.configuration.getFormat().equals("JSON")) {
                rowToAvroBytesFunction = JsonUtils.getRowToJsonBytesFunction(schema);
            } else if (this.configuration.getFormat().equals("PROTO")) {
                String fileDescriptorPath = this.configuration.getFileDescriptorPath();
                String schema2 = this.configuration.getSchema();
                String messageName = this.configuration.getMessageName();
                if (messageName == null) {
                    throw new IllegalArgumentException("Expecting messageName to be non-null.");
                }
                if (fileDescriptorPath != null && schema2 != null) {
                    throw new IllegalArgumentException("You must include a descriptorPath or a proto Schema but not both.");
                }
                if (fileDescriptorPath != null) {
                    rowToAvroBytesFunction = ProtoByteUtils.getRowToProtoBytes(fileDescriptorPath, messageName);
                } else {
                    if (schema2 == null) {
                        throw new IllegalArgumentException("At least a descriptorPath or a proto Schema is required.");
                    }
                    rowToAvroBytesFunction = ProtoByteUtils.getRowToProtoBytesFromSchema(schema2, messageName);
                }
            } else {
                rowToAvroBytesFunction = AvroUtils.getRowToAvroBytesFunction(schema);
            }
            boolean hasOutput = ErrorHandling.hasOutput(this.configuration.getErrorHandling());
            Map<String, String> producerConfigUpdates = this.configuration.getProducerConfigUpdates();
            Schema errorSchema = ErrorHandling.errorSchema(schema);
            PCollectionTuple apply = pCollectionRowTuple.get("input").apply("Map rows to Kafka messages", ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", rowToAvroBytesFunction, errorSchema, hasOutput)).withOutputTags(KafkaWriteSchemaTransformProvider.OUTPUT_TAG, TupleTagList.of(KafkaWriteSchemaTransformProvider.ERROR_TAG)));
            apply.get(KafkaWriteSchemaTransformProvider.OUTPUT_TAG).apply(KafkaIO.write().withTopic(this.configuration.getTopic()).withBootstrapServers(this.configuration.getBootstrapServers()).withProducerConfigUpdates(producerConfigUpdates == null ? new HashMap() : new HashMap(producerConfigUpdates)).withKeySerializer(ByteArraySerializer.class).withValueSerializer(ByteArraySerializer.class));
            return PCollectionRowTuple.of(hasOutput ? this.configuration.getErrorHandling().getOutput() : "errors", apply.get(KafkaWriteSchemaTransformProvider.ERROR_TAG).setRowSchema(ErrorHandling.errorSchema(errorSchema)));
        }
    }

    @DefaultSchema(AutoValueSchema.class)
    @AutoValue
    /* loaded from: input_file:org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider$KafkaWriteSchemaTransformConfiguration.class */
    public static abstract class KafkaWriteSchemaTransformConfiguration implements Serializable {

        @AutoValue.Builder
        /* loaded from: input_file:org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider$KafkaWriteSchemaTransformConfiguration$Builder.class */
        public static abstract class Builder {
            public abstract Builder setFormat(String str);

            public abstract Builder setTopic(String str);

            public abstract Builder setBootstrapServers(String str);

            public abstract Builder setProducerConfigUpdates(Map<String, String> map);

            public abstract Builder setErrorHandling(ErrorHandling errorHandling);

            public abstract Builder setFileDescriptorPath(String str);

            public abstract Builder setMessageName(String str);

            public abstract Builder setSchema(String str);

            public abstract KafkaWriteSchemaTransformConfiguration build();
        }

        @SchemaFieldDescription("The encoding format for the data stored in Kafka. Valid options are: RAW,JSON,AVRO,PROTO")
        public abstract String getFormat();

        public abstract String getTopic();

        @SchemaFieldDescription("A list of host/port pairs to use for establishing the initial connection to the Kafka cluster. The client will make use of all servers irrespective of which servers are specified here for bootstrapping—this list only impacts the initial hosts used to discover the full set of servers. | Format: host1:port1,host2:port2,...")
        public abstract String getBootstrapServers();

        @SchemaFieldDescription("A list of key-value pairs that act as configuration parameters for Kafka producers. Most of these configurations will not be needed, but if you need to customize your Kafka producer, you may use this. See a detailed list: https://docs.confluent.io/platform/current/installation/configuration/producer-configs.html")
        @Nullable
        public abstract Map<String, String> getProducerConfigUpdates();

        @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
        @Nullable
        public abstract ErrorHandling getErrorHandling();

        @SchemaFieldDescription("The path to the Protocol Buffer File Descriptor Set file. This file is used for schema definition and message serialization.")
        @Nullable
        public abstract String getFileDescriptorPath();

        @SchemaFieldDescription("The name of the Protocol Buffer message to be used for schema extraction and data conversion.")
        @Nullable
        public abstract String getMessageName();

        @Nullable
        public abstract String getSchema();

        public static Builder builder() {
            return new AutoValue_KafkaWriteSchemaTransformProvider_KafkaWriteSchemaTransformConfiguration.Builder();
        }
    }

    protected Class<KafkaWriteSchemaTransformConfiguration> configurationClass() {
        return KafkaWriteSchemaTransformConfiguration.class;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SchemaTransform from(KafkaWriteSchemaTransformConfiguration kafkaWriteSchemaTransformConfiguration) {
        if (SUPPORTED_FORMATS.contains(kafkaWriteSchemaTransformConfiguration.getFormat())) {
            return new KafkaWriteSchemaTransform(kafkaWriteSchemaTransformConfiguration);
        }
        throw new IllegalArgumentException("Format " + kafkaWriteSchemaTransformConfiguration.getFormat() + " is not supported. Supported formats are: " + String.join(", ", SUPPORTED_FORMATS));
    }

    public static SerializableFunction<Row, byte[]> getRowToRawBytesFunction(final String str) {
        return new SimpleFunction<Row, byte[]>() { // from class: org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.3
            public byte[] apply(Row row) {
                byte[] bytes = row.getBytes(str);
                if (bytes == null) {
                    throw new NullPointerException();
                }
                return bytes;
            }
        };
    }

    public String identifier() {
        return "beam:schematransform:org.apache.beam:kafka_write:v1";
    }

    public List<String> inputCollectionNames() {
        return Collections.singletonList("input");
    }

    public List<String> outputCollectionNames() {
        return Collections.emptyList();
    }
}
