From f303d6ae7ebe4c5b2e1af5e29840317a6ce0d691 Mon Sep 17 00:00:00 2001
From: johnjcasey <95318300+johnjcasey@users.noreply.github.com>
Date: Thu, 28 Dec 2023 18:26:32 -0500
Subject: [PATCH] Add Error Handlers to File IO and related IOs (TextIO,
AvroIO) (#29670)
* first pass of wiring error handling into write files and adding tests
* fix error handling to solve constant filenaming policy returning a null destination
* fix tests, add a safety check to the error handler
* spotless
* add documentation
* add textio error handler pass-through
* add avroio error handler pass-through
* add documentation to avroio
* add documentation to WriteFiles
* remove function to check if the exception is bad, because that isn't portable
* spotless
* spotless
* clean up documentation
* clean up documentation, remove unnecessary unwritten records tag
* spotless
* spotless
---
.../java/org/apache/beam/sdk/io/FileIO.java | 43 +++
.../java/org/apache/beam/sdk/io/TextIO.java | 20 ++
.../org/apache/beam/sdk/io/WriteFiles.java | 297 ++++++++++++++----
.../errorhandling/ErrorHandler.java | 4 +
.../apache/beam/sdk/io/WriteFilesTest.java | 131 ++++++++
.../errorhandling/ErrorHandlingTestUtils.java | 48 +++
.../beam/sdk/extensions/avro/io/AvroIO.java | 20 ++
.../apache/beam/sdk/io/kafka/KafkaIOIT.java | 2 +-
.../apache/beam/sdk/io/kafka/KafkaIOTest.java | 20 +-
9 files changed, 511 insertions(+), 74 deletions(-)
create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
index 76fc1a70b78c..0bc984877217 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
@@ -61,6 +61,8 @@
import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -236,6 +238,27 @@
* destination-dependent: every window/pane for every destination will use the same number of shards
* specified via {@link Write#withNumShards} or {@link Write#withSharding}.
*
+ *
Handling Errors
+ *
+ * When using dynamic destinations, or when using a formatting function to format a record for
+ * writing, it's possible for an individual record to be malformed, causing an exception. By
+ * default, these exceptions are propagated to the runner causing the bundle to fail. These are
+ * usually retried, though this depends on the runner. Alternately, these errors can be routed to
+ * another {@link PTransform} by using {@link Write#withBadRecordErrorHandler(ErrorHandler)}. The
+ * ErrorHandler is registered with the pipeline (see below). See {@link ErrorHandler} for more
+ * documentation. Of note, this error handling only handles errors related to specific records. It
+ * does not handle errors related to connectivity, authorization, etc. as those should be retried by
+ * the runner.
+ *
+ *
{@code
+ * PCollection<> records = ...;
+ * PTransform,?> alternateSink = ...;
+ * try (BadRecordErrorHandler> handler = pipeline.registerBadRecordErrorHandler(alternateSink) {
+ * records.apply("Write", FileIO.writeDynamic().otherConfigs()
+ * .withBadRecordErrorHandler(handler));
+ * }
+ * }
+ *
* Writing custom types to sinks
*
* Normally, when writing a collection of a custom type using a {@link Sink} that takes a
@@ -1016,6 +1039,8 @@ public static FileNaming relativeFileNaming(
abstract boolean getNoSpilling();
+ abstract @Nullable ErrorHandler getBadRecordErrorHandler();
+
abstract Builder toBuilder();
@AutoValue.Builder
@@ -1062,6 +1087,9 @@ abstract Builder setSharding(
abstract Builder setNoSpilling(boolean noSpilling);
+ abstract Builder setBadRecordErrorHandler(
+ @Nullable ErrorHandler badRecordErrorHandler);
+
abstract Write build();
}
@@ -1288,6 +1316,18 @@ public Write withNoSpilling() {
return toBuilder().setNoSpilling(true).build();
}
+ /**
+ * Configures a new {@link Write} with an ErrorHandler. For configuring an ErrorHandler, see
+ * {@link ErrorHandler}. Whenever a record is formatted, or a lookup for a dynamic destination
+ * is performed, and that operation fails, the exception is passed to the error handler. This is
+ * intended to handle any errors related to the data of a record, but not any connectivity or IO
+ * errors related to the literal writing of a record.
+ */
+ public Write withBadRecordErrorHandler(
+ ErrorHandler errorHandler) {
+ return toBuilder().setBadRecordErrorHandler(errorHandler).build();
+ }
+
@VisibleForTesting
Contextful> resolveFileNamingFn() {
if (getDynamic()) {
@@ -1391,6 +1431,9 @@ public WriteFilesResult expand(PCollection input) {
if (getNoSpilling()) {
writeFiles = writeFiles.withNoSpilling();
}
+ if (getBadRecordErrorHandler() != null) {
+ writeFiles = writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler());
+ }
return input.apply(writeFiles);
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
index 2c7a4fc5d4f5..96635a37fac1 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
@@ -51,6 +51,8 @@
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
@@ -176,6 +178,10 @@
*
* For backwards compatibility, {@link TextIO} also supports the legacy {@link
* DynamicDestinations} interface for advanced features via {@link Write#to(DynamicDestinations)}.
+ *
+ *
Error handling for records that are malformed can be handled by using {@link
+ * TypedWrite#withBadRecordErrorHandler(ErrorHandler)}. See documentation in {@link FileIO} for
+ * details on usage
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -708,6 +714,8 @@ public abstract static class TypedWrite
*/
abstract WritableByteChannelFactory getWritableByteChannelFactory();
+ abstract @Nullable ErrorHandler getBadRecordErrorHandler();
+
abstract Builder toBuilder();
@AutoValue.Builder
@@ -754,6 +762,9 @@ abstract Builder setNumShards(
abstract Builder setWritableByteChannelFactory(
WritableByteChannelFactory writableByteChannelFactory);
+ abstract Builder setBadRecordErrorHandler(
+ @Nullable ErrorHandler badRecordErrorHandler);
+
abstract TypedWrite build();
}
@@ -993,6 +1004,12 @@ public TypedWrite withNoSpilling() {
return toBuilder().setNoSpilling(true).build();
}
+ /** See {@link FileIO.Write#withBadRecordErrorHandler(ErrorHandler)} for details on usage. */
+ public TypedWrite withBadRecordErrorHandler(
+ ErrorHandler errorHandler) {
+ return toBuilder().setBadRecordErrorHandler(errorHandler).build();
+ }
+
/** Don't write any output files if the PCollection is empty. */
public TypedWrite skipIfEmpty() {
return toBuilder().setSkipIfEmpty(true).build();
@@ -1083,6 +1100,9 @@ public WriteFilesResult expand(PCollection input) {
if (getNoSpilling()) {
write = write.withNoSpilling();
}
+ if (getBadRecordErrorHandler() != null) {
+ write = write.withBadRecordErrorHandler(getBadRecordErrorHandler());
+ }
if (getSkipIfEmpty()) {
write = write.withSkipIfEmpty();
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
index 91d6082eede4..7359141c5b87 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io;
+import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
@@ -49,6 +50,7 @@
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
@@ -62,6 +64,10 @@
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -166,6 +172,8 @@ public static WriteFiles())
+ .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER)
.build();
}
@@ -189,6 +197,10 @@ public static WriteFiles getShardingFunction();
+ public abstract ErrorHandler getBadRecordErrorHandler();
+
+ public abstract BadRecordRouter getBadRecordRouter();
+
abstract Builder toBuilder();
@AutoValue.Builder
@@ -215,6 +227,12 @@ abstract Builder setSideInputs(
abstract Builder setShardingFunction(
@Nullable ShardingFunction shardingFunction);
+ abstract Builder setBadRecordErrorHandler(
+ ErrorHandler badRecordErrorHandler);
+
+ abstract Builder setBadRecordRouter(
+ BadRecordRouter badRecordRouter);
+
abstract WriteFiles build();
}
@@ -330,6 +348,15 @@ public WriteFiles withSkipIfEmpty() {
return toBuilder().setSkipIfEmpty(true).build();
}
+ /** See {@link FileIO.Write#withBadRecordErrorHandler(ErrorHandler)} for details on usage. */
+ public WriteFiles withBadRecordErrorHandler(
+ ErrorHandler errorHandler) {
+ return toBuilder()
+ .setBadRecordErrorHandler(errorHandler)
+ .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER)
+ .build();
+ }
+
@Override
public void validate(PipelineOptions options) {
getSink().validate(options);
@@ -495,28 +522,39 @@ private WriteUnshardedBundlesToTempFiles(
@Override
public PCollection> expand(PCollection input) {
- if (getMaxNumWritersPerBundle() < 0) {
- return input
- .apply(
- "WritedUnshardedBundles",
- ParDo.of(new WriteUnshardedTempFilesFn(null, destinationCoder))
- .withSideInputs(getSideInputs()))
- .setCoder(fileResultCoder);
- }
TupleTag> writtenRecordsTag = new TupleTag<>("writtenRecords");
TupleTag, UserT>> unwrittenRecordsTag =
new TupleTag<>("unwrittenRecords");
+ Coder inputCoder = input.getCoder();
+ if (getMaxNumWritersPerBundle() < 0) {
+ PCollectionTuple writeTuple =
+ input.apply(
+ "WritedUnshardedBundles",
+ ParDo.of(new WriteUnshardedTempFilesFn(null, destinationCoder, inputCoder))
+ .withSideInputs(getSideInputs())
+ .withOutputTags(
+ writtenRecordsTag, TupleTagList.of(ImmutableList.of(BAD_RECORD_TAG))));
+ addErrorCollection(writeTuple);
+ return writeTuple.get(writtenRecordsTag).setCoder(fileResultCoder);
+ }
+
PCollectionTuple writeTuple =
input.apply(
"WriteUnshardedBundles",
- ParDo.of(new WriteUnshardedTempFilesFn(unwrittenRecordsTag, destinationCoder))
+ ParDo.of(
+ new WriteUnshardedTempFilesFn(
+ unwrittenRecordsTag, destinationCoder, inputCoder))
.withSideInputs(getSideInputs())
- .withOutputTags(writtenRecordsTag, TupleTagList.of(unwrittenRecordsTag)));
+ .withOutputTags(
+ writtenRecordsTag,
+ TupleTagList.of(ImmutableList.of(unwrittenRecordsTag, BAD_RECORD_TAG))));
+ addErrorCollection(writeTuple);
+
PCollection> writtenBundleFiles =
writeTuple.get(writtenRecordsTag).setCoder(fileResultCoder);
// Any "spilled" elements are written using WriteShardedBundles. Assign shard numbers in
// finalize to stay consistent with what WriteWindowedBundles does.
- PCollection> writtenSpilledFiles =
+ PCollectionTuple spilledWriteTuple =
writeTuple
.get(unwrittenRecordsTag)
.setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder()))
@@ -529,7 +567,15 @@ public PCollection> expand(PCollection input) {
.apply("GroupUnwritten", GroupByKey.create())
.apply(
"WriteUnwritten",
- ParDo.of(new WriteShardsIntoTempFilesFn()).withSideInputs(getSideInputs()))
+ ParDo.of(new WriteShardsIntoTempFilesFn(input.getCoder()))
+ .withSideInputs(getSideInputs())
+ .withOutputTags(writtenRecordsTag, TupleTagList.of(BAD_RECORD_TAG)));
+
+ addErrorCollection(spilledWriteTuple);
+
+ PCollection> writtenSpilledFiles =
+ spilledWriteTuple
+ .get(writtenRecordsTag)
.setCoder(fileResultCoder)
.apply(
"DropShardNum",
@@ -556,6 +602,8 @@ private class WriteUnshardedTempFilesFn extends DoFn, UserT>> unwrittenRecordsTag;
private final Coder destinationCoder;
+ private final Coder inputCoder;
+
// Initialized in startBundle()
private @Nullable Map, Writer> writers;
@@ -563,9 +611,11 @@ private class WriteUnshardedTempFilesFn extends DoFn, UserT>> unwrittenRecordsTag,
- Coder destinationCoder) {
+ Coder destinationCoder,
+ Coder inputCoder) {
this.unwrittenRecordsTag = unwrittenRecordsTag;
this.destinationCoder = destinationCoder;
+ this.inputCoder = inputCoder;
}
@StartBundle
@@ -575,7 +625,9 @@ public void startBundle(StartBundleContext c) {
}
@ProcessElement
- public void processElement(ProcessContext c, BoundedWindow window) throws Exception {
+ public void processElement(
+ ProcessContext c, BoundedWindow window, MultiOutputReceiver outputReceiver)
+ throws Exception {
getDynamicDestinations().setSideInputAccessorFromProcessContext(c);
PaneInfo paneInfo = c.pane();
// If we are doing windowed writes, we need to ensure that we have separate files for
@@ -583,7 +635,12 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except
// destinations go to different writers.
// In the case of unwindowed writes, the window and the pane will always be the same, and
// the map will only have a single element.
- DestinationT destination = getDynamicDestinations().getDestination(c.element());
+ MaybeDestination maybeDestination =
+ getDestinationWithErrorHandling(c.element(), outputReceiver, inputCoder);
+ if (!maybeDestination.isValid) {
+ return;
+ }
+ DestinationT destination = maybeDestination.destination;
WriterKey key = new WriterKey<>(window, c.pane(), destination);
Writer writer = writers.get(key);
if (writer == null) {
@@ -607,15 +664,22 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except
} else {
spilledShardNum = (spilledShardNum + 1) % SPILLED_RECORD_SHARDING_FACTOR;
}
- c.output(
- unwrittenRecordsTag,
- KV.of(
- ShardedKey.of(hashDestination(destination, destinationCoder), spilledShardNum),
- c.element()));
+ outputReceiver
+ .get(unwrittenRecordsTag)
+ .output(
+ KV.of(
+ ShardedKey.of(
+ hashDestination(destination, destinationCoder), spilledShardNum),
+ c.element()));
return;
}
}
- writeOrClose(writer, getDynamicDestinations().formatRecord(c.element()));
+ OutputT formattedRecord =
+ formatRecordWithErrorHandling(c.element(), outputReceiver, inputCoder);
+ if (formattedRecord == null) {
+ return;
+ }
+ writeOrClose(writer, formattedRecord);
}
@FinishBundle
@@ -701,6 +765,56 @@ private static int hashDestination(
.asInt();
}
+ private static class MaybeDestination {
+ final DestinationT destination;
+ final boolean isValid;
+
+ MaybeDestination(DestinationT destination, boolean isValid) {
+ this.destination = destination;
+ this.isValid = isValid;
+ }
+ }
+ // Utility method to get the dynamic destination based on a record. Returns a MaybeDestination
+ // because some implementations of dynamic destinations return null, despite this being prohibited
+ // by the interface
+ private MaybeDestination getDestinationWithErrorHandling(
+ UserT input, MultiOutputReceiver outputReceiver, Coder inputCoder) throws Exception {
+ try {
+ return new MaybeDestination<>(getDynamicDestinations().getDestination(input), true);
+ } catch (Exception e) {
+ getBadRecordRouter()
+ .route(
+ outputReceiver, input, inputCoder, e, "Unable to get dynamic destination for record");
+ return new MaybeDestination<>(null, false);
+ }
+ }
+
+ // Utility method to format a record based on the dynamic destination. If the operation fails, and
+ // is output to the bad record router, this returns null
+ private @Nullable OutputT formatRecordWithErrorHandling(
+ UserT input, MultiOutputReceiver outputReceiver, Coder inputCoder) throws Exception {
+ try {
+ return getDynamicDestinations().formatRecord(input);
+ } catch (Exception e) {
+ getBadRecordRouter()
+ .route(
+ outputReceiver,
+ input,
+ inputCoder,
+ e,
+ "Unable to format record for Dynamic Destination");
+ return null;
+ }
+ }
+
+ private void addErrorCollection(PCollectionTuple sourceTuple) {
+ getBadRecordErrorHandler()
+ .addErrorCollection(
+ sourceTuple
+ .get(BAD_RECORD_TAG)
+ .setCoder(BadRecord.getCoder(sourceTuple.getPipeline())));
+ }
+
private class WriteShardedBundlesToTempFiles
extends PTransform, PCollection>> {
private final Coder destinationCoder;
@@ -728,17 +842,32 @@ public PCollection> expand(PCollection input) {
? new RandomShardingFunction(destinationCoder)
: getShardingFunction();
- return input
- .apply(
+ TupleTag, UserT>> shardedRecords = new TupleTag<>("shardedRecords");
+ TupleTag> writtenRecordsTag = new TupleTag<>("writtenRecords");
+
+ PCollectionTuple shardedFiles =
+ input.apply(
"ApplyShardingKey",
- ParDo.of(new ApplyShardingFunctionFn(shardingFunction, numShardsView))
- .withSideInputs(shardingSideInputs))
- .setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder()))
- .apply("GroupIntoShards", GroupByKey.create())
- .apply(
- "WriteShardsIntoTempFiles",
- ParDo.of(new WriteShardsIntoTempFilesFn()).withSideInputs(getSideInputs()))
- .setCoder(fileResultCoder);
+ ParDo.of(
+ new ApplyShardingFunctionFn(
+ shardingFunction, numShardsView, input.getCoder()))
+ .withSideInputs(shardingSideInputs)
+ .withOutputTags(shardedRecords, TupleTagList.of(BAD_RECORD_TAG)));
+ addErrorCollection(shardedFiles);
+
+ PCollectionTuple writtenFiles =
+ shardedFiles
+ .get(shardedRecords)
+ .setCoder(KvCoder.of(ShardedKeyCoder.of(VarIntCoder.of()), input.getCoder()))
+ .apply("GroupIntoShards", GroupByKey.create())
+ .apply(
+ "WriteShardsIntoTempFiles",
+ ParDo.of(new WriteShardsIntoTempFilesFn(input.getCoder()))
+ .withSideInputs(getSideInputs())
+ .withOutputTags(writtenRecordsTag, TupleTagList.of(BAD_RECORD_TAG)));
+ addErrorCollection(writtenFiles);
+
+ return writtenFiles.get(writtenRecordsTag).setCoder(fileResultCoder);
}
}
@@ -763,22 +892,18 @@ public PCollection>> expand(PCollection inp
//
// TODO(https://github.com/apache/beam/issues/20928): The implementation doesn't currently
// work with merging windows.
+ TupleTag> shardTag = new TupleTag<>("shardTag");
+
+ PCollectionTuple shardedElements =
+ input.apply(
+ "KeyedByDestinationHash",
+ ParDo.of(new KeyByDestinationHash(input.getCoder(), destinationCoder))
+ .withOutputTags(shardTag, TupleTagList.of(BAD_RECORD_TAG)));
+ addErrorCollection(shardedElements);
+
PCollection, Iterable>> shardedInput =
- input
- .apply(
- "KeyedByDestinationHash",
- ParDo.of(
- new DoFn>() {
- @ProcessElement
- public void processElement(@Element UserT element, ProcessContext context)
- throws Exception {
- getDynamicDestinations().setSideInputAccessorFromProcessContext(context);
- DestinationT destination =
- getDynamicDestinations().getDestination(context.element());
- context.output(
- KV.of(hashDestination(destination, destinationCoder), element));
- }
- }))
+ shardedElements
+ .get(shardTag)
.setCoder(KvCoder.of(VarIntCoder.of(), input.getCoder()))
.apply(
"ShardAndBatch",
@@ -791,8 +916,9 @@ public void processElement(@Element UserT element, ProcessContext context)
org.apache.beam.sdk.util.ShardedKey.Coder.of(VarIntCoder.of()),
IterableCoder.of(input.getCoder())));
+ TupleTag> writtenRecordsTag = new TupleTag<>("writtenRecords");
// Write grouped elements to temp files.
- PCollection> tempFiles =
+ PCollectionTuple writtenFiles =
shardedInput
.apply(
"AddDummyShard",
@@ -816,7 +942,15 @@ public KV, Iterable> apply(
ShardedKeyCoder.of(VarIntCoder.of()), IterableCoder.of(input.getCoder())))
.apply(
"WriteShardsIntoTempFiles",
- ParDo.of(new WriteShardsIntoTempFilesFn()).withSideInputs(getSideInputs()))
+ ParDo.of(new WriteShardsIntoTempFilesFn(input.getCoder()))
+ .withSideInputs(getSideInputs())
+ .withOutputTags(writtenRecordsTag, TupleTagList.of(BAD_RECORD_TAG)));
+
+ addErrorCollection(writtenFiles);
+
+ PCollection> tempFiles =
+ writtenFiles
+ .get(writtenRecordsTag)
.setCoder(fileResultCoder)
.apply(
"DropShardNum",
@@ -865,6 +999,32 @@ public void processElement(
}
}
+ private class KeyByDestinationHash extends DoFn> {
+
+ private final Coder inputCoder;
+
+ private final Coder destinationCoder;
+
+ public KeyByDestinationHash(Coder inputCoder, Coder destinationCoder) {
+ this.inputCoder = inputCoder;
+ this.destinationCoder = destinationCoder;
+ }
+
+ @ProcessElement
+ public void processElement(
+ @Element UserT element, ProcessContext context, MultiOutputReceiver outputReceiver)
+ throws Exception {
+ getDynamicDestinations().setSideInputAccessorFromProcessContext(context);
+ MaybeDestination maybeDestination =
+ getDestinationWithErrorHandling(context.element(), outputReceiver, inputCoder);
+ if (!maybeDestination.isValid) {
+ return;
+ }
+ DestinationT destination = maybeDestination.destination;
+ context.output(KV.of(hashDestination(destination, destinationCoder), element));
+ }
+ }
+
private class RandomShardingFunction implements ShardingFunction {
private final Coder destinationCoder;
@@ -903,15 +1063,20 @@ private class ApplyShardingFunctionFn extends DoFn
private final ShardingFunction shardingFn;
private final @Nullable PCollectionView numShardsView;
+ private final Coder inputCoder;
+
ApplyShardingFunctionFn(
ShardingFunction shardingFn,
- @Nullable PCollectionView numShardsView) {
+ @Nullable PCollectionView numShardsView,
+ Coder inputCoder) {
this.numShardsView = numShardsView;
this.shardingFn = shardingFn;
+ this.inputCoder = inputCoder;
}
@ProcessElement
- public void processElement(ProcessContext context) throws Exception {
+ public void processElement(ProcessContext context, MultiOutputReceiver outputReceiver)
+ throws Exception {
getDynamicDestinations().setSideInputAccessorFromProcessContext(context);
final int shardCount;
if (numShardsView != null) {
@@ -927,7 +1092,12 @@ public void processElement(ProcessContext context) throws Exception {
+ " Got %s",
shardCount);
- DestinationT destination = getDynamicDestinations().getDestination(context.element());
+ MaybeDestination maybeDestination =
+ getDestinationWithErrorHandling(context.element(), outputReceiver, inputCoder);
+ if (!maybeDestination.isValid) {
+ return;
+ }
+ DestinationT destination = maybeDestination.destination;
ShardedKey shardKey =
shardingFn.assignShardKey(destination, context.element(), shardCount);
context.output(KV.of(shardKey, context.element()));
@@ -936,6 +1106,13 @@ public void processElement(ProcessContext context) throws Exception {
private class WriteShardsIntoTempFilesFn
extends DoFn, Iterable>, FileResult> {
+
+ private final Coder inputCoder;
+
+ public WriteShardsIntoTempFilesFn(Coder inputCoder) {
+ this.inputCoder = inputCoder;
+ }
+
private transient List> closeFutures = new ArrayList<>();
private transient List>> deferredOutput =
new ArrayList<>();
@@ -949,14 +1126,21 @@ private void readObject(java.io.ObjectInputStream in)
}
@ProcessElement
- public void processElement(ProcessContext c, BoundedWindow window) throws Exception {
+ public void processElement(
+ ProcessContext c, BoundedWindow window, MultiOutputReceiver outputReceiver)
+ throws Exception {
getDynamicDestinations().setSideInputAccessorFromProcessContext(c);
// Since we key by a 32-bit hash of the destination, there might be multiple destinations
// in this iterable. The number of destinations is generally very small (1000s or less), so
// there will rarely be hash collisions.
Map> writers = Maps.newHashMap();
for (UserT input : c.element().getValue()) {
- DestinationT destination = getDynamicDestinations().getDestination(input);
+ MaybeDestination maybeDestination =
+ getDestinationWithErrorHandling(input, outputReceiver, inputCoder);
+ if (!maybeDestination.isValid) {
+ continue;
+ }
+ DestinationT destination = maybeDestination.destination;
Writer writer = writers.get(destination);
if (writer == null) {
String uuid = UUID.randomUUID().toString();
@@ -971,7 +1155,12 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except
writer.open(uuid);
writers.put(destination, writer);
}
- writeOrClose(writer, getDynamicDestinations().formatRecord(input));
+
+ OutputT formattedRecord = formatRecordWithErrorHandling(input, outputReceiver, inputCoder);
+ if (formattedRecord == null) {
+ continue;
+ }
+ writeOrClose(writer, formattedRecord);
}
// Ensure that we clean-up any prior writers that were being closed as part of this bundle
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
index e02965b72022..cf040470d608 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandler.java
@@ -119,6 +119,10 @@ private void readObject(ObjectInputStream aInputStream)
@Override
public void addErrorCollection(PCollection errorCollection) {
+ if (isClosed()) {
+ throw new IllegalStateException(
+ "Error collections cannot be added after Error Handler is closed");
+ }
errorCollections.add(errorCollection);
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
index 39cb612f2d89..2db20b92f27f 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteFilesTest.java
@@ -61,6 +61,7 @@
import org.apache.beam.sdk.options.PipelineOptionsFactoryTest.TestPipelineOptions;
import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.testing.UsesTestStream;
@@ -78,6 +79,8 @@
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -634,6 +637,134 @@ private void testDynamicDestinationsHelper(boolean bounded, boolean emptyShards)
}
}
+ // Test FailingDynamicDestinations class. Expects user values to be string-encoded integers.
+ // Throws exceptions when trying to format records or get destinations based on the mod
+ // of the element
+ static class FailingTestDestinations extends DynamicDestinations {
+ private ResourceId baseOutputDirectory;
+
+ FailingTestDestinations(ResourceId baseOutputDirectory) {
+ this.baseOutputDirectory = baseOutputDirectory;
+ }
+
+ @Override
+ public String formatRecord(String record) {
+ int value = Integer.valueOf(record);
+ // deterministically fail to format 1/3rd of records
+ if (value % 3 == 0) {
+ throw new RuntimeException("Failed To Format Record");
+ }
+ return "record_" + record;
+ }
+
+ @Override
+ public Integer getDestination(String element) {
+ int value = Integer.valueOf(element);
+ // deterministically fail to find the destination for 1/7th of records
+ if (value % 7 == 0) {
+ throw new RuntimeException("Failed To Get Destination");
+ }
+ return value % 5;
+ }
+
+ @Override
+ public Integer getDefaultDestination() {
+ return 0;
+ }
+
+ @Override
+ public FilenamePolicy getFilenamePolicy(Integer destination) {
+ return new PerWindowFiles(
+ baseOutputDirectory.resolve("file_" + destination, StandardResolveOptions.RESOLVE_FILE),
+ "simple");
+ }
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testFailingDynamicDestinationsBounded() throws Exception {
+ testFailingDynamicDestinationsHelper(true, false);
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+ public void testFailingDynamicDestinationsUnbounded() throws Exception {
+ testFailingDynamicDestinationsHelper(false, false);
+ }
+
+ @Test
+ @Category({NeedsRunner.class, UsesUnboundedPCollections.class})
+ public void testFailingDynamicDestinationsAutosharding() throws Exception {
+ testFailingDynamicDestinationsHelper(false, true);
+ }
+
+ private void testFailingDynamicDestinationsHelper(boolean bounded, boolean autosharding)
+ throws IOException {
+ FailingTestDestinations dynamicDestinations =
+ new FailingTestDestinations(getBaseOutputDirectory());
+ SimpleSink sink =
+ new SimpleSink<>(getBaseOutputDirectory(), dynamicDestinations, Compression.UNCOMPRESSED);
+
+ // Flag to validate that the pipeline options are passed to the Sink.
+ WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class);
+ options.setTestFlag("test_value");
+ Pipeline p = TestPipeline.create(options);
+
+ final int numInputs = 100;
+ long expectedFailures = 0;
+ List inputs = Lists.newArrayList();
+ for (int i = 0; i < numInputs; ++i) {
+ inputs.add(Integer.toString(i));
+ if (i % 7 == 0 || i % 3 == 0) {
+ expectedFailures++;
+ }
+ }
+ // Prepare timestamps for the elements.
+ List timestamps = new ArrayList<>();
+ for (long i = 0; i < inputs.size(); i++) {
+ timestamps.add(i + 1);
+ }
+
+ BadRecordErrorHandler> errorHandler =
+ p.registerBadRecordErrorHandler(new ErrorSinkTransform());
+ int numShards = autosharding ? 0 : 2;
+ WriteFiles writeFiles =
+ WriteFiles.to(sink).withNumShards(numShards).withBadRecordErrorHandler(errorHandler);
+
+ PCollection input = p.apply(Create.timestamped(inputs, timestamps));
+ WriteFilesResult res;
+ if (!bounded) {
+ input.setIsBoundedInternal(IsBounded.UNBOUNDED);
+ input = input.apply(Window.into(FixedWindows.of(Duration.standardDays(1))));
+ res = input.apply(writeFiles.withWindowedWrites());
+ } else {
+ res = input.apply(writeFiles);
+ }
+
+ errorHandler.close();
+
+ PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(expectedFailures);
+
+ res.getPerDestinationOutputFilenames().apply(new VerifyFilesExist<>());
+ p.run();
+
+ for (int i = 0; i < 5; ++i) {
+ ResourceId base =
+ getBaseOutputDirectory().resolve("file_" + i, StandardResolveOptions.RESOLVE_FILE);
+ List expected = Lists.newArrayList();
+ for (int j = i; j < numInputs; j += 5) {
+ if (j % 3 != 0 && j % 7 != 0) {
+ expected.add("record_" + j);
+ }
+ }
+ checkFileContents(
+ base.toString(),
+ expected,
+ Optional.fromNullable(autosharding ? null : numShards),
+ bounded /* expectRemovedTempDirectory */);
+ }
+ }
+
@Test
public void testShardedDisplayData() {
DynamicDestinations dynamicDestinations =
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java
new file mode 100644
index 000000000000..41367765b920
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms.errorhandling;
+
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.checkerframework.checker.initialization.qual.Initialized;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
+import org.joda.time.Duration;
+
+public class ErrorHandlingTestUtils {
+ public static class ErrorSinkTransform
+ extends PTransform, PCollection> {
+
+ @Override
+ public @UnknownKeyFor @NonNull @Initialized PCollection expand(
+ PCollection input) {
+ if (input.isBounded() == IsBounded.BOUNDED) {
+ return input.apply("Combine", Combine.globally(Count.combineFn()));
+ } else {
+ return input
+ .apply("Window", Window.into(FixedWindows.of(Duration.standardDays(1))))
+ .apply("Combine", Combine.globally(Count.combineFn()).withoutDefaults());
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java
index a65db5a90bad..2e4939560ad1 100644
--- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java
+++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java
@@ -69,6 +69,8 @@
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
@@ -337,6 +339,10 @@
* events.apply("WriteAvros", AvroIO.writeCustomTypeToGenericRecords()
* .to(new UserDynamicAvroDestinations(userToSchemaMap)));
* }
+ *
+ * Error handling for writing records that are malformed can be handled by using {@link
+ * TypedWrite#withBadRecordErrorHandler(ErrorHandler)}. See documentation in {@link FileIO} for
+ * details on usage
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -1427,6 +1433,8 @@ public abstract static class TypedWrite
abstract AvroSink.@Nullable DatumWriterFactory getDatumWriterFactory();
+ abstract @Nullable ErrorHandler getBadRecordErrorHandler();
+
/**
* The codec used to encode the blocks in the Avro file. String value drawn from those in
* https://avro.apache.org/docs/1.7.7/api/java/org/apache/avro/file/CodecFactory.html
@@ -1489,6 +1497,9 @@ abstract Builder setDynamicDestinations(
abstract Builder setDatumWriterFactory(
AvroSink.DatumWriterFactory datumWriterFactory);
+ abstract Builder setBadRecordErrorHandler(
+ @Nullable ErrorHandler badRecordErrorHandler);
+
abstract TypedWrite build();
}
@@ -1713,6 +1724,12 @@ public TypedWrite withMetadata(Map
return toBuilder().setMetadata(ImmutableMap.copyOf(metadata)).build();
}
+ /** See {@link FileIO.Write#withBadRecordErrorHandler(ErrorHandler)} for details on usage. */
+ public TypedWrite withBadRecordErrorHandler(
+ ErrorHandler errorHandler) {
+ return toBuilder().setBadRecordErrorHandler(errorHandler).build();
+ }
+
DynamicAvroDestinations resolveDynamicDestinations() {
DynamicAvroDestinations dynamicDestinations =
getDynamicDestinations();
@@ -1782,6 +1799,9 @@ public WriteFilesResult expand(PCollection input) {
if (getNoSpilling()) {
write = write.withNoSpilling();
}
+ if (getBadRecordErrorHandler() != null) {
+ write = write.withBadRecordErrorHandler(getBadRecordErrorHandler());
+ }
return input.apply("Write", write);
}
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
index 5b976687f2c1..ab6ac52e318d 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
@@ -44,7 +44,6 @@
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.io.common.IOITHelper;
import org.apache.beam.sdk.io.common.IOTestPipelineOptions;
-import org.apache.beam.sdk.io.kafka.KafkaIOTest.ErrorSinkTransform;
import org.apache.beam.sdk.io.kafka.KafkaIOTest.FailingLongSerializer;
import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer;
import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource;
@@ -77,6 +76,7 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform;
import org.apache.beam.sdk.transforms.windowing.CalendarWindows;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index b0df82bcdc19..9b15b86051f5 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -88,7 +88,6 @@
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Distinct;
@@ -97,15 +96,13 @@
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.Max;
import org.apache.beam.sdk.transforms.Min;
-import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.BadRecordErrorHandler;
+import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.CalendarWindows;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.CoderUtils;
@@ -145,10 +142,7 @@
import org.apache.kafka.common.serialization.LongSerializer;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.common.utils.Utils;
-import org.checkerframework.checker.initialization.qual.Initialized;
-import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
-import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.hamcrest.collection.IsIterableContainingInAnyOrder;
import org.hamcrest.collection.IsIterableWithSize;
import org.joda.time.Duration;
@@ -1472,18 +1466,6 @@ public void testSinkWithSerializationErrors() throws Exception {
}
}
- public static class ErrorSinkTransform
- extends PTransform, PCollection> {
-
- @Override
- public @UnknownKeyFor @NonNull @Initialized PCollection expand(
- PCollection input) {
- return input
- .apply("Window", Window.into(CalendarWindows.years(1)))
- .apply("Combine", Combine.globally(Count.combineFn()).withoutDefaults());
- }
- }
-
@Test
public void testValuesSink() throws Exception {
// similar to testSink(), but use values()' interface.