Skip to content

Commit

Permalink
Merge pull request #12779 from [BEAM-10856] Support for NestedValuePr…
Browse files Browse the repository at this point in the history
…ovider for Python SDK

* Support for NestedValueProvider for Python SDK

* Fix typo

* Update CHANGES.md

* Update value_provider_test.py

* Fix NestedValueProvider docstrings. (#1)

* Fix isort and doc errors. (#2)

* Update CHANGES.md

Co-authored-by: Eugene Nikolaiev <eugene.nikolayev@gmail.com>
  • Loading branch information
epicfaace and nikie authored Dec 1, 2020
1 parent 02da3ae commit 87f3138
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
* Added Cloud Bigtable Table with Read operation to Beam SQL ([BEAM-11173](https://issues.apache.org/jira/browse/BEAM-11173))
* Added option to disable unnecessary copying between operators in Flink Runner (Java) ([BEAM-11146](https://issues.apache.org/jira/browse/BEAM-11146))
* Added CombineFn.setup and CombineFn.teardown to Python SDK. These methods let you initialize the CombineFn's state before any of the other methods of the CombineFn is executed and clean that state up later on. If you are using Dataflow, you need to enable Dataflow Runner V2 by passing `--experiments=use_runner_v2` before using this feature. ([BEAM-3736](https://issues.apache.org/jira/browse/BEAM-3736))
* Added support for NestedValueProvider for the Python SDK ([BEAM-10856](https://issues.apache.org/jira/browse/BEAM-10856)).
* X feature added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).

## Breaking Changes
Expand Down
71 changes: 68 additions & 3 deletions sdks/python/apache_beam/options/value_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# limitations under the License.
#

"""A ValueProvider class to implement templates with both statically
and dynamically provided values.
"""A ValueProvider abstracts the notion of fetching a value that may or
may not be currently available.
This can be used to parameterize transforms that only read values in at
runtime, for example.
"""

# pytype: skip-file
Expand All @@ -33,22 +36,38 @@
'ValueProvider',
'StaticValueProvider',
'RuntimeValueProvider',
'NestedValueProvider',
'check_accessible',
]


class ValueProvider(object):
"""Base class that all other ValueProviders must implement.
"""
def is_accessible(self):
"""Whether the contents of this ValueProvider is available to routines
that run at graph construction time.
"""
raise NotImplementedError(
'ValueProvider.is_accessible implemented in derived classes')

def get(self):
"""Return the value wrapped by this ValueProvider.
"""
raise NotImplementedError(
'ValueProvider.get implemented in derived classes')


class StaticValueProvider(ValueProvider):
"""StaticValueProvider is an implementation of ValueProvider that allows
for a static value to be provided.
"""
def __init__(self, value_type, value):
"""
Args:
value_type: Type of the static value
value: Static value
"""
self.value_type = value_type
self.value = value_type(value)

Expand Down Expand Up @@ -78,6 +97,10 @@ def __hash__(self):


class RuntimeValueProvider(ValueProvider):
"""RuntimeValueProvider is an implementation of ValueProvider that
allows for a value to be provided at execution time rather than
at graph construction time.
"""
runtime_options = None
experiments = set() # type: Set[str]

Expand Down Expand Up @@ -122,8 +145,50 @@ def __str__(self):
repr(self.default_value))


class NestedValueProvider(ValueProvider):
"""NestedValueProvider is an implementation of ValueProvider that allows
for wrapping another ValueProvider object.
"""
def __init__(self, value, translator):
"""Creates a NestedValueProvider that wraps the provided ValueProvider.
Args:
value: ValueProvider object to wrap
translator: function that is applied to the ValueProvider
Raises:
``RuntimeValueProviderError``: if any of the provided objects are not
accessible.
"""
self.value = value
self.translator = translator

def is_accessible(self):
return self.value.is_accessible()

def get(self):
try:
return self.cached_value
except AttributeError:
self.cached_value = self.translator(self.value.get())
return self.cached_value

def __str__(self):
return "%s(value: %s, translator: %s)" % (
self.__class__.__name__,
self.value,
self.translator.__name__,
)


def check_accessible(value_provider_list):
"""Check accessibility of a list of ValueProvider objects."""
"""A decorator that checks accessibility of a list of ValueProvider objects.
Args:
value_provider_list: list of ValueProvider objects
Raises:
``RuntimeValueProviderError``: if any of the provided objects are not
accessible.
"""
assert isinstance(value_provider_list, list)

def _check_accessible(fnc):
Expand Down
40 changes: 40 additions & 0 deletions sdks/python/apache_beam/options/value_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import logging
import unittest

from mock import Mock

from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.value_provider import NestedValueProvider
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.options.value_provider import StaticValueProvider

Expand Down Expand Up @@ -218,6 +221,43 @@ def test_experiments_options_setup(self):
self.assertIn('b,c', options.experiments)
self.assertNotIn('c', options.experiments)

def test_nested_value_provider_wrap_static(self):
vp = NestedValueProvider(StaticValueProvider(int, 1), lambda x: x + 1)

self.assertTrue(vp.is_accessible())
self.assertEqual(vp.get(), 2)

def test_nested_value_provider_caches_value(self):
mock_fn = Mock()

def translator(x):
mock_fn()
return x

vp = NestedValueProvider(StaticValueProvider(int, 1), translator)

vp.get()
self.assertEqual(mock_fn.call_count, 1)
vp.get()
self.assertEqual(mock_fn.call_count, 1)

def test_nested_value_provider_wrap_runtime(self):
class UserDefinedOptions(PipelineOptions):
@classmethod
def _add_argparse_args(cls, parser):
parser.add_value_provider_argument(
'--vpt_vp_arg15',
help='This keyword argument is a value provider') # set at runtime

options = UserDefinedOptions([])
vp = NestedValueProvider(options.vpt_vp_arg15, lambda x: x + x)
self.assertFalse(vp.is_accessible())

RuntimeValueProvider.set_runtime_options({'vpt_vp_arg15': 'abc'})

self.assertTrue(vp.is_accessible())
self.assertEqual(vp.get(), 'abcabc')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 87f3138

Please sign in to comment.