diff --git a/CHANGES.md b/CHANGES.md index 40c061e07240..367d6f7ec98d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -98,6 +98,7 @@ * Added Cloud Bigtable Table with Read operation to Beam SQL ([BEAM-11173](https://issues.apache.org/jira/browse/BEAM-11173)) * Added option to disable unnecessary copying between operators in Flink Runner (Java) ([BEAM-11146](https://issues.apache.org/jira/browse/BEAM-11146)) * Added CombineFn.setup and CombineFn.teardown to Python SDK. These methods let you initialize the CombineFn's state before any of the other methods of the CombineFn is executed and clean that state up later on. If you are using Dataflow, you need to enable Dataflow Runner V2 by passing `--experiments=use_runner_v2` before using this feature. ([BEAM-3736](https://issues.apache.org/jira/browse/BEAM-3736)) +* Added support for NestedValueProvider for the Python SDK ([BEAM-10856](https://issues.apache.org/jira/browse/BEAM-10856)). * X feature added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)). ## Breaking Changes diff --git a/sdks/python/apache_beam/options/value_provider.py b/sdks/python/apache_beam/options/value_provider.py index fde61c85db77..0fa5f2b5f157 100644 --- a/sdks/python/apache_beam/options/value_provider.py +++ b/sdks/python/apache_beam/options/value_provider.py @@ -15,8 +15,11 @@ # limitations under the License. # -"""A ValueProvider class to implement templates with both statically -and dynamically provided values. +"""A ValueProvider abstracts the notion of fetching a value that may or +may not be currently available. + +This can be used to parameterize transforms that only read values in at +runtime, for example. """ # pytype: skip-file @@ -33,22 +36,38 @@ 'ValueProvider', 'StaticValueProvider', 'RuntimeValueProvider', + 'NestedValueProvider', 'check_accessible', ] class ValueProvider(object): + """Base class that all other ValueProviders must implement. + """ def is_accessible(self): + """Whether the contents of this ValueProvider is available to routines + that run at graph construction time. + """ raise NotImplementedError( 'ValueProvider.is_accessible implemented in derived classes') def get(self): + """Return the value wrapped by this ValueProvider. + """ raise NotImplementedError( 'ValueProvider.get implemented in derived classes') class StaticValueProvider(ValueProvider): + """StaticValueProvider is an implementation of ValueProvider that allows + for a static value to be provided. + """ def __init__(self, value_type, value): + """ + Args: + value_type: Type of the static value + value: Static value + """ self.value_type = value_type self.value = value_type(value) @@ -78,6 +97,10 @@ def __hash__(self): class RuntimeValueProvider(ValueProvider): + """RuntimeValueProvider is an implementation of ValueProvider that + allows for a value to be provided at execution time rather than + at graph construction time. + """ runtime_options = None experiments = set() # type: Set[str] @@ -122,8 +145,50 @@ def __str__(self): repr(self.default_value)) +class NestedValueProvider(ValueProvider): + """NestedValueProvider is an implementation of ValueProvider that allows + for wrapping another ValueProvider object. + """ + def __init__(self, value, translator): + """Creates a NestedValueProvider that wraps the provided ValueProvider. + + Args: + value: ValueProvider object to wrap + translator: function that is applied to the ValueProvider + Raises: + ``RuntimeValueProviderError``: if any of the provided objects are not + accessible. + """ + self.value = value + self.translator = translator + + def is_accessible(self): + return self.value.is_accessible() + + def get(self): + try: + return self.cached_value + except AttributeError: + self.cached_value = self.translator(self.value.get()) + return self.cached_value + + def __str__(self): + return "%s(value: %s, translator: %s)" % ( + self.__class__.__name__, + self.value, + self.translator.__name__, + ) + + def check_accessible(value_provider_list): - """Check accessibility of a list of ValueProvider objects.""" + """A decorator that checks accessibility of a list of ValueProvider objects. + + Args: + value_provider_list: list of ValueProvider objects + Raises: + ``RuntimeValueProviderError``: if any of the provided objects are not + accessible. + """ assert isinstance(value_provider_list, list) def _check_accessible(fnc): diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py index 52525607cb8e..189501bb9eed 100644 --- a/sdks/python/apache_beam/options/value_provider_test.py +++ b/sdks/python/apache_beam/options/value_provider_test.py @@ -24,8 +24,11 @@ import logging import unittest +from mock import Mock + from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.value_provider import NestedValueProvider from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.options.value_provider import StaticValueProvider @@ -218,6 +221,43 @@ def test_experiments_options_setup(self): self.assertIn('b,c', options.experiments) self.assertNotIn('c', options.experiments) + def test_nested_value_provider_wrap_static(self): + vp = NestedValueProvider(StaticValueProvider(int, 1), lambda x: x + 1) + + self.assertTrue(vp.is_accessible()) + self.assertEqual(vp.get(), 2) + + def test_nested_value_provider_caches_value(self): + mock_fn = Mock() + + def translator(x): + mock_fn() + return x + + vp = NestedValueProvider(StaticValueProvider(int, 1), translator) + + vp.get() + self.assertEqual(mock_fn.call_count, 1) + vp.get() + self.assertEqual(mock_fn.call_count, 1) + + def test_nested_value_provider_wrap_runtime(self): + class UserDefinedOptions(PipelineOptions): + @classmethod + def _add_argparse_args(cls, parser): + parser.add_value_provider_argument( + '--vpt_vp_arg15', + help='This keyword argument is a value provider') # set at runtime + + options = UserDefinedOptions([]) + vp = NestedValueProvider(options.vpt_vp_arg15, lambda x: x + x) + self.assertFalse(vp.is_accessible()) + + RuntimeValueProvider.set_runtime_options({'vpt_vp_arg15': 'abc'}) + + self.assertTrue(vp.is_accessible()) + self.assertEqual(vp.get(), 'abcabc') + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)