Skip to content

Commit

Permalink
Add some simple annotations to Python transforms. (#28191)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Aug 31, 2023
1 parent 51f2542 commit f91bb68
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,15 @@ def __init__(
# allow us to effectively disambiguate in multi-model settings.
self._model_tag = uuid.uuid4().hex

def annotations(self):
return {
'model_handler': str(self._model_handler),
'model_handler_type': (
f'{self._model_handler.__class__.__module__}'
f'.{self._model_handler.__class__.__qualname__}'),
**super().annotations()
}

def _get_model_metadata_pcoll(self, pipeline):
# avoid circular imports.
# pylint: disable=wrong-import-position
Expand Down
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def default_label(self):
return self.__class__.__name__

def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]:
return {}
return {
'python_type': #
f'{self.__class__.__module__}.{self.__class__.__qualname__}'
}

def default_type_hints(self):
fn_type_hints = IOTypeHints.from_callable(self.expand)
Expand Down
10 changes: 7 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_get_pcollection_output(self):
- type: PyMap
name: Square
input: Create
config:
config:
fn: "lambda x: x*x"
'''

Expand Down Expand Up @@ -123,7 +123,11 @@ def test_create_ptransform_with_inputs(self):
self.assertIsInstance(result, beam.transforms.ParDo)
self.assertEqual(result.label, 'Map(lambda x: x*x)')

result_annotations = {**result.annotations()}
result_annotations = {
key: value
for (key, value) in result.annotations().items()
if key.startswith('yaml')
}
target_annotations = {
'yaml_type': 'PyMap',
'yaml_args': '{"fn": "lambda x: x*x"}',
Expand All @@ -146,7 +150,7 @@ def get_spec():
fn: "lambda x: x * x * x"
- type: Filter
name: FilterOutBigNumbers
input: PyMap
input: PyMap
keep: "lambda x: x<100"
'''
return yaml.load(pipeline_yaml, Loader=SafeLineLoader)
Expand Down

0 comments on commit f91bb68

Please sign in to comment.