diff --git a/CHANGES.md b/CHANGES.md index 4b977bf3790d..7686b7a92d96 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -66,6 +66,7 @@ * TextIO now supports skipping multiple header lines (Java) ([#17990](https://github.com/apache/beam/issues/17990)). * Python GCSIO is now implemented with GCP GCS Client instead of apitools ([#25676](https://github.com/apache/beam/issues/25676)) * Adding support for LowCardinality DataType in ClickHouse (Java) ([#29533](https://github.com/apache/beam/pull/29533)). +* Added support for handling bad records to KafkaIO (Java) ([#29546](https://github.com/apache/beam/pull/29546)) ## New Features / Improvements diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy index cd2875fdb512..bb08e79edd3c 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/kafka/KafkaTestUtilities.groovy @@ -40,7 +40,7 @@ class KafkaTestUtilities { '"keySizeBytes": "10",' + '"valueSizeBytes": "90"' + '}', - "--readTimeout=120", + "--readTimeout=60", "--kafkaTopic=beam", "--withTestcontainers=true", "--kafkaContainerVersion=5.5.2", @@ -56,6 +56,7 @@ class KafkaTestUtilities { excludeTestsMatching "*SDFResumesCorrectly" //Kafka SDF does not work for kafka versions <2.0.1 excludeTestsMatching "*StopReadingFunction" //Kafka SDF does not work for kafka versions <2.0.1 excludeTestsMatching "*WatermarkUpdateWithSparseMessages" //Kafka SDF does not work for kafka versions <2.0.1 + excludeTestsMatching "*KafkaIOSDFReadWithErrorHandler" } } } diff --git a/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java b/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java index 34a4b646555d..602c34d4219d 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java +++ b/examples/java/src/main/java/org/apache/beam/examples/KafkaStreaming.java @@ -49,8 +49,11 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler; import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -60,6 +63,8 @@ import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.IntegerDeserializer; import org.apache.kafka.common.serialization.IntegerSerializer; import org.apache.kafka.common.serialization.StringDeserializer; @@ -97,7 +102,7 @@ public interface KafkaStreamingOptions extends PipelineOptions { * to use your own Kafka server. */ @Description("Kafka server host") - @Default.String("kafka_server:9092") + @Default.String("localhost:9092") String getKafkaHost(); void setKafkaHost(String value); @@ -208,15 +213,22 @@ public void run() { // Start reading form Kafka with the latest offset consumerConfig.put("auto.offset.reset", "latest"); - PCollection> pCollection = - pipeline.apply( - KafkaIO.read() - .withBootstrapServers(options.getKafkaHost()) - .withTopic(TOPIC_NAME) - .withKeyDeserializer(StringDeserializer.class) - .withValueDeserializer(IntegerDeserializer.class) - .withConsumerConfigUpdates(consumerConfig) - .withoutMetadata()); + // Register an error handler for any deserialization errors. + // Errors are simulated with an intentionally failing deserializer + PCollection> pCollection; + try (BadRecordErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new LogErrors())) { + pCollection = + pipeline.apply( + KafkaIO.read() + .withBootstrapServers(options.getKafkaHost()) + .withTopic(TOPIC_NAME) + .withKeyDeserializer(StringDeserializer.class) + .withValueDeserializer(IntermittentlyFailingIntegerDeserializer.class) + .withConsumerConfigUpdates(consumerConfig) + .withBadRecordErrorHandler(errorHandler) + .withoutMetadata()); + } pCollection // Apply a window and a trigger ourput repeatedly. @@ -317,4 +329,39 @@ public void processElement(ProcessContext c, IntervalWindow w) throws Exception c.output(c.element()); } } + + // Simple PTransform to log Error information + static class LogErrors extends PTransform, PCollection> { + + @Override + public PCollection expand(PCollection input) { + return input.apply("Log Errors", ParDo.of(new LogErrorFn())); + } + + static class LogErrorFn extends DoFn { + @ProcessElement + public void processElement(@Element BadRecord record, OutputReceiver receiver) { + System.out.println(record); + receiver.output(record); + } + } + } + + // Intentionally failing deserializer to simulate bad data from Kafka + public static class IntermittentlyFailingIntegerDeserializer implements Deserializer { + + public static final IntegerDeserializer INTEGER_DESERIALIZER = new IntegerDeserializer(); + public int deserializeCount = 0; + + public IntermittentlyFailingIntegerDeserializer() {} + + @Override + public Integer deserialize(String topic, byte[] data) { + deserializeCount++; + if (deserializeCount % 10 == 0) { + throw new SerializationException("Expected Serialization Exception"); + } + return INTEGER_DESERIALIZER.deserialize(topic, data); + } + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java index 9e0298d885eb..e02965b72022 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.transforms.errorhandling; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -49,22 +52,24 @@ *

Simple usage with one DLQ *

{@code
  * PCollection records = ...;
- * try (ErrorHandler errorHandler = pipeline.registerErrorHandler(SomeSink.write())) {
- *  PCollection results = records.apply(SomeIO.write().withDeadLetterQueue(errorHandler));
+ * try (BadRecordErrorHandler errorHandler = pipeline.registerBadRecordErrorHandler(SomeSink.write())) {
+ *  PCollection results = records.apply(SomeIO.write().withErrorHandler(errorHandler));
  * }
  * results.apply(SomeOtherTransform);
  * }
* Usage with multiple DLQ stages *
{@code
  * PCollection records = ...;
- * try (ErrorHandler errorHandler = pipeline.registerErrorHandler(SomeSink.write())) {
- *  PCollection results = records.apply(SomeIO.write().withDeadLetterQueue(errorHandler))
- *                        .apply(OtherTransform.builder().withDeadLetterQueue(errorHandler));
+ * try (BadRecordErrorHandler errorHandler = pipeline.registerBadRecordErrorHandler(SomeSink.write())) {
+ *  PCollection results = records.apply(SomeIO.write().withErrorHandler(errorHandler))
+ *                        .apply(OtherTransform.builder().withErrorHandler(errorHandler));
  * }
  * results.apply(SomeOtherTransform);
  * }
+ * This is marked as serializable despite never being needed on the runner, to enable it to be a + * parameter of an Autovalue configured PTransform. */ -public interface ErrorHandler extends AutoCloseable { +public interface ErrorHandler extends AutoCloseable, Serializable { void addErrorCollection(PCollection errorCollection); @@ -79,13 +84,16 @@ class PTransformErrorHandler private static final Logger LOG = LoggerFactory.getLogger(PTransformErrorHandler.class); private final PTransform, OutputT> sinkTransform; - private final Pipeline pipeline; + // transient as Pipelines are not serializable + private final transient Pipeline pipeline; private final Coder coder; - private final List> errorCollections = new ArrayList<>(); + // transient as PCollections are not serializable + private transient List> errorCollections = new ArrayList<>(); - private @Nullable OutputT sinkOutput = null; + // transient as PCollections are not serializable + private transient @Nullable OutputT sinkOutput = null; private boolean closed = false; @@ -103,6 +111,12 @@ public PTransformErrorHandler( this.coder = coder; } + private void readObject(ObjectInputStream aInputStream) + throws ClassNotFoundException, IOException { + aInputStream.defaultReadObject(); + errorCollections = new ArrayList<>(); + } + @Override public void addErrorCollection(PCollection errorCollection) { errorCollections.add(errorCollection); diff --git a/sdks/java/io/kafka/kafka-01103/build.gradle b/sdks/java/io/kafka/kafka-01103/build.gradle index a0fa372397a2..3a74bf04ef22 100644 --- a/sdks/java/io/kafka/kafka-01103/build.gradle +++ b/sdks/java/io/kafka/kafka-01103/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="0.11.0.3" undelimited="01103" + sdfCompatible=false } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-100/build.gradle b/sdks/java/io/kafka/kafka-100/build.gradle index 15ce8c0deeff..bd5fa67b1cfc 100644 --- a/sdks/java/io/kafka/kafka-100/build.gradle +++ b/sdks/java/io/kafka/kafka-100/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="1.0.0" undelimited="100" + sdfCompatible=false } -apply from: "../kafka-integration-test.gradle" \ No newline at end of file +apply from: "../kafka-integration-test.gradle" diff --git a/sdks/java/io/kafka/kafka-111/build.gradle b/sdks/java/io/kafka/kafka-111/build.gradle index fee4c382ed41..c2b0c8f82827 100644 --- a/sdks/java/io/kafka/kafka-111/build.gradle +++ b/sdks/java/io/kafka/kafka-111/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="1.1.1" undelimited="111" + sdfCompatible=false } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-201/build.gradle b/sdks/java/io/kafka/kafka-201/build.gradle index d395d0aa6269..a26ca4ac19cf 100644 --- a/sdks/java/io/kafka/kafka-201/build.gradle +++ b/sdks/java/io/kafka/kafka-201/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.0.1" undelimited="201" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-211/build.gradle b/sdks/java/io/kafka/kafka-211/build.gradle index 4de07193b5a2..433d6c93f361 100644 --- a/sdks/java/io/kafka/kafka-211/build.gradle +++ b/sdks/java/io/kafka/kafka-211/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.1.1" undelimited="211" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-222/build.gradle b/sdks/java/io/kafka/kafka-222/build.gradle index 57de58e81895..0f037e742968 100644 --- a/sdks/java/io/kafka/kafka-222/build.gradle +++ b/sdks/java/io/kafka/kafka-222/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.2.2" undelimited="222" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-231/build.gradle b/sdks/java/io/kafka/kafka-231/build.gradle index 3682791c5b67..712158dcd3ae 100644 --- a/sdks/java/io/kafka/kafka-231/build.gradle +++ b/sdks/java/io/kafka/kafka-231/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.3.1" undelimited="231" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-241/build.gradle b/sdks/java/io/kafka/kafka-241/build.gradle index 358c95aeb2fe..c0ac7df674b5 100644 --- a/sdks/java/io/kafka/kafka-241/build.gradle +++ b/sdks/java/io/kafka/kafka-241/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.4.1" undelimited="241" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-251/build.gradle b/sdks/java/io/kafka/kafka-251/build.gradle index f291ecccc36b..4de9f97a738a 100644 --- a/sdks/java/io/kafka/kafka-251/build.gradle +++ b/sdks/java/io/kafka/kafka-251/build.gradle @@ -18,6 +18,7 @@ project.ext { delimited="2.5.1" undelimited="251" + sdfCompatible=true } apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-integration-test.gradle b/sdks/java/io/kafka/kafka-integration-test.gradle index 778f8a3c456c..1aeb0c97f93b 100644 --- a/sdks/java/io/kafka/kafka-integration-test.gradle +++ b/sdks/java/io/kafka/kafka-integration-test.gradle @@ -39,4 +39,4 @@ dependencies { configurations.create("kafkaVersion$undelimited") -tasks.register("kafkaVersion${undelimited}BatchIT",KafkaTestUtilities.KafkaBatchIT, project.ext.delimited, project.ext.undelimited, false, configurations, project) \ No newline at end of file +tasks.register("kafkaVersion${undelimited}BatchIT",KafkaTestUtilities.KafkaBatchIT, project.ext.delimited, project.ext.undelimited, project.ext.sdfCompatible, configurations, project) \ No newline at end of file diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 7e4fc55c6ce9..8fd0c34cfa90 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -81,6 +81,11 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing; @@ -89,9 +94,11 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner; @@ -167,6 +174,10 @@ * // signal. * .withCheckStopReadingFn(new SerializedFunction() {}) * + * //If you would like to send messages that fail to be parsed from Kafka to an alternate sink, + * //use the error handler pattern as defined in {@link ErrorHandler} + * .withBadRecordErrorHandler(errorHandler) + * * // finally, if you don't need Kafka metadata, you can drop it.g * .withoutMetadata() // PCollection> * ) @@ -469,6 +480,11 @@ * // or you can also set a custom timestamp with a function. * .withPublishTimestampFunction((elem, elemTs) -> ...) * + * // Optionally, records that fail to serialize can be sent to an error handler + * // See {@link ErrorHandler} for details of for details of configuring a bad record error + * // handler + * .withBadRecordErrorHandler(errorHandler) + * * // Optionally enable exactly-once sink (on supported runners). See JavaDoc for withEOS(). * .withEOS(20, "eos-sink-group-id"); * ); @@ -592,13 +608,7 @@ public static ReadSourceDescriptors readSourceDescriptors() { */ public static Write write() { return new AutoValue_KafkaIO_Write.Builder() - .setWriteRecordsTransform( - new AutoValue_KafkaIO_WriteRecords.Builder() - .setProducerConfig(WriteRecords.DEFAULT_PRODUCER_PROPERTIES) - .setEOS(false) - .setNumShards(0) - .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN) - .build()) + .setWriteRecordsTransform(writeRecords()) .build(); } @@ -613,6 +623,8 @@ public static WriteRecords writeRecords() { .setEOS(false) .setNumShards(0) .setConsumerFactoryFn(KafkaIOUtils.KAFKA_CONSUMER_FACTORY_FN) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) .build(); } @@ -691,6 +703,9 @@ public abstract static class Read @Pure public abstract @Nullable CheckStopReadingFn getCheckStopReadingFn(); + @Pure + public abstract @Nullable ErrorHandler getBadRecordErrorHandler(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -739,6 +754,9 @@ abstract Builder setValueDeserializerProvider( abstract Builder setCheckStopReadingFn(@Nullable CheckStopReadingFn checkStopReadingFn); + abstract Builder setBadRecordErrorHandler( + @Nullable ErrorHandler badRecordErrorHandler); + Builder setCheckStopReadingFn( @Nullable SerializableFunction checkStopReadingFn) { return setCheckStopReadingFn(CheckStopReadingFnWrapper.of(checkStopReadingFn)); @@ -1312,6 +1330,10 @@ public Read withCheckStopReadingFn( .build(); } + public Read withBadRecordErrorHandler(ErrorHandler badRecordErrorHandler) { + return toBuilder().setBadRecordErrorHandler(badRecordErrorHandler).build(); + } + /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */ public PTransform>> withoutMetadata() { return new TypedWithoutMetadata<>(this); @@ -1529,6 +1551,11 @@ static class ReadFromKafkaViaUnbounded extends AbstractReadFromKafka @Override public PCollection> expand(PBegin input) { + if (kafkaRead.getBadRecordErrorHandler() != null) { + LOG.warn( + "The Legacy implementation of Kafka Read does not support writing malformed" + + "messages to an error handler. Use the SDF implementation instead."); + } // Handles unbounded source to bounded conversion if maxNumRecords or maxReadTime is set. Unbounded> unbounded = org.apache.beam.sdk.io.Read.from( @@ -1576,6 +1603,10 @@ public PCollection> expand(PBegin input) { if (kafkaRead.getStopReadTime() != null) { readTransform = readTransform.withBounded(); } + if (kafkaRead.getBadRecordErrorHandler() != null) { + readTransform = + readTransform.withBadRecordErrorHandler(kafkaRead.getBadRecordErrorHandler()); + } PCollection output; if (kafkaRead.isDynamicRead()) { Set topics = new HashSet<>(); @@ -1956,6 +1987,8 @@ public void populateDisplayData(DisplayData.Builder builder) { public abstract static class ReadSourceDescriptors extends PTransform, PCollection>> { + private final TupleTag>> records = new TupleTag<>(); + private static final Logger LOG = LoggerFactory.getLogger(ReadSourceDescriptors.class); @Pure @@ -1997,6 +2030,12 @@ public abstract static class ReadSourceDescriptors @Pure abstract @Nullable TimestampPolicyFactory getTimestampPolicyFactory(); + @Pure + abstract BadRecordRouter getBadRecordRouter(); + + @Pure + abstract ErrorHandler getBadRecordErrorHandler(); + abstract boolean isBounded(); abstract ReadSourceDescriptors.Builder toBuilder(); @@ -2041,6 +2080,12 @@ abstract ReadSourceDescriptors.Builder setCommitOffsetEnabled( abstract ReadSourceDescriptors.Builder setTimestampPolicyFactory( TimestampPolicyFactory policy); + abstract ReadSourceDescriptors.Builder setBadRecordRouter( + BadRecordRouter badRecordRouter); + + abstract ReadSourceDescriptors.Builder setBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler); + abstract ReadSourceDescriptors.Builder setBounded(boolean bounded); abstract ReadSourceDescriptors build(); @@ -2052,6 +2097,8 @@ public static ReadSourceDescriptors read() { .setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES) .setCommitOffsetEnabled(false) .setBounded(false) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .setBadRecordErrorHandler(new ErrorHandler.DefaultErrorHandler<>()) .build() .withProcessingTime() .withMonotonicallyIncreasingWatermarkEstimator(); @@ -2305,6 +2352,14 @@ public ReadSourceDescriptors withConsumerConfigOverrides( return toBuilder().setConsumerConfig(consumerConfig).build(); } + public ReadSourceDescriptors withBadRecordErrorHandler( + ErrorHandler errorHandler) { + return toBuilder() + .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER) + .setBadRecordErrorHandler(errorHandler) + .build(); + } + ReadAllFromRow forExternalBuild() { return new ReadAllFromRow<>(this); } @@ -2395,9 +2450,18 @@ public PCollection> expand(PCollection Coder> recordCoder = KafkaRecordCoder.of(keyCoder, valueCoder); try { + PCollectionTuple pCollectionTuple = + input.apply( + ParDo.of(ReadFromKafkaDoFn.create(this, records)) + .withOutputTags(records, TupleTagList.of(BadRecordRouter.BAD_RECORD_TAG))); + getBadRecordErrorHandler() + .addErrorCollection( + pCollectionTuple + .get(BadRecordRouter.BAD_RECORD_TAG) + .setCoder(BadRecord.getCoder(input.getPipeline()))); PCollection>> outputWithDescriptor = - input - .apply(ParDo.of(ReadFromKafkaDoFn.create(this))) + pCollectionTuple + .get(records) .setCoder( KvCoder.of( input @@ -2538,6 +2602,12 @@ public abstract static class WriteRecords public abstract @Nullable SerializableFunction, ? extends Consumer> getConsumerFactoryFn(); + @Pure + public abstract BadRecordRouter getBadRecordRouter(); + + @Pure + public abstract ErrorHandler getBadRecordErrorHandler(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -2565,6 +2635,11 @@ abstract Builder setPublishTimestampFunction( abstract Builder setConsumerFactoryFn( SerializableFunction, ? extends Consumer> fn); + abstract Builder setBadRecordRouter(BadRecordRouter router); + + abstract Builder setBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler); + abstract WriteRecords build(); } @@ -2711,6 +2786,14 @@ public WriteRecords withConsumerFactoryFn( return toBuilder().setConsumerFactoryFn(consumerFactoryFn).build(); } + public WriteRecords withBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler) { + return toBuilder() + .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER) + .setBadRecordErrorHandler(badRecordErrorHandler) + .build(); + } + @Override public PDone expand(PCollection> input) { checkArgument( @@ -2722,6 +2805,9 @@ public PDone expand(PCollection> input) { if (isEOS()) { checkArgument(getTopic() != null, "withTopic() is required when isEOS() is true"); + checkArgument( + getBadRecordErrorHandler() instanceof DefaultErrorHandler, + "BadRecordErrorHandling isn't supported with Kafka Exactly Once writing"); KafkaExactlyOnceSink.ensureEOSSupport(); // TODO: Verify that the group_id does not have existing state stored on Kafka unless @@ -2732,7 +2818,19 @@ public PDone expand(PCollection> input) { input.apply(new KafkaExactlyOnceSink<>(this)); } else { - input.apply(ParDo.of(new KafkaWriter<>(this))); + // Even though the errors are the only output from writing to Kafka, we maintain a + // PCollectionTuple + // with a void tag as the 'primary' output for easy forward compatibility + PCollectionTuple pCollectionTuple = + input.apply( + ParDo.of(new KafkaWriter<>(this)) + .withOutputTags( + new TupleTag(), TupleTagList.of(BadRecordRouter.BAD_RECORD_TAG))); + getBadRecordErrorHandler() + .addErrorCollection( + pCollectionTuple + .get(BadRecordRouter.BAD_RECORD_TAG) + .setCoder(BadRecord.getCoder(input.getPipeline()))); } return PDone.in(input.getPipeline()); } @@ -2995,6 +3093,15 @@ public Write withProducerConfigUpdates(Map configUpdates) getWriteRecordsTransform().withProducerConfigUpdates(configUpdates)); } + /** + * Configure a {@link BadRecordErrorHandler} for sending records to if they fail to serialize + * when being sent to Kafka. + */ + public Write withBadRecordErrorHandler(ErrorHandler badRecordErrorHandler) { + return withWriteRecordsTransform( + getWriteRecordsTransform().withBadRecordErrorHandler(badRecordErrorHandler)); + } + @Override public PDone expand(PCollection> input) { final String topic = Preconditions.checkStateNotNull(getTopic(), "withTopic() is required"); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java index b779de1d9cf1..a2cc9aaeb4d9 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIOReadImplementationCompatibility.java @@ -111,6 +111,7 @@ Object getDefaultValue() { KEY_DESERIALIZER_PROVIDER, VALUE_DESERIALIZER_PROVIDER, CHECK_STOP_READING_FN(SDF), + BAD_RECORD_ERROR_HANDLER(SDF), ; @Nonnull private final ImmutableSet supportedImplementations; diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java index c0c9772959f9..4f4663aa8cc8 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriter.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.SinkMetrics; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.util.Preconditions; import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.KafkaProducer; @@ -32,6 +33,7 @@ import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.errors.SerializationException; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -57,7 +59,7 @@ public void setup() { // Suppression since errors are tracked in SendCallback(), and checked in finishBundle() @ProcessElement @SuppressWarnings("FutureReturnValueIgnored") - public void processElement(ProcessContext ctx) throws Exception { + public void processElement(ProcessContext ctx, MultiOutputReceiver receiver) throws Exception { Producer producer = Preconditions.checkStateNotNull(this.producer); checkForFailures(); @@ -75,19 +77,31 @@ public void processElement(ProcessContext ctx) throws Exception { topicName = spec.getTopic(); } - @SuppressWarnings({"nullness", "unused"}) // Kafka library not annotated - Future ignored = - producer.send( - new ProducerRecord<>( - topicName, - record.partition(), - timestampMillis, - record.key(), - record.value(), - record.headers()), - callback); - - elementsWritten.inc(); + try { + @SuppressWarnings({"nullness", "unused"}) // Kafka library not annotated + Future ignored = + producer.send( + new ProducerRecord<>( + topicName, + record.partition(), + timestampMillis, + record.key(), + record.value(), + record.headers()), + callback); + + elementsWritten.inc(); + } catch (SerializationException e) { + // This exception should only occur during the key and value deserialization when + // creating the Kafka Record. We can catch the exception here as producer.send serializes + // the record before starting the future. + badRecordRouter.route( + receiver, + record, + null, + e, + "Failure serializing Key or Value of Kakfa record writing from Kafka"); + } } @FinishBundle @@ -110,6 +124,8 @@ public void teardown() { private final WriteRecords spec; private final Map producerConfig; + private final BadRecordRouter badRecordRouter; + private transient @Nullable Producer producer = null; // first exception and number of failures since last invocation of checkForFailures(): private transient @Nullable Exception sendException = null; @@ -122,6 +138,8 @@ public void teardown() { this.producerConfig = new HashMap<>(spec.getProducerConfig()); + this.badRecordRouter = spec.getBadRecordRouter(); + if (spec.getKeySerializer() != null) { this.producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, spec.getKeySerializer()); } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 1b6e3addce22..924833290f13 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.splittabledofn.GrowableOffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; @@ -45,6 +46,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; @@ -60,6 +62,7 @@ import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.serialization.Deserializer; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -144,29 +147,37 @@ abstract class ReadFromKafkaDoFn extends DoFn>> { - static ReadFromKafkaDoFn create(ReadSourceDescriptors transform) { + static ReadFromKafkaDoFn create( + ReadSourceDescriptors transform, + TupleTag>> recordTag) { if (transform.isBounded()) { - return new Bounded<>(transform); + return new Bounded<>(transform, recordTag); } else { - return new Unbounded<>(transform); + return new Unbounded<>(transform, recordTag); } } @UnboundedPerElement private static class Unbounded extends ReadFromKafkaDoFn { - Unbounded(ReadSourceDescriptors transform) { - super(transform); + Unbounded( + ReadSourceDescriptors transform, + TupleTag>> recordTag) { + super(transform, recordTag); } } @BoundedPerElement private static class Bounded extends ReadFromKafkaDoFn { - Bounded(ReadSourceDescriptors transform) { - super(transform); + Bounded( + ReadSourceDescriptors transform, + TupleTag>> recordTag) { + super(transform, recordTag); } } - private ReadFromKafkaDoFn(ReadSourceDescriptors transform) { + private ReadFromKafkaDoFn( + ReadSourceDescriptors transform, + TupleTag>> recordTag) { this.consumerConfig = transform.getConsumerConfig(); this.offsetConsumerConfig = transform.getOffsetConsumerConfig(); this.keyDeserializerProvider = @@ -178,6 +189,8 @@ private ReadFromKafkaDoFn(ReadSourceDescriptors transform) { this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn(); this.timestampPolicyFactory = transform.getTimestampPolicyFactory(); this.checkStopReadingFn = transform.getCheckStopReadingFn(); + this.badRecordRouter = transform.getBadRecordRouter(); + this.recordTag = recordTag; } private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class); @@ -193,6 +206,10 @@ private ReadFromKafkaDoFn(ReadSourceDescriptors transform) { createWatermarkEstimatorFn; private final @Nullable TimestampPolicyFactory timestampPolicyFactory; + private final BadRecordRouter badRecordRouter; + + private final TupleTag>> recordTag; + // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; @@ -361,7 +378,8 @@ public ProcessContinuation processElement( @Element KafkaSourceDescriptor kafkaSourceDescriptor, RestrictionTracker tracker, WatermarkEstimator watermarkEstimator, - OutputReceiver>> receiver) { + MultiOutputReceiver receiver) + throws Exception { final LoadingCache avgRecordSize = Preconditions.checkStateNotNull(this.avgRecordSize); final Deserializer keyDeserializerInstance = @@ -431,36 +449,52 @@ public ProcessContinuation processElement( if (!tracker.tryClaim(rawRecord.offset())) { return ProcessContinuation.stop(); } - KafkaRecord kafkaRecord = - new KafkaRecord<>( - rawRecord.topic(), - rawRecord.partition(), - rawRecord.offset(), - ConsumerSpEL.getRecordTimestamp(rawRecord), - ConsumerSpEL.getRecordTimestampType(rawRecord), - ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null, - ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord), - ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord)); - int recordSize = - (rawRecord.key() == null ? 0 : rawRecord.key().length) - + (rawRecord.value() == null ? 0 : rawRecord.value().length); - avgRecordSize - .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) - .update(recordSize, rawRecord.offset() - expectedOffset); - rawSizes.update(recordSize); - expectedOffset = rawRecord.offset() + 1; - Instant outputTimestamp; - // The outputTimestamp and watermark will be computed by timestampPolicy, where the - // WatermarkEstimator should be a manual one. - if (timestampPolicy != null) { - TimestampPolicyContext context = - updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); - outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord); - } else { - Preconditions.checkStateNotNull(this.extractOutputTimestampFn); - outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord); + try { + KafkaRecord kafkaRecord = + new KafkaRecord<>( + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + ConsumerSpEL.getRecordTimestamp(rawRecord), + ConsumerSpEL.getRecordTimestampType(rawRecord), + ConsumerSpEL.hasHeaders() ? rawRecord.headers() : null, + ConsumerSpEL.deserializeKey(keyDeserializerInstance, rawRecord), + ConsumerSpEL.deserializeValue(valueDeserializerInstance, rawRecord)); + int recordSize = + (rawRecord.key() == null ? 0 : rawRecord.key().length) + + (rawRecord.value() == null ? 0 : rawRecord.value().length); + avgRecordSize + .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) + .update(recordSize, rawRecord.offset() - expectedOffset); + rawSizes.update(recordSize); + expectedOffset = rawRecord.offset() + 1; + Instant outputTimestamp; + // The outputTimestamp and watermark will be computed by timestampPolicy, where the + // WatermarkEstimator should be a manual one. + if (timestampPolicy != null) { + TimestampPolicyContext context = + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + outputTimestamp = timestampPolicy.getTimestampForRecord(context, kafkaRecord); + } else { + Preconditions.checkStateNotNull(this.extractOutputTimestampFn); + outputTimestamp = extractOutputTimestampFn.apply(kafkaRecord); + } + receiver + .get(recordTag) + .outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp); + } catch (SerializationException e) { + // This exception should only occur during the key and value deserialization when + // creating the Kafka Record + badRecordRouter.route( + receiver, + rawRecord, + null, + e, + "Failure deserializing Key or Value of Kakfa record reading from Kafka"); + if (timestampPolicy != null) { + updateWatermarkManually(timestampPolicy, watermarkEstimator, tracker); + } } - receiver.outputWithTimestamp(KV.of(kafkaSourceDescriptor, kafkaRecord), outputTimestamp); } } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java index 2ccf7dcc3a93..38bf723a15a9 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOExternalTest.java @@ -350,13 +350,7 @@ public void testConstructKafkaWrite() throws Exception { RunnerApi.PTransform writeComposite = result.getComponents().getTransformsOrThrow(transform.getSubtransforms(1)); RunnerApi.PTransform writeParDo = - result - .getComponents() - .getTransformsOrThrow( - result - .getComponents() - .getTransformsOrThrow(writeComposite.getSubtransforms(0)) - .getSubtransforms(0)); + result.getComponents().getTransformsOrThrow(writeComposite.getSubtransforms(0)); RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom(writeParDo.getSpec().getPayload()); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java index 2c8ace9c66c1..5b976687f2c1 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java @@ -29,6 +29,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.UUID; @@ -43,6 +44,9 @@ import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.common.IOITHelper; import org.apache.beam.sdk.io.common.IOTestPipelineOptions; +import org.apache.beam.sdk.io.kafka.KafkaIOTest.ErrorSinkTransform; +import org.apache.beam.sdk.io.kafka.KafkaIOTest.FailingLongSerializer; +import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer; import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource; import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions; import org.apache.beam.sdk.options.Default; @@ -72,6 +76,7 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Values; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler; import org.apache.beam.sdk.transforms.windowing.CalendarWindows; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; @@ -124,8 +129,6 @@ public class KafkaIOIT { private static final String RUN_TIME_METRIC_NAME = "run_time"; - private static final String READ_ELEMENT_METRIC_NAME = "kafka_read_element_count"; - private static final String NAMESPACE = KafkaIOIT.class.getName(); private static final String TEST_ID = UUID.randomUUID().toString(); @@ -352,6 +355,68 @@ public void processElement(@Element String element, OutputReceiver outpu } } + // This test verifies that bad data from Kafka is properly sent to the error handler + @Test + public void testKafkaIOSDFReadWithErrorHandler() throws IOException { + writePipeline + .apply(Create.of(KV.of("key", "val"))) + .apply( + "Write to Kafka", + KafkaIO.write() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withKeySerializer(StringSerializer.class) + .withValueSerializer(StringSerializer.class) + .withTopic(options.getKafkaTopic() + "-failingDeserialization")); + + PipelineResult writeResult = writePipeline.run(); + PipelineResult.State writeState = writeResult.waitUntilFinish(); + assertNotEquals(PipelineResult.State.FAILED, writeState); + + BadRecordErrorHandler> eh = + sdfReadPipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + sdfReadPipeline.apply( + KafkaIO.read() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withTopic(options.getKafkaTopic() + "-failingDeserialization") + .withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest")) + .withKeyDeserializer(FailingDeserializer.class) + .withValueDeserializer(FailingDeserializer.class) + .withBadRecordErrorHandler(eh)); + eh.close(); + + PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(1L); + + PipelineResult readResult = sdfReadPipeline.run(); + PipelineResult.State readState = + readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout())); + cancelIfTimeouted(readResult, readState); + assertNotEquals(PipelineResult.State.FAILED, readState); + } + + @Test + public void testKafkaIOWriteWithErrorHandler() throws IOException { + + BadRecordErrorHandler> eh = + writePipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + writePipeline + .apply("Create single KV", Create.of(KV.of("key", 4L))) + .apply( + "Write to Kafka", + KafkaIO.write() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withKeySerializer(StringSerializer.class) + .withValueSerializer(FailingLongSerializer.class) + .withTopic(options.getKafkaTopic() + "-failingSerialization") + .withBadRecordErrorHandler(eh)); + eh.close(); + + PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(1L); + + PipelineResult writeResult = writePipeline.run(); + PipelineResult.State writeState = writeResult.waitUntilFinish(); + assertNotEquals(PipelineResult.State.FAILED, writeState); + } + // This test roundtrips a single KV to verify that externalWithMetadata // can handle null keys and values correctly. @Test @@ -484,9 +549,7 @@ public void testKafkaWithDynamicPartitions() throws IOException { public void testKafkaWithStopReadingFunction() { AlwaysStopCheckStopReadingFn checkStopReadingFn = new AlwaysStopCheckStopReadingFn(); - PipelineResult readResult = runWithStopReadingFn(checkStopReadingFn, "stop-reading"); - - assertEquals(-1, readElementMetric(readResult, NAMESPACE, READ_ELEMENT_METRIC_NAME)); + runWithStopReadingFn(checkStopReadingFn, "stop-reading", 0L); } private static class AlwaysStopCheckStopReadingFn implements CheckStopReadingFn { @@ -500,11 +563,7 @@ public Boolean apply(TopicPartition input) { public void testKafkaWithDelayedStopReadingFunction() { DelayedCheckStopReadingFn checkStopReadingFn = new DelayedCheckStopReadingFn(); - PipelineResult readResult = runWithStopReadingFn(checkStopReadingFn, "delayed-stop-reading"); - - assertEquals( - sourceOptions.numRecords, - readElementMetric(readResult, NAMESPACE, READ_ELEMENT_METRIC_NAME)); + runWithStopReadingFn(checkStopReadingFn, "delayed-stop-reading", sourceOptions.numRecords); } public static final Schema KAFKA_TOPIC_SCHEMA = @@ -644,7 +703,7 @@ private static class DelayedCheckStopReadingFn implements CheckStopReadingFn { @Override public Boolean apply(TopicPartition input) { - if (checkCount >= 5) { + if (checkCount >= 10) { return true; } checkCount++; @@ -652,7 +711,8 @@ public Boolean apply(TopicPartition input) { } } - private PipelineResult runWithStopReadingFn(CheckStopReadingFn function, String topicSuffix) { + private void runWithStopReadingFn( + CheckStopReadingFn function, String topicSuffix, Long expectedCount) { writePipeline .apply("Generate records", Read.from(new SyntheticBoundedSource(sourceOptions))) .apply("Measure write time", ParDo.of(new TimeMonitor<>(NAMESPACE, WRITE_TIME_METRIC_NAME))) @@ -661,21 +721,31 @@ private PipelineResult runWithStopReadingFn(CheckStopReadingFn function, String writeToKafka().withTopic(options.getKafkaTopic() + "-" + topicSuffix)); readPipeline.getOptions().as(Options.class).setStreaming(true); - readPipeline - .apply( - "Read from unbounded Kafka", - readFromKafka() - .withTopic(options.getKafkaTopic() + "-" + topicSuffix) - .withCheckStopReadingFn(function)) - .apply("Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME))); + PCollection count = + readPipeline + .apply( + "Read from unbounded Kafka", + readFromKafka() + .withTopic(options.getKafkaTopic() + "-" + topicSuffix) + .withCheckStopReadingFn(function)) + .apply( + "Measure read time", ParDo.of(new TimeMonitor<>(NAMESPACE, READ_TIME_METRIC_NAME))) + .apply("Window", Window.into(CalendarWindows.years(1))) + .apply( + "Counting element", + Combine.globally(Count.>combineFn()).withoutDefaults()); + + if (expectedCount == 0L) { + PAssert.that(count).empty(); + } else { + PAssert.thatSingleton(count).isEqualTo(expectedCount); + } PipelineResult writeResult = writePipeline.run(); writeResult.waitUntilFinish(); PipelineResult readResult = readPipeline.run(); readResult.waitUntilFinish(Duration.standardSeconds(options.getReadTimeout())); - - return readResult; } @Test @@ -686,7 +756,7 @@ public void testWatermarkUpdateWithSparseMessages() throws IOException, Interrup String topicName = "SparseDataTopicPartition-" + UUID.randomUUID(); Map records = new HashMap<>(); - for (int i = 0; i < 5; i++) { + for (int i = 1; i <= 5; i++) { records.put(i, String.valueOf(i)); } @@ -725,7 +795,7 @@ public void testWatermarkUpdateWithSparseMessages() throws IOException, Interrup PipelineResult readResult = sdfReadPipeline.run(); - Thread.sleep(options.getReadTimeout() * 1000); + Thread.sleep(options.getReadTimeout() * 1000 * 2); for (String value : records.values()) { kafkaIOITExpectedLogs.verifyError(value); @@ -753,11 +823,6 @@ public void processElement( } } - private long readElementMetric(PipelineResult result, String namespace, String name) { - MetricsReader metricsReader = new MetricsReader(result, namespace); - return metricsReader.getCounterMetric(name); - } - private Set readMetrics(PipelineResult writeResult, PipelineResult readResult) { BiFunction supplier = (reader, metricName) -> { diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index aeb5818e9134..b0df82bcdc19 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -51,6 +51,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -87,6 +88,7 @@ import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Distinct; @@ -95,11 +97,15 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.Max; import org.apache.beam.sdk.transforms.Min; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.CalendarWindows; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.CoderUtils; @@ -121,9 +127,12 @@ import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.header.internals.RecordHeader; @@ -136,7 +145,10 @@ import org.apache.kafka.common.serialization.LongSerializer; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.common.utils.Utils; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.hamcrest.collection.IsIterableContainingInAnyOrder; import org.hamcrest.collection.IsIterableWithSize; import org.joda.time.Duration; @@ -1379,7 +1391,7 @@ public void testSink() throws Exception { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1404,13 +1416,81 @@ public void testSink() throws Exception { } } + public static class FailingLongSerializer implements Serializer { + // enables instantiation by registrys + public FailingLongSerializer() {} + + @Override + public byte[] serialize(String topic, Long data) { + throw new SerializationException("ExpectedSerializationException"); + } + + @Override + public void configure(Map configs, boolean isKey) { + // intentionally left blank for compatibility with older kafka versions + } + } + + @Test + public void testSinkWithSerializationErrors() throws Exception { + // Attempt to write 10 elements to Kafka, but they will all fail to serialize, and be sent to + // the DLQ + + int numElements = 10; + + try (MockProducerWrapper producerWrapper = + new MockProducerWrapper(new FailingLongSerializer())) { + + ProducerSendCompletionThread completionThread = + new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); + + String topic = "test"; + + BadRecordErrorHandler> eh = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + + p.apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()).withoutMetadata()) + .apply( + KafkaIO.write() + .withBootstrapServers("none") + .withTopic(topic) + .withKeySerializer(IntegerSerializer.class) + .withValueSerializer(FailingLongSerializer.class) + .withInputTimestamp() + .withProducerFactoryFn(new ProducerFactoryFn(producerWrapper.producerKey)) + .withBadRecordErrorHandler(eh)); + + eh.close(); + + PAssert.thatSingleton(Objects.requireNonNull(eh.getOutput())).isEqualTo(10L); + + p.run(); + + completionThread.shutdown(); + + verifyProducerRecords(producerWrapper.mockProducer, topic, 0, false, true); + } + } + + public static class ErrorSinkTransform + extends PTransform, PCollection> { + + @Override + public @UnknownKeyFor @NonNull @Initialized PCollection expand( + PCollection input) { + return input + .apply("Window", Window.into(CalendarWindows.years(1))) + .apply("Combine", Combine.globally(Count.combineFn()).withoutDefaults()); + } + } + @Test public void testValuesSink() throws Exception { // similar to testSink(), but use values()' interface. int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1442,7 +1522,7 @@ public void testRecordsSink() throws Exception { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1474,7 +1554,7 @@ public void testSinkToMultipleTopics() throws Exception { // Set different output topic names int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1519,7 +1599,7 @@ public void testKafkaWriteHeaders() throws Exception { // Set different output topic names int numElements = 1; SimpleEntry header = new SimpleEntry<>("header_key", "header_value"); - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1562,7 +1642,7 @@ public void testKafkaWriteHeaders() throws Exception { public void testSinkProducerRecordsWithCustomTS() throws Exception { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1601,7 +1681,7 @@ public void testSinkProducerRecordsWithCustomTS() throws Exception { public void testSinkProducerRecordsWithCustomPartition() throws Exception { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1725,7 +1805,7 @@ public void testExactlyOnceSink() { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -1803,7 +1883,7 @@ public void testSinkWithSendErrors() throws Throwable { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThreadWithErrors = new ProducerSendCompletionThread(producerWrapper.mockProducer, 10, 100).start(); @@ -1993,7 +2073,7 @@ public void testSourceWithPatternDisplayData() { @Test public void testSinkDisplayData() { - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { KafkaIO.Write write = KafkaIO.write() .withBootstrapServers("myServerA:9092,myServerB:9092") @@ -2017,7 +2097,7 @@ public void testSinkMetrics() throws Exception { int numElements = 1000; - try (MockProducerWrapper producerWrapper = new MockProducerWrapper()) { + try (MockProducerWrapper producerWrapper = new MockProducerWrapper(new LongSerializer())) { ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread(producerWrapper.mockProducer).start(); @@ -2109,14 +2189,22 @@ private static class MockProducerWrapper implements AutoCloseable { } } - MockProducerWrapper() { + MockProducerWrapper(Serializer valueSerializer) { producerKey = String.valueOf(ThreadLocalRandom.current().nextLong()); mockProducer = new MockProducer( + Cluster.empty() + .withPartitions( + ImmutableMap.of( + new TopicPartition("test", 0), + new PartitionInfo("test", 0, null, null, null), + new TopicPartition("test", 1), + new PartitionInfo("test", 1, null, null, null))), false, // disable synchronous completion of send. see ProducerSendCompletionThread // below. + new DefaultPartitioner(), new IntegerSerializer(), - new LongSerializer()) { + valueSerializer) { // override flush() so that it does not complete all the waiting sends, giving a chance // to diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 554c6d2fcaf1..48b5b060a295 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.kafka; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -41,15 +42,20 @@ import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -64,7 +70,9 @@ import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.SerializationException; import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.StringDeserializer; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; @@ -80,19 +88,22 @@ public class ReadFromKafkaDoFnTest { private final TopicPartition topicPartition = new TopicPartition("topic", 0); + private static final TupleTag>> RECORDS = + new TupleTag<>(); + @Rule public ExpectedException thrown = ExpectedException.none(); private final SimpleMockKafkaConsumer consumer = new SimpleMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition); private final ReadFromKafkaDoFn dofnInstance = - ReadFromKafkaDoFn.create(makeReadSourceDescriptor(consumer)); + ReadFromKafkaDoFn.create(makeReadSourceDescriptor(consumer), RECORDS); private final ExceptionMockKafkaConsumer exceptionConsumer = new ExceptionMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition); private final ReadFromKafkaDoFn exceptionDofnInstance = - ReadFromKafkaDoFn.create(makeReadSourceDescriptor(exceptionConsumer)); + ReadFromKafkaDoFn.create(makeReadSourceDescriptor(exceptionConsumer), RECORDS); private ReadSourceDescriptors makeReadSourceDescriptor( Consumer kafkaMockConsumer) { @@ -109,6 +120,31 @@ public Consumer apply(Map input) { .withBootstrapServers("bootstrap_server"); } + private ReadSourceDescriptors makeFailingReadSourceDescriptor( + Consumer kafkaMockConsumer) { + return ReadSourceDescriptors.read() + .withKeyDeserializer(FailingDeserializer.class) + .withValueDeserializer(FailingDeserializer.class) + .withConsumerFactoryFn( + new SerializableFunction, Consumer>() { + @Override + public Consumer apply(Map input) { + return kafkaMockConsumer; + } + }) + .withBootstrapServers("bootstrap_server"); + } + + public static class FailingDeserializer implements Deserializer { + + public FailingDeserializer() {} + + @Override + public String deserialize(String topic, byte[] data) { + throw new SerializationException("Intentional serialization exception"); + } + } + private static class ExceptionMockKafkaConsumer extends MockConsumer { private final TopicPartition topicPartition; @@ -254,23 +290,57 @@ public synchronized long position(TopicPartition partition) { } } - private static class MockOutputReceiver - implements OutputReceiver>> { + private static class MockMultiOutputReceiver implements MultiOutputReceiver { + + MockOutputReceiver>> mockOutputReceiver = + new MockOutputReceiver<>(); + + MockOutputReceiver badOutputReceiver = new MockOutputReceiver<>(); + + @Override + public @UnknownKeyFor @NonNull @Initialized OutputReceiver get( + @UnknownKeyFor @NonNull @Initialized TupleTag tag) { + if (RECORDS.equals(tag)) { + return (OutputReceiver) mockOutputReceiver; + } else if (BAD_RECORD_TAG.equals(tag)) { + return (OutputReceiver) badOutputReceiver; + } else { + throw new RuntimeException("Invalid Tag"); + } + } + + public List>> getGoodRecords() { + return mockOutputReceiver.getOutputs(); + } - private final List>> records = - new ArrayList<>(); + public List getBadRecords() { + return badOutputReceiver.getOutputs(); + } @Override - public void output(KV> output) {} + public @UnknownKeyFor @NonNull @Initialized + OutputReceiver<@UnknownKeyFor @NonNull @Initialized Row> getRowReceiver( + @UnknownKeyFor @NonNull @Initialized TupleTag tag) { + return null; + } + } + + private static class MockOutputReceiver implements OutputReceiver { + + private final List records = new ArrayList<>(); + + @Override + public void output(T output) { + records.add(output); + } @Override public void outputWithTimestamp( - KV> output, - @UnknownKeyFor @NonNull @Initialized Instant timestamp) { + T output, @UnknownKeyFor @NonNull @Initialized Instant timestamp) { records.add(output); } - public List>> getOutputs() { + public List getOutputs() { return this.records; } } @@ -386,7 +456,7 @@ public void testInitialRestrictionWithException() throws Exception { @Test public void testProcessElement() throws Exception { - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setNumOfRecordsPerPoll(3L); long startOffset = 5L; OffsetRangeTracker tracker = @@ -396,7 +466,8 @@ public void testProcessElement() throws Exception { ProcessContinuation result = dofnInstance.processElement(descriptor, tracker, null, receiver); assertEquals(ProcessContinuation.stop(), result); assertEquals( - createExpectedRecords(descriptor, startOffset, 3, "key", "value"), receiver.getOutputs()); + createExpectedRecords(descriptor, startOffset, 3, "key", "value"), + receiver.getGoodRecords()); } @Test @@ -406,7 +477,7 @@ public void testRawSizeMetric() throws Exception { MetricsContainerImpl container = new MetricsContainerImpl("any"); MetricsEnvironment.setCurrentContainer(container); - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setNumOfRecordsPerPoll(numElements); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0, numElements)); KafkaSourceDescriptor descriptor = @@ -427,7 +498,7 @@ public void testRawSizeMetric() throws Exception { @Test public void testProcessElementWithEmptyPoll() throws Exception { - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setNumOfRecordsPerPoll(-1); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); ProcessContinuation result = @@ -437,12 +508,12 @@ public void testProcessElementWithEmptyPoll() throws Exception { null, receiver); assertEquals(ProcessContinuation.resume(), result); - assertTrue(receiver.getOutputs().isEmpty()); + assertTrue(receiver.getGoodRecords().isEmpty()); } @Test public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception { - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); consumer.setRemoved(); consumer.setNumOfRecordsPerPoll(10); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); @@ -457,7 +528,7 @@ public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception { @Test public void testProcessElementWhenTopicPartitionIsStopped() throws Exception { - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); ReadFromKafkaDoFn instance = ReadFromKafkaDoFn.create( makeReadSourceDescriptor(consumer) @@ -470,7 +541,8 @@ public Boolean apply(TopicPartition input) { return true; } }) - .build()); + .build(), + RECORDS); instance.setup(); consumer.setNumOfRecordsPerPoll(10); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); @@ -489,7 +561,7 @@ public void testProcessElementWithException() throws Exception { thrown.expect(KafkaException.class); thrown.expectMessage("SeekException"); - MockOutputReceiver receiver = new MockOutputReceiver(); + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); exceptionDofnInstance.processElement( @@ -499,6 +571,61 @@ public void testProcessElementWithException() throws Exception { receiver); } + @Test + public void testProcessElementWithDeserializationExceptionDefaultRecordHandler() + throws Exception { + thrown.expect(SerializationException.class); + thrown.expectMessage("Intentional serialization exception"); + + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); + + consumer.setNumOfRecordsPerPoll(1); + + ReadFromKafkaDoFn dofnInstance = + ReadFromKafkaDoFn.create(makeFailingReadSourceDescriptor(consumer), RECORDS); + + dofnInstance.setup(); + + dofnInstance.processElement( + KafkaSourceDescriptor.of(topicPartition, null, null, null, null, null), + tracker, + null, + receiver); + + Assert.assertEquals("OutputRecordSize", 0, receiver.getGoodRecords().size()); + Assert.assertEquals("OutputErrorSize", 0, receiver.getBadRecords().size()); + } + + @Test + public void testProcessElementWithDeserializationExceptionRecordingRecordHandler() + throws Exception { + MockMultiOutputReceiver receiver = new MockMultiOutputReceiver(); + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, 1L)); + + consumer.setNumOfRecordsPerPoll(1); + + // Because we never actually execute the pipeline, no data will actually make it to the error + // handler. This will just configure the ReadSourceDesriptors to route the errors to the output + // PCollection instead of rethrowing. + ReadSourceDescriptors descriptors = + makeFailingReadSourceDescriptor(consumer) + .withBadRecordErrorHandler(new DefaultErrorHandler<>()); + + ReadFromKafkaDoFn dofnInstance = ReadFromKafkaDoFn.create(descriptors, RECORDS); + + dofnInstance.setup(); + + dofnInstance.processElement( + KafkaSourceDescriptor.of(topicPartition, null, null, null, null, null), + tracker, + null, + receiver); + + Assert.assertEquals("OutputRecordSize", 0, receiver.getGoodRecords().size()); + Assert.assertEquals("OutputErrorSize", 1, receiver.getBadRecords().size()); + } + private static final TypeDescriptor KAFKA_SOURCE_DESCRIPTOR_TYPE_DESCRIPTOR = new TypeDescriptor() {}; @@ -522,7 +649,8 @@ private BoundednessVisitor testBoundedness( .apply( ParDo.of( ReadFromKafkaDoFn.create( - readSourceDescriptorsDecorator.apply(makeReadSourceDescriptor(consumer))))) + readSourceDescriptorsDecorator.apply(makeReadSourceDescriptor(consumer)), + RECORDS))) .setCoder( KvCoder.of( SerializableCoder.of(KafkaSourceDescriptor.class),