Skip to content

Commit

Permalink
Add an annotation to expose transforms to yaml. (#28208)
Browse files Browse the repository at this point in the history
We should add this to all transforms that are simply parameterized.
  • Loading branch information
robertwb authored Sep 13, 2023
1 parent 141e3e6 commit cf0cf3b
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
52 changes: 52 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ class and wrapper class that allows lambda functions to be used as

import copy
import itertools
import json
import logging
import operator
import os
import sys
import threading
import warnings
from functools import reduce
from functools import wraps
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -83,6 +85,7 @@ class and wrapper class that allows lambda functions to be used as
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
from apache_beam.utils import proto_utils
from apache_beam.utils import python_callable

if TYPE_CHECKING:
from apache_beam import coders
Expand All @@ -95,6 +98,7 @@ class and wrapper class that allows lambda functions to be used as
'PTransform',
'ptransform_fn',
'label_from_callable',
'annotate_yaml',
]

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -1096,3 +1100,51 @@ def __ror__(self, pvalueish, _unused=None):

def expand(self, pvalue):
raise RuntimeError("Should never be expanded directly.")


# Defined here to avoid circular import issues for Beam library transforms.
def annotate_yaml(constructor):
"""Causes instances of this transform to be annotated with their yaml syntax.
Should only be used for transforms that are fully defined by their constructor
arguments.
"""
@wraps(constructor)
def wrapper(*args, **kwargs):
transform = constructor(*args, **kwargs)

fully_qualified_name = (
f'{constructor.__module__}.{constructor.__qualname__}')
try:
imported_constructor = (
python_callable.PythonCallableWithSource.
load_from_fully_qualified_name(fully_qualified_name))
if imported_constructor != wrapper:
raise ImportError('Different object.')
except ImportError:
warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.')
return transform

try:
config = json.dumps({
'constructor': fully_qualified_name,
'args': args,
'kwargs': kwargs,
})
except TypeError as exn:
warnings.warn(
f'Cannot serialize arguments for {constructor} as json: {exn}')
return transform

original_annotations = transform.annotations
transform.annotations = lambda: {
**original_annotations(),
# These override whatever may have been provided earlier.
# The outermost call is expected to be the most specific.
'yaml_provider': 'python',
'yaml_type': 'PyTransform',
'yaml_args': config,
}
return transform

return wrapper
30 changes: 30 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,23 @@ def test_name_is_ambiguous(self):
output: AnotherFilter
''')

def test_annotations(self):
t = LinearTransform(5, b=100)
annotations = t.annotations()
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
result = p | YamlTransform(
'''
type: chain
transforms:
- type: Create
config:
elements: [0, 1, 2, 3]
- type: %r
config: %s
''' % (annotations['yaml_type'], annotations['yaml_args']))
assert_that(result, equal_to([100, 105, 110, 115]))


class CreateTimestamped(beam.PTransform):
def __init__(self, elements):
Expand Down Expand Up @@ -631,6 +648,19 @@ def test_prefers_same_provider_class(self):
label='StartWith3')


@beam.transforms.ptransform.annotate_yaml
class LinearTransform(beam.PTransform):
"""A transform used for testing annotate_yaml."""
def __init__(self, a, b):
self._a = a
self._b = b

def expand(self, pcoll):
a = self._a
b = self._b
return pcoll | beam.Map(lambda x: a * x + b)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit cf0cf3b

Please sign in to comment.