Skip to content

Commit

Permalink
[ManagedIO] pass underlying transform URN as an annotation (#31398)
Browse files Browse the repository at this point in the history
* pass underlying transform URN to annotation

* move annotation keys to proto

* address comments: add descriptions for annotation enums; fail when missing transform_identifier; add unit tests for annotations
  • Loading branch information
ahmedabu98 authored May 30, 2024
1 parent 5454489 commit b50ad0f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ message BuilderMethod {
bytes payload = 3;
}

message Annotations {
enum Enum {
// The annotation key for the encoded configuration Row used to build a transform
CONFIG_ROW_KEY = 0 [(org.apache.beam.model.pipeline.v1.beam_constant) = "config_row"];
// The annotation key for the configuration Schema used to decode the configuration Row
CONFIG_ROW_SCHEMA_KEY = 1 [(org.apache.beam.model.pipeline.v1.beam_constant) = "config_row_schema"];
// If ths transform is a SchemaTransform, this is the annotation key for the SchemaTransform's URN
SCHEMATRANSFORM_URN_KEY = 2 [(org.apache.beam.model.pipeline.v1.beam_constant) = "schematransform_urn"];
// If the transform is a ManagedSchemaTransform, this is the annotation key for the underlying SchemaTransform's URN
MANAGED_UNDERLYING_TRANSFORM_URN_KEY = 3 [(org.apache.beam.model.pipeline.v1.beam_constant) = "managed_underlying_transform_urn"];
}
}

// Payload for a Schema-aware PTransform.
// This is a transform that is aware of its input and output PCollection schemas
// and is configured using Beam Schema-compatible parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ public class BeamUrns {
public static String getUrn(ProtocolMessageEnum value) {
return value.getValueDescriptor().getOptions().getExtension(RunnerApi.beamUrn);
}

/** Returns the constant value of a given enum annotated with [(beam_constant)]. */
public static String getConstant(ProtocolMessageEnum value) {
return value.getValueDescriptor().getOptions().getExtension(RunnerApi.beamConstant);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.util.construction;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
Expand All @@ -43,6 +44,7 @@
import org.apache.beam.sdk.io.Read;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.display.DisplayData;
Expand Down Expand Up @@ -94,16 +96,12 @@ public class PTransformTranslation {
public static final String MAP_WINDOWS_TRANSFORM_URN = "beam:transform:map_windows:v1";
public static final String MERGE_WINDOWS_TRANSFORM_URN = "beam:transform:merge_windows:v1";
public static final String TO_STRING_TRANSFORM_URN = "beam:transform:to_string:v1";
public static final String MANAGED_TRANSFORM_URN = "beam:transform:managed:v1";

// Required runner implemented transforms. These transforms should never specify an environment.
public static final ImmutableSet<String> RUNNER_IMPLEMENTED_TRANSFORMS =
ImmutableSet.of(GROUP_BY_KEY_TRANSFORM_URN, IMPULSE_TRANSFORM_URN);

public static final String CONFIG_ROW_KEY = "config_row";

public static final String CONFIG_ROW_SCHEMA_KEY = "config_row_schema";
public static final String SCHEMATRANSFORM_URN_KEY = "schematransform_urn";

// DeprecatedPrimitives
/**
* @deprecated SDKs should move away from creating `Read` transforms and migrate to using Impulse
Expand Down Expand Up @@ -522,11 +520,28 @@ public RunnerApi.PTransform translate(
}

if (spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))) {
ExternalTransforms.SchemaTransformPayload payload =
ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload());
String identifier = payload.getIdentifier();
transformBuilder.putAnnotations(
SCHEMATRANSFORM_URN_KEY,
ByteString.copyFromUtf8(
ExternalTransforms.SchemaTransformPayload.parseFrom(spec.getPayload())
.getIdentifier()));
BeamUrns.getConstant(Annotations.Enum.SCHEMATRANSFORM_URN_KEY),
ByteString.copyFromUtf8(identifier));
if (identifier.equals(MANAGED_TRANSFORM_URN)) {
Schema configSchema =
SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
Row configRow =
RowCoder.of(configSchema).decode(payload.getConfigurationRow().newInput());
String underlyingIdentifier = configRow.getString("transform_identifier");
if (underlyingIdentifier == null) {
throw new IllegalStateException(
String.format(
"Encountered a Managed Transform that has an empty \"transform_identifier\": \n%s",
configRow));
}
transformBuilder.putAnnotations(
BeamUrns.getConstant(Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY),
ByteString.copyFromUtf8(underlyingIdentifier));
}
}
}

Expand All @@ -546,12 +561,12 @@ public RunnerApi.PTransform translate(
}
if (configRow != null) {
transformBuilder.putAnnotations(
CONFIG_ROW_KEY,
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_KEY),
ByteString.copyFrom(
CoderUtils.encodeToByteArray(RowCoder.of(configRow.getSchema()), configRow)));

transformBuilder.putAnnotations(
CONFIG_ROW_SCHEMA_KEY,
BeamUrns.getConstant(Annotations.Enum.CONFIG_ROW_SCHEMA_KEY),
ByteString.copyFrom(
SchemaTranslation.schemaToProto(configRow.getSchema(), true).toByteArray()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,11 @@ RunnerApi.Pipeline updateTransformViaTransformService(
throw new IllegalArgumentException("Could not find a transform with the ID " + transformId);
}
ByteString configRowBytes =
transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_KEY);
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY));
ByteString configRowSchemaBytes =
transformToUpgrade.getAnnotationsOrThrow(PTransformTranslation.CONFIG_ROW_SCHEMA_KEY);
transformToUpgrade.getAnnotationsOrThrow(
BeamUrns.getConstant(ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY));
SchemaApi.Schema configRowSchemaProto =
SchemaApi.Schema.parseFrom(configRowSchemaBytes.toByteArray());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
*/
package org.apache.beam.sdk.managed;

import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.CONFIG_ROW_KEY;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.CONFIG_ROW_SCHEMA_KEY;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.MANAGED_UNDERLYING_TRANSFORM_URN_KEY;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.Annotations.Enum.SCHEMATRANSFORM_URN_KEY;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
import static org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload;
import static org.apache.beam.sdk.managed.ManagedSchemaTransformProvider.ManagedConfig;
import static org.apache.beam.sdk.managed.ManagedSchemaTransformProvider.ManagedSchemaTransform;
import static org.apache.beam.sdk.managed.ManagedSchemaTransformTranslation.ManagedSchemaTransformTranslator;
import static org.apache.beam.sdk.util.construction.PTransformTranslation.MANAGED_TRANSFORM_URN;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
Expand All @@ -41,11 +46,13 @@
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.junit.Test;
Expand Down Expand Up @@ -154,9 +161,38 @@ public void testProtoTranslation() throws Exception {
})
.collect(Collectors.toList());
assertEquals(1, managedTransformProto.size());
RunnerApi.FunctionSpec spec = managedTransformProto.get(0).getSpec();
RunnerApi.PTransform convertedTransform = managedTransformProto.get(0);

// Check that the proto contains correct values
// Check the transform proto contains the correct annotations.
// These annotations can be accessed and used by the runner to make decisions
Row managedConfigRow =
Row.withSchema(PROVIDER.configurationSchema())
.withFieldValue("transform_identifier", TestSchemaTransformProvider.IDENTIFIER)
.withFieldValue("config", yamlStringConfig)
.build();
Map<String, ByteString> expectedAnnotations =
ImmutableMap.<String, ByteString>builder()
.put(
BeamUrns.getConstant(SCHEMATRANSFORM_URN_KEY),
ByteString.copyFromUtf8(MANAGED_TRANSFORM_URN))
.put(
BeamUrns.getConstant(MANAGED_UNDERLYING_TRANSFORM_URN_KEY),
ByteString.copyFromUtf8(TestSchemaTransformProvider.IDENTIFIER))
.put(
BeamUrns.getConstant(CONFIG_ROW_KEY),
ByteString.copyFrom(
CoderUtils.encodeToByteArray(
RowCoder.of(PROVIDER.configurationSchema()), managedConfigRow)))
.put(
BeamUrns.getConstant(CONFIG_ROW_SCHEMA_KEY),
ByteString.copyFrom(
SchemaTranslation.schemaToProto(PROVIDER.configurationSchema(), true)
.toByteArray()))
.build();
assertEquals(expectedAnnotations, convertedTransform.getAnnotationsMap());

// Check that the spec proto contains correct values
RunnerApi.FunctionSpec spec = convertedTransform.getSpec();
SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload());
assertEquals(PROVIDER.identifier(), payload.getIdentifier());
Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
Expand Down

0 comments on commit b50ad0f

Please sign in to comment.