Skip to content

Commit

Permalink
Simplify Managed API to avoid dealing with PCollectionRowTuple (apach…
Browse files Browse the repository at this point in the history
…e#31470)

* Managed accepts PInput type

* add unit test

* spotless

* spotless

* rename to getSinglePCollection
  • Loading branch information
ahmedabu98 authored Jun 4, 2024
1 parent 195dc3f commit 7ea8cd2
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Objects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -180,6 +181,22 @@ public PCollection<Row> get(String tag) {
return pcollection;
}

/**
* Like {@link #get(String)}, but is a convenience method to get a single PCollection without
* providing a tag for that output. Use only when there is a single collection in this tuple.
*
* <p>Throws {@link IllegalStateException} if more than one output exists in the {@link
* PCollectionRowTuple}.
*/
public PCollection<Row> getSinglePCollection() {
Preconditions.checkState(
pcollectionMap.size() == 1,
"Expected exactly one output PCollection<Row>, but found %s. "
+ "Please try retrieving a specified output using get(<tag>) instead.",
pcollectionMap.size());
return get(pcollectionMap.entrySet().iterator().next().getKey());
}

/**
* Returns an immutable Map from tag to corresponding {@link PCollection}, for all the members of
* this {@link PCollectionRowTuple}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.hadoop.conf.Configuration;
Expand Down Expand Up @@ -216,11 +215,10 @@ public void testRead() throws Exception {
.build())
.build();

PCollectionRowTuple output =
PCollectionRowTuple.empty(readPipeline)
.apply(Managed.read(Managed.ICEBERG).withConfig(config));
PCollection<Row> rows =
readPipeline.apply(Managed.read(Managed.ICEBERG).withConfig(config)).getSinglePCollection();

PAssert.that(output.get("output")).containsInAnyOrder(expectedRows);
PAssert.that(rows).containsInAnyOrder(expectedRows);
readPipeline.run().waitUntilFinish();
}

Expand Down Expand Up @@ -258,7 +256,7 @@ public void testWrite() {
.build();

PCollection<Row> input = writePipeline.apply(Create.of(inputRows)).setRowSchema(BEAM_SCHEMA);
PCollectionRowTuple.of("input", input).apply(Managed.write(Managed.ICEBERG).withConfig(config));
input.apply(Managed.write(Managed.ICEBERG).withConfig(config));

writePipeline.run().waitUntilFinish();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ public void testReadUsingManagedTransform() throws Exception {
Map<String, Object> configMap = new Yaml().load(yamlConfig);

PCollection<Row> output =
PCollectionRowTuple.empty(testPipeline)
testPipeline
.apply(Managed.read(Managed.ICEBERG).withConfig(configMap))
.get(OUTPUT_TAG);
.getSinglePCollection();

PAssert.that(output)
.satisfies(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,12 @@ public void testWriteUsingManagedTransform() {
identifier, CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP, warehouse.location);
Map<String, Object> configMap = new Yaml().load(yamlConfig);

PCollectionRowTuple input =
PCollectionRowTuple.of(
INPUT_TAG,
testPipeline
.apply(
"Records To Add", Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1)))
.setRowSchema(
SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA)));
PCollection<Row> inputRows =
testPipeline
.apply("Records To Add", Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1)))
.setRowSchema(SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA));
PCollection<Row> result =
input.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG);
inputRows.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG);

PAssert.that(result).satisfies(new VerifyOutputs(identifier, "append"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2663,7 +2663,7 @@ abstract static class Builder<K, V> {
abstract Builder<K, V> setProducerConfig(Map<String, Object> producerConfig);

abstract Builder<K, V> setProducerFactoryFn(
SerializableFunction<Map<String, Object>, Producer<K, V>> fn);
@Nullable SerializableFunction<Map<String, Object>, Producer<K, V>> fn);

abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>> serializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.apache.beam.sdk.managed.ManagedTransformConstants;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
Expand Down Expand Up @@ -319,7 +319,7 @@ public void testBuildTransformWithManaged() {
// Kafka Read SchemaTransform gets built in ManagedSchemaTransformProvider's expand
Managed.read(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
.expand(PCollectionRowTuple.empty(Pipeline.create()));
.expand(PBegin.in(Pipeline.create()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
Expand Down Expand Up @@ -225,10 +224,8 @@ public void testBuildTransformWithManaged() {
Managed.write(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
.expand(
PCollectionRowTuple.of(
"input",
Pipeline.create()
.apply(Create.empty(Schema.builder().addByteArrayField("bytes").build()))));
Pipeline.create()
.apply(Create.empty(Schema.builder().addByteArrayField("bytes").build())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
Expand All @@ -47,12 +52,13 @@
* specifies arguments using like so:
*
* <pre>{@code
* PCollectionRowTuple output = PCollectionRowTuple.empty(pipeline).apply(
* PCollection<Row> rows = pipeline.apply(
* Managed.read(ICEBERG)
* .withConfig(ImmutableMap.<String, Object>.builder()
* .put("foo", "abc")
* .put("bar", 123)
* .build()));
* .build()))
* .getOutput();
* }</pre>
*
* <p>Instead of specifying configuration arguments directly in the code, one can provide the
Expand All @@ -66,11 +72,9 @@
* <p>The file's path can be passed in to the Managed API like so:
*
* <pre>{@code
* PCollectionRowTuple input = PCollectionRowTuple.of("input", pipeline.apply(Create.of(...)))
* PCollection<Row> inputRows = pipeline.apply(Create.of(...));
*
* PCollectionRowTuple output = input.apply(
* Managed.write(ICEBERG)
* .withConfigUrl(<config path>));
* input.apply(Managed.write(ICEBERG).withConfigUrl(<config path>));
* }</pre>
*/
public class Managed {
Expand Down Expand Up @@ -132,8 +136,7 @@ public static ManagedTransform write(String sink) {
}

@AutoValue
public abstract static class ManagedTransform
extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {
public abstract static class ManagedTransform extends PTransform<PInput, PCollectionRowTuple> {
abstract String getIdentifier();

abstract @Nullable Map<String, Object> getConfig();
Expand Down Expand Up @@ -183,7 +186,9 @@ ManagedTransform withSupportedIdentifiers(List<String> supportedIdentifiers) {
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
public PCollectionRowTuple expand(PInput input) {
PCollectionRowTuple inputTuple = resolveInput(input);

ManagedSchemaTransformProvider.ManagedConfig managedConfig =
ManagedSchemaTransformProvider.ManagedConfig.builder()
.setTransformIdentifier(getIdentifier())
Expand All @@ -194,7 +199,28 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
SchemaTransform underlyingTransform =
new ManagedSchemaTransformProvider(getSupportedIdentifiers()).from(managedConfig);

return input.apply(underlyingTransform);
return inputTuple.apply(underlyingTransform);
}

@VisibleForTesting
static PCollectionRowTuple resolveInput(PInput input) {
if (input instanceof PBegin) {
return PCollectionRowTuple.empty(input.getPipeline());
} else if (input instanceof PCollection) {
PCollection<?> inputCollection = (PCollection<?>) input;
Preconditions.checkArgument(
inputCollection.getCoder() instanceof RowCoder,
"Input PCollection must contain Row elements with a set Schema "
+ "(using .setRowSchema()). Instead, found collection %s with coder: %s.",
inputCollection.getName(),
inputCollection.getCoder());
return PCollectionRowTuple.of(
ManagedTransformConstants.INPUT, (PCollection<Row>) inputCollection);
} else if (input instanceof PCollectionRowTuple) {
return (PCollectionRowTuple) input;
}

throw new IllegalArgumentException("Unsupported input type: " + input.getClass());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
* every single parameter through the Managed interface.
*/
public class ManagedTransformConstants {
// Standard input PCollection tag
public static final String INPUT = "input";

public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1";
public static final String ICEBERG_WRITE =
"beam:schematransform:org.apache.beam:iceberg_write:v1";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
Expand Down Expand Up @@ -141,7 +140,7 @@ public void testProtoTranslation() throws Exception {
.setIdentifier(TestSchemaTransformProvider.IDENTIFIER)
.build()
.withConfig(underlyingConfig);
PCollectionRowTuple.of("input", input).apply(transform).get("output");
input.apply(transform);

// Then translate the pipeline to a proto and extract the ManagedSchemaTransform's proto
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@
*/
package org.apache.beam.sdk.managed;

import static org.junit.Assert.assertThrows;

import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.managed.testing.TestSchemaTransformProvider;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.junit.Rule;
Expand Down Expand Up @@ -61,11 +67,33 @@ public void testInvalidTransform() {
Row.withSchema(SCHEMA).withFieldValue("str", "b").withFieldValue("int", 2).build(),
Row.withSchema(SCHEMA).withFieldValue("str", "c").withFieldValue("int", 3).build());

@Test
public void testResolveInputToPCollectionRowTuple() {
Pipeline p = Pipeline.create();
List<PInput> inputTypes =
Arrays.asList(
PBegin.in(p),
p.apply(Create.of(ROWS).withRowSchema(SCHEMA)),
PCollectionRowTuple.of("pcoll", p.apply(Create.of(ROWS).withRowSchema(SCHEMA))));

List<PInput> badInputTypes =
Arrays.asList(
p.apply(Create.of(1, 2, 3)),
p.apply(Create.of(ROWS)),
PCollectionTuple.of("pcoll", p.apply(Create.of(ROWS))));

for (PInput input : inputTypes) {
Managed.ManagedTransform.resolveInput(input);
}
for (PInput badInput : badInputTypes) {
assertThrows(
IllegalArgumentException.class, () -> Managed.ManagedTransform.resolveInput(badInput));
}
}

public void runTestProviderTest(Managed.ManagedTransform writeOp) {
PCollection<Row> rows =
PCollectionRowTuple.of("input", pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA))
.apply(writeOp)
.get("output");
pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA).apply(writeOp).getSinglePCollection();

Schema outputSchema = rows.getSchema();
PAssert.that(rows)
Expand Down

0 comments on commit 7ea8cd2

Please sign in to comment.