diff --git a/CHANGES.md b/CHANGES.md index b791dd806d4a1..ed40ffcb04af3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -66,7 +66,7 @@ ## New Features / Improvements -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* [Enrichment Transform](https://s.apache.org/enrichment-transform) along with GCP BigTable handler added to Python SDK ([#30001](https://github.com/apache/beam/pull/30001)). ## Breaking Changes diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py new file mode 100644 index 0000000000000..63ec7061d3e5a --- /dev/null +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -0,0 +1,413 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""``PTransform`` for reading from and writing to Web APIs.""" +import abc +import concurrent.futures +import contextlib +import logging +import sys +import time +from typing import Generic +from typing import Optional +from typing import TypeVar + +from google.api_core.exceptions import TooManyRequests + +import apache_beam as beam +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics import Metrics +from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC +from apache_beam.utils import retry + +RequestT = TypeVar('RequestT') +ResponseT = TypeVar('ResponseT') + +DEFAULT_TIMEOUT_SECS = 30 # seconds + +_LOGGER = logging.getLogger(__name__) + + +class UserCodeExecutionException(Exception): + """Base class for errors related to calling Web APIs.""" + + +class UserCodeQuotaException(UserCodeExecutionException): + """Extends ``UserCodeExecutionException`` to signal specifically that + the Web API client encountered a Quota or API overuse related error. + """ + + +class UserCodeTimeoutException(UserCodeExecutionException): + """Extends ``UserCodeExecutionException`` to signal a user code timeout.""" + + +def retry_on_exception(exception: Exception): + """retry on exceptions caused by unavailability of the remote server.""" + return isinstance( + exception, + (TooManyRequests, UserCodeTimeoutException, UserCodeQuotaException)) + + +class _MetricsCollector: + """A metrics collector that tracks RequestResponseIO related usage.""" + def __init__(self, namespace: str): + """ + Args: + namespace: Namespace for the metrics. + """ + self.requests = Metrics.counter(namespace, 'requests') + self.responses = Metrics.counter(namespace, 'responses') + self.failures = Metrics.counter(namespace, 'failures') + self.throttled_requests = Metrics.counter(namespace, 'throttled_requests') + self.throttled_secs = Metrics.counter( + namespace, 'cumulativeThrottlingSeconds') + self.timeout_requests = Metrics.counter(namespace, 'requests_timed_out') + self.call_counter = Metrics.counter(namespace, 'call_invocations') + self.setup_counter = Metrics.counter(namespace, 'setup_counter') + self.teardown_counter = Metrics.counter(namespace, 'teardown_counter') + self.backoff_counter = Metrics.counter(namespace, 'backoff_counter') + self.sleeper_counter = Metrics.counter(namespace, 'sleeper_counter') + self.should_backoff_counter = Metrics.counter( + namespace, 'should_backoff_counter') + + +class Caller(contextlib.AbstractContextManager, + abc.ABC, + Generic[RequestT, ResponseT]): + """Interface for user custom code intended for API calls. + For setup and teardown of clients when applicable, implement the + ``__enter__`` and ``__exit__`` methods respectively.""" + @abc.abstractmethod + def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT: + """Calls a Web API with the ``RequestT`` and returns a + ``ResponseT``. ``RequestResponseIO`` expects implementations of the + ``__call__`` method to throw either a ``UserCodeExecutionException``, + ``UserCodeQuotaException``, or ``UserCodeTimeoutException``. + """ + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + +class ShouldBackOff(abc.ABC): + """ + ShouldBackOff provides mechanism to apply adaptive throttling. + """ + pass + + +class Repeater(abc.ABC): + """Repeater provides mechanism to repeat requests for a + configurable condition.""" + @abc.abstractmethod + def repeat( + self, + caller: Caller[RequestT, ResponseT], + request: RequestT, + timeout: float, + metrics_collector: Optional[_MetricsCollector]) -> ResponseT: + """repeat method is called from the RequestResponseIO when + a repeater is enabled. + + Args: + caller: :class:`apache_beam.io.requestresponse.Caller` object that calls + the API. + request: input request to repeat. + timeout: time to wait for the request to complete. + metrics_collector: (Optional) a + ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to + collect the metrics for RequestResponseIO. + """ + pass + + +def _execute_request( + caller: Caller[RequestT, ResponseT], + request: RequestT, + timeout: float, + metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT: + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(caller, request) + try: + return future.result(timeout=timeout) + except TooManyRequests as e: + _LOGGER.info( + 'request could not be completed. got code %i from the service.', + e.code) + raise e + except concurrent.futures.TimeoutError: + if metrics_collector: + metrics_collector.timeout_requests.inc(1) + raise UserCodeTimeoutException( + f'Timeout {timeout} exceeded ' + f'while completing request: {request}') + except RuntimeError: + if metrics_collector: + metrics_collector.failures.inc(1) + raise UserCodeExecutionException('could not complete request') + + +class ExponentialBackOffRepeater(Repeater): + """Exponential BackOff Repeater uses exponential backoff retry strategy for + exceptions due to the remote service such as TooManyRequests (HTTP 429), + UserCodeTimeoutException, UserCodeQuotaException. + + It utilizes the decorator + :func:`apache_beam.utils.retry.with_exponential_backoff`. + """ + def __init__(self): + pass + + @retry.with_exponential_backoff( + num_retries=2, retry_filter=retry_on_exception) + def repeat( + self, + caller: Caller[RequestT, ResponseT], + request: RequestT, + timeout: float, + metrics_collector: Optional[_MetricsCollector] = None) -> ResponseT: + """repeat method is called from the RequestResponseIO when + a repeater is enabled. + + Args: + caller: :class:`apache_beam.io.requestresponse.Caller` object that + calls the API. + request: input request to repeat. + timeout: time to wait for the request to complete. + metrics_collector: (Optional) a + ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to + collect the metrics for RequestResponseIO. + """ + return _execute_request(caller, request, timeout, metrics_collector) + + +class NoOpsRepeater(Repeater): + """ + NoOpsRepeater executes a request just once irrespective of any exception. + """ + def repeat( + self, + caller: Caller[RequestT, ResponseT], + request: RequestT, + timeout: float, + metrics_collector: Optional[_MetricsCollector]) -> ResponseT: + return _execute_request(caller, request, timeout, metrics_collector) + + +class CacheReader(abc.ABC): + """CacheReader provides mechanism to read from the cache.""" + pass + + +class CacheWriter(abc.ABC): + """CacheWriter provides mechanism to write to the cache.""" + pass + + +class PreCallThrottler(abc.ABC): + """PreCallThrottler provides a throttle mechanism before sending request.""" + pass + + +class DefaultThrottler(PreCallThrottler): + """Default throttler that uses + :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` + + Args: + window_ms (int): length of history to consider, in ms, to set throttling. + bucket_ms (int): granularity of time buckets that we store data in, in ms. + overload_ratio (float): the target ratio between requests sent and + successful requests. This is "K" in the formula in + https://landing.google.com/sre/book/chapters/handling-overload.html. + delay_secs (int): minimum number of seconds to throttle a request. + """ + def __init__( + self, + window_ms: int = 1, + bucket_ms: int = 1, + overload_ratio: float = 2, + delay_secs: int = 5): + self.throttler = AdaptiveThrottler( + window_ms=window_ms, bucket_ms=bucket_ms, overload_ratio=overload_ratio) + self.delay_secs = delay_secs + + +class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A :class:`RequestResponseIO` transform to read and write to APIs. + + Processes an input :class:`~apache_beam.pvalue.PCollection` of requests + by making a call to the API as defined in :class:`Caller`'s `__call__` + and returns a :class:`~apache_beam.pvalue.PCollection` of responses. + """ + def __init__( + self, + caller: Caller[RequestT, ResponseT], + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Repeater = ExponentialBackOffRepeater(), + cache_reader: Optional[CacheReader] = None, + cache_writer: Optional[CacheWriter] = None, + throttler: PreCallThrottler = DefaultThrottler(), + ): + """ + Instantiates a RequestResponseIO transform. + + Args: + caller (~apache_beam.io.requestresponse.Caller): an implementation of + `Caller` object that makes call to the API. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): + (Optional) provides methods for backoff. + repeater (~apache_beam.io.requestresponse.Repeater): provides method to + repeat failed requests to API due to service errors. Defaults to + :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to + repeat requests with exponential backoff. + cache_reader (~apache_beam.io.requestresponse.CacheReader): (Optional) + provides methods to read external cache. + cache_writer (~apache_beam.io.requestresponse.CacheWriter): (Optional) + provides methods to write to external cache. + throttler (~apache_beam.io.requestresponse.PreCallThrottler): + provides methods to pre-throttle a request. Defaults to + :class:`apache_beam.io.requestresponse.DefaultThrottler` for + client-side adaptive throttling using + :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` + """ + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + if repeater: + self._repeater = repeater + else: + self._repeater = NoOpsRepeater() + self._cache_reader = cache_reader + self._cache_writer = cache_writer + self._throttler = throttler + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + # TODO(riteshghorse): handle Cache and Throttle PTransforms when available. + if isinstance(self._throttler, DefaultThrottler): + return requests | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater, + throttler=self._throttler) + else: + return requests | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater) + + +class _Call(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """(Internal-only) PTransform that invokes a remote function on each element + of the input PCollection. + + This PTransform uses a `Caller` object to invoke the actual API calls, + and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of + clients when applicable. Additionally, a timeout value is specified to + regulate the duration of each call, defaults to 30 seconds. + + Args: + caller (:class:`apache_beam.io.requestresponse.Caller`): a callable + object that invokes API call. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): + (Optional) provides methods for backoff. + repeater (~apache_beam.io.requestresponse.Repeater): (Optional) provides + methods to repeat requests to API. + throttler (~apache_beam.io.requestresponse.PreCallThrottler): + (Optional) provides methods to pre-throttle a request. + """ + def __init__( + self, + caller: Caller[RequestT, ResponseT], + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Repeater = None, + throttler: PreCallThrottler = None, + ): + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + self._repeater = repeater + self._throttler = throttler + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + return requests | beam.ParDo( + _CallDoFn(self._caller, self._timeout, self._repeater, self._throttler)) + + +class _CallDoFn(beam.DoFn): + def setup(self): + self._caller.__enter__() + self._metrics_collector = _MetricsCollector(self._caller.__str__()) + self._metrics_collector.setup_counter.inc(1) + + def __init__( + self, + caller: Caller[RequestT, ResponseT], + timeout: float, + repeater: Repeater, + throttler: PreCallThrottler): + self._metrics_collector = None + self._caller = caller + self._timeout = timeout + self._repeater = repeater + self._throttler = throttler + + def process(self, request: RequestT, *args, **kwargs): + self._metrics_collector.requests.inc(1) + + is_throttled_request = False + if self._throttler: + while self._throttler.throttler.throttle_request(time.time() * + MSEC_TO_SEC): + _LOGGER.info( + "Delaying request for %d seconds" % self._throttler.delay_secs) + time.sleep(self._throttler.delay_secs) + self._metrics_collector.throttled_secs.inc(self._throttler.delay_secs) + is_throttled_request = True + + if is_throttled_request: + self._metrics_collector.throttled_requests.inc(1) + + try: + req_time = time.time() + response = self._repeater.repeat( + self._caller, request, self._timeout, self._metrics_collector) + self._metrics_collector.responses.inc(1) + self._throttler.throttler.successful_request(req_time * MSEC_TO_SEC) + yield response + except Exception as e: + raise e + + def teardown(self): + self._metrics_collector.teardown_counter.inc(1) + self._caller.__exit__(*sys.exc_info()) diff --git a/sdks/python/apache_beam/io/requestresponseio_it_test.py b/sdks/python/apache_beam/io/requestresponse_it_test.py similarity index 86% rename from sdks/python/apache_beam/io/requestresponseio_it_test.py rename to sdks/python/apache_beam/io/requestresponse_it_test.py index aae6b4e6ef2c7..396347c58d163 100644 --- a/sdks/python/apache_beam/io/requestresponseio_it_test.py +++ b/sdks/python/apache_beam/io/requestresponse_it_test.py @@ -16,6 +16,7 @@ # import base64 import sys +import typing import unittest from dataclasses import dataclass from typing import Tuple @@ -24,13 +25,18 @@ import urllib3 import apache_beam as beam -from apache_beam.io.requestresponseio import Caller -from apache_beam.io.requestresponseio import RequestResponseIO -from apache_beam.io.requestresponseio import UserCodeExecutionException -from apache_beam.io.requestresponseio import UserCodeQuotaException from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline +# pylint: disable=ungrouped-imports +try: + from apache_beam.io.requestresponse import Caller + from apache_beam.io.requestresponse import RequestResponseIO + from apache_beam.io.requestresponse import UserCodeExecutionException + from apache_beam.io.requestresponse import UserCodeQuotaException +except ImportError: + raise unittest.SkipTest('RequestResponseIO dependencies are not installed.') + _HTTP_PATH = '/v1/echo' _PAYLOAD = base64.b64encode(bytes('payload', 'utf-8')) _HTTP_ENDPOINT_ADDRESS_FLAG = '--httpEndpointAddress' @@ -61,28 +67,27 @@ def _add_argparse_args(cls, parser) -> None: help='The ID for an allocated quota that should exceed.') -# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto -# generated classes from .test-infra/mock-apis: @dataclass -class EchoRequest: +class EchoResponse: id: str payload: bytes -@dataclass -class EchoResponse: +# TODO(riteshghorse,damondouglas) replace Echo(Request|Response) with proto +# generated classes from .test-infra/mock-apis: +class Request(typing.NamedTuple): id: str payload: bytes -class EchoHTTPCaller(Caller): +class EchoHTTPCaller(Caller[Request, EchoResponse]): """Implements ``Caller`` to call the ``EchoServiceGrpc``'s HTTP handler. The purpose of ``EchoHTTPCaller`` is to support integration tests. """ def __init__(self, url: str): self.url = url + _HTTP_PATH - def __call__(self, request: EchoRequest, *args, **kwargs) -> EchoResponse: + def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: """Overrides ``Caller``'s call method invoking the ``EchoServiceGrpc``'s HTTP handler with an ``EchoRequest``, returning either a successful ``EchoResponse`` or throwing either a @@ -129,7 +134,7 @@ def setUpClass(cls) -> None: def setUp(self) -> None: client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD) + req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) try: # The following is needed to exceed the API client(req) @@ -148,7 +153,7 @@ def _get_client_and_options(cls) -> Tuple[EchoHTTPCaller, EchoITOptions]: def test_given_valid_request_receives_response(self): client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD) + req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) response: EchoResponse = client(req) @@ -158,20 +163,20 @@ def test_given_valid_request_receives_response(self): def test_given_exceeded_quota_should_raise(self): client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = EchoRequest(id=options.should_exceed_quota_id, payload=_PAYLOAD) + req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) self.assertRaises(UserCodeQuotaException, lambda: client(req)) def test_not_found_should_raise(self): client, _ = EchoHTTPCallerTestIT._get_client_and_options() - req = EchoRequest(id='i-dont-exist-quota-id', payload=_PAYLOAD) + req = Request(id='i-dont-exist-quota-id', payload=_PAYLOAD) self.assertRaisesRegex( UserCodeExecutionException, "Not Found", lambda: client(req)) def test_request_response_io(self): client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = EchoRequest(id=options.never_exceed_quota_id, payload=_PAYLOAD) + req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) with TestPipeline(is_integration_test=True) as test_pipeline: output = ( test_pipeline diff --git a/sdks/python/apache_beam/io/requestresponse_test.py b/sdks/python/apache_beam/io/requestresponse_test.py new file mode 100644 index 0000000000000..6d807c2a8eb83 --- /dev/null +++ b/sdks/python/apache_beam/io/requestresponse_test.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from google.api_core.exceptions import TooManyRequests + from apache_beam.io.requestresponse import Caller, DefaultThrottler + from apache_beam.io.requestresponse import RequestResponseIO + from apache_beam.io.requestresponse import UserCodeExecutionException + from apache_beam.io.requestresponse import UserCodeTimeoutException + from apache_beam.io.requestresponse import retry_on_exception +except ImportError: + raise unittest.SkipTest('RequestResponseIO dependencies are not installed.') + + +class AckCaller(Caller[str, str]): + """AckCaller acknowledges the incoming request by returning a + request with ACK.""" + def __enter__(self): + pass + + def __call__(self, request: str): + return f"ACK: {request}" + + def __exit__(self, exc_type, exc_val, exc_tb): + return None + + +class CallerWithTimeout(AckCaller): + """CallerWithTimeout sleeps for 2 seconds before responding. + Used to test timeout in RequestResponseIO.""" + def __call__(self, request: str, *args, **kwargs): + time.sleep(2) + return f"ACK: {request}" + + +class CallerWithRuntimeError(AckCaller): + """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO + to raise a UserCodeExecutionException.""" + def __call__(self, request: str, *args, **kwargs): + if not request: + raise RuntimeError("Exception expected, not an error.") + + +class CallerThatRetries(AckCaller): + def __init__(self): + self.count = -1 + + def __call__(self, request: str, *args, **kwargs): + try: + pass + except Exception as e: + raise e + finally: + self.count += 1 + raise TooManyRequests('retries = %d' % self.count) + + +class TestCaller(unittest.TestCase): + def test_valid_call(self): + caller = AckCaller() + with TestPipeline() as test_pipeline: + output = ( + test_pipeline + | beam.Create(["sample_request"]) + | RequestResponseIO(caller=caller)) + + self.assertIsNotNone(output) + + def test_call_timeout(self): + caller = CallerWithTimeout() + with self.assertRaises(UserCodeTimeoutException): + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create(["timeout_request"]) + | RequestResponseIO(caller=caller, timeout=1)) + + def test_call_runtime_error(self): + caller = CallerWithRuntimeError() + with self.assertRaises(UserCodeExecutionException): + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create([""]) + | RequestResponseIO(caller=caller)) + + def test_retry_on_exception(self): + self.assertFalse(retry_on_exception(RuntimeError())) + self.assertTrue(retry_on_exception(TooManyRequests("HTTP 429"))) + + def test_caller_backoff_retry_strategy(self): + caller = CallerThatRetries() + with self.assertRaises(TooManyRequests) as cm: + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create(["sample_request"]) + | RequestResponseIO(caller=caller)) + self.assertRegex(cm.exception.message, 'retries = 2') + + def test_caller_no_retry_strategy(self): + caller = CallerThatRetries() + with self.assertRaises(TooManyRequests) as cm: + with TestPipeline() as test_pipeline: + _ = ( + test_pipeline + | beam.Create(["sample_request"]) + | RequestResponseIO(caller=caller, repeater=None)) + self.assertRegex(cm.exception.message, 'retries = 0') + + def test_default_throttler(self): + caller = CallerWithTimeout() + throttler = DefaultThrottler( + window_ms=10000, bucket_ms=5000, overload_ratio=1) + # manually override the number of received requests for testing. + throttler.throttler._all_requests.add(time.time() * 1000, 100) + test_pipeline = TestPipeline() + _ = ( + test_pipeline + | beam.Create(['sample_request']) + | RequestResponseIO(caller=caller, throttler=throttler)) + result = test_pipeline.run() + result.wait_until_finish() + metrics = result.metrics().query( + beam.metrics.MetricsFilter().with_name('throttled_requests')) + self.assertEqual(metrics['counters'][0].committed, 1) + metrics = result.metrics().query( + beam.metrics.MetricsFilter().with_name('cumulativeThrottlingSeconds')) + self.assertGreater(metrics['counters'][0].committed, 0) + metrics = result.metrics().query( + beam.metrics.MetricsFilter().with_name('responses')) + self.assertEqual(metrics['counters'][0].committed, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/io/requestresponseio.py b/sdks/python/apache_beam/io/requestresponseio.py deleted file mode 100644 index 0ec586e640184..0000000000000 --- a/sdks/python/apache_beam/io/requestresponseio.py +++ /dev/null @@ -1,218 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""``PTransform`` for reading from and writing to Web APIs.""" -import abc -import concurrent.futures -import contextlib -import logging -import sys -from typing import Generic -from typing import Optional -from typing import TypeVar - -import apache_beam as beam -from apache_beam.pvalue import PCollection - -RequestT = TypeVar('RequestT') -ResponseT = TypeVar('ResponseT') - -DEFAULT_TIMEOUT_SECS = 30 # seconds - -_LOGGER = logging.getLogger(__name__) - - -class UserCodeExecutionException(Exception): - """Base class for errors related to calling Web APIs.""" - - -class UserCodeQuotaException(UserCodeExecutionException): - """Extends ``UserCodeExecutionException`` to signal specifically that - the Web API client encountered a Quota or API overuse related error. - """ - - -class UserCodeTimeoutException(UserCodeExecutionException): - """Extends ``UserCodeExecutionException`` to signal a user code timeout.""" - - -class Caller(contextlib.AbstractContextManager, abc.ABC): - """Interface for user custom code intended for API calls. - For setup and teardown of clients when applicable, implement the - ``__enter__`` and ``__exit__`` methods respectively.""" - @abc.abstractmethod - def __call__(self, request: RequestT, *args, **kwargs) -> ResponseT: - """Calls a Web API with the ``RequestT`` and returns a - ``ResponseT``. ``RequestResponseIO`` expects implementations of the - ``__call__`` method to throw either a ``UserCodeExecutionException``, - ``UserCodeQuotaException``, or ``UserCodeTimeoutException``. - """ - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return None - - -class ShouldBackOff(abc.ABC): - """ - ShouldBackOff provides mechanism to apply adaptive throttling. - """ - pass - - -class Repeater(abc.ABC): - """Repeater provides mechanism to repeat requests for a - configurable condition.""" - pass - - -class CacheReader(abc.ABC): - """CacheReader provides mechanism to read from the cache.""" - pass - - -class CacheWriter(abc.ABC): - """CacheWriter provides mechanism to write to the cache.""" - pass - - -class PreCallThrottler(abc.ABC): - """PreCallThrottler provides a throttle mechanism before sending request.""" - pass - - -class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], - beam.PCollection[ResponseT]]): - """A :class:`RequestResponseIO` transform to read and write to APIs. - - Processes an input :class:`~apache_beam.pvalue.PCollection` of requests - by making a call to the API as defined in :class:`Caller`'s `__call__` - and returns a :class:`~apache_beam.pvalue.PCollection` of responses. - """ - def __init__( - self, - caller: [Caller], - timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, - should_backoff: Optional[ShouldBackOff] = None, - repeater: Optional[Repeater] = None, - cache_reader: Optional[CacheReader] = None, - cache_writer: Optional[CacheWriter] = None, - throttler: Optional[PreCallThrottler] = None, - ): - """ - Instantiates a RequestResponseIO transform. - - Args: - caller (~apache_beam.io.requestresponseio.Caller): an implementation of - `Caller` object that makes call to the API. - timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) - provides methods to repeat requests to API. - cache_reader (~apache_beam.io.requestresponseio.CacheReader): (Optional) - provides methods to read external cache. - cache_writer (~apache_beam.io.requestresponseio.CacheWriter): (Optional) - provides methods to write to external cache. - throttler (~apache_beam.io.requestresponseio.PreCallThrottler): - (Optional) provides methods to pre-throttle a request. - """ - self._caller = caller - self._timeout = timeout - self._should_backoff = should_backoff - self._repeater = repeater - self._cache_reader = cache_reader - self._cache_writer = cache_writer - self._throttler = throttler - - def expand(self, requests: PCollection[RequestT]) -> PCollection[ResponseT]: - # TODO(riteshghorse): add Cache and Throttle PTransforms. - return requests | _Call( - caller=self._caller, - timeout=self._timeout, - should_backoff=self._should_backoff, - repeater=self._repeater) - - -class _Call(beam.PTransform[beam.PCollection[RequestT], - beam.PCollection[ResponseT]]): - """(Internal-only) PTransform that invokes a remote function on each element - of the input PCollection. - - This PTransform uses a `Caller` object to invoke the actual API calls, - and uses ``__enter__`` and ``__exit__`` to manage setup and teardown of - clients when applicable. Additionally, a timeout value is specified to - regulate the duration of each call, defaults to 30 seconds. - - Args: - caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable - object that invokes API call. - timeout (float): timeout value in seconds to wait for response from API. - """ - def __init__( - self, - caller: Caller, - timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, - should_backoff: Optional[ShouldBackOff] = None, - repeater: Optional[Repeater] = None, - ): - """Initialize the _Call transform. - Args: - caller (:class:`apache_beam.io.requestresponseio.Caller`): a callable - object that invokes API call. - timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponseio.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponseio.Repeater): (Optional) provides - methods to repeat requests to API. - """ - self._caller = caller - self._timeout = timeout - self._should_backoff = should_backoff - self._repeater = repeater - - def expand( - self, - requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: - return requests | beam.ParDo(_CallDoFn(self._caller, self._timeout)) - - -class _CallDoFn(beam.DoFn, Generic[RequestT, ResponseT]): - def setup(self): - self._caller.__enter__() - - def __init__(self, caller: Caller, timeout: float): - self._caller = caller - self._timeout = timeout - - def process(self, request, *args, **kwargs): - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(self._caller, request) - try: - yield future.result(timeout=self._timeout) - except concurrent.futures.TimeoutError: - raise UserCodeTimeoutException( - f'Timeout {self._timeout} exceeded ' - f'while completing request: {request}') - except RuntimeError: - raise UserCodeExecutionException('could not complete request') - - def teardown(self): - self._caller.__exit__(*sys.exc_info()) diff --git a/sdks/python/apache_beam/io/requestresponseio_test.py b/sdks/python/apache_beam/io/requestresponseio_test.py deleted file mode 100644 index 2828a3578871d..0000000000000 --- a/sdks/python/apache_beam/io/requestresponseio_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import time -import unittest - -import apache_beam as beam -from apache_beam.io.requestresponseio import Caller -from apache_beam.io.requestresponseio import RequestResponseIO -from apache_beam.io.requestresponseio import UserCodeExecutionException -from apache_beam.io.requestresponseio import UserCodeTimeoutException -from apache_beam.testing.test_pipeline import TestPipeline - - -class AckCaller(Caller): - """AckCaller acknowledges the incoming request by returning a - request with ACK.""" - def __enter__(self): - pass - - def __call__(self, request: str): - return f"ACK: {request}" - - def __exit__(self, exc_type, exc_val, exc_tb): - return None - - -class CallerWithTimeout(AckCaller): - """CallerWithTimeout sleeps for 2 seconds before responding. - Used to test timeout in RequestResponseIO.""" - def __call__(self, request: str, *args, **kwargs): - time.sleep(2) - return f"ACK: {request}" - - -class CallerWithRuntimeError(AckCaller): - """CallerWithRuntimeError raises a `RuntimeError` for RequestResponseIO - to raise a UserCodeExecutionException.""" - def __call__(self, request: str, *args, **kwargs): - if not request: - raise RuntimeError("Exception expected, not an error.") - - -class TestCaller(unittest.TestCase): - def test_valid_call(self): - caller = AckCaller() - with TestPipeline() as test_pipeline: - output = ( - test_pipeline - | beam.Create(["sample_request"]) - | RequestResponseIO(caller=caller)) - - self.assertIsNotNone(output) - - def test_call_timeout(self): - caller = CallerWithTimeout() - with self.assertRaises(UserCodeTimeoutException): - with TestPipeline() as test_pipeline: - _ = ( - test_pipeline - | beam.Create(["timeout_request"]) - | RequestResponseIO(caller=caller, timeout=1)) - - def test_call_runtime_error(self): - caller = CallerWithRuntimeError() - with self.assertRaises(UserCodeExecutionException): - with TestPipeline() as test_pipeline: - _ = ( - test_pipeline - | beam.Create([""]) - | RequestResponseIO(caller=caller)) - - -if __name__ == '__main__': - unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py new file mode 100644 index 0000000000000..a2f961be64373 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import TypeVar + +import apache_beam as beam +from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS +from apache_beam.io.requestresponse import Caller +from apache_beam.io.requestresponse import DefaultThrottler +from apache_beam.io.requestresponse import ExponentialBackOffRepeater +from apache_beam.io.requestresponse import PreCallThrottler +from apache_beam.io.requestresponse import Repeater +from apache_beam.io.requestresponse import RequestResponseIO + +__all__ = [ + "EnrichmentSourceHandler", + "Enrichment", + "cross_join", +] + +InputT = TypeVar('InputT') +OutputT = TypeVar('OutputT') + +JoinFn = Callable[[Dict[str, Any], Dict[str, Any]], beam.Row] + +_LOGGER = logging.getLogger(__name__) + + +def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: + """cross_join performs a cross join on two `dict` objects. + + Joins the columns of the right row onto the left row. + + Args: + left (Dict[str, Any]): input request dictionary. + right (Dict[str, Any]): response dictionary from the API. + + Returns: + `beam.Row` containing the merged columns. + """ + for k, v in right.items(): + if k not in left: + # Don't override the values in left. + left[k] = v + elif left[k] != v: + _LOGGER.warning( + '%s exists in the input row as well the row fetched ' + 'from API but have different values - %s and %s. Using the input ' + 'value (%s) for the enriched row. You can override this behavior by ' + 'passing a custom `join_fn` to Enrichment transform.' % + (k, left[k], v, left[k])) + return beam.Row(**left) + + +class EnrichmentSourceHandler(Caller[InputT, OutputT]): + """Wrapper class for :class:`apache_beam.io.requestresponse.Caller`. + + Ensure that the implementation of ``__call__`` method returns a tuple + of `beam.Row` objects. + """ + pass + + +class Enrichment(beam.PTransform[beam.PCollection[InputT], + beam.PCollection[OutputT]]): + """A :class:`apache_beam.transforms.enrichment.Enrichment` transform to + enrich elements in a PCollection. + **NOTE:** This transform and its implementation are under development and + do not provide backward compatibility guarantees. + Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` + to enrich elements by joining the metadata from external source. + + Processes an input :class:`~apache_beam.pvalue.PCollection` of `beam.Row` by + applying a :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` + to each element and returning the enriched + :class:`~apache_beam.pvalue.PCollection`. + + Args: + source_handler: Handles source lookup and metadata retrieval. + Implements the + :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` + join_fn: A lambda function to join original element with lookup metadata. + Defaults to `CROSS_JOIN`. + timeout: (Optional) timeout for source requests. Defaults to 30 seconds. + repeater (~apache_beam.io.requestresponse.Repeater): provides method to + repeat failed requests to API due to service errors. Defaults to + :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to + repeat requests with exponential backoff. + throttler (~apache_beam.io.requestresponse.PreCallThrottler): + provides methods to pre-throttle a request. Defaults to + :class:`apache_beam.io.requestresponse.DefaultThrottler` for + client-side adaptive throttling using + :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`. + """ + def __init__( + self, + source_handler: EnrichmentSourceHandler, + join_fn: JoinFn = cross_join, + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + repeater: Repeater = ExponentialBackOffRepeater(), + throttler: PreCallThrottler = DefaultThrottler(), + ): + self._source_handler = source_handler + self._join_fn = join_fn + self._timeout = timeout + self._repeater = repeater + self._throttler = throttler + + def expand(self, + input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: + fetched_data = input_row | RequestResponseIO( + caller=self._source_handler, + timeout=self._timeout, + repeater=self._repeater, + throttler=self._throttler) + + # EnrichmentSourceHandler returns a tuple of (request,response). + return fetched_data | beam.Map( + lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict())) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py new file mode 100644 index 0000000000000..86ff2f3b8e7f6 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from enum import Enum +from typing import Any +from typing import Dict +from typing import Optional + +from google.api_core.exceptions import NotFound +from google.cloud import bigtable +from google.cloud.bigtable import Client +from google.cloud.bigtable.row_filters import CellsColumnLimitFilter +from google.cloud.bigtable.row_filters import RowFilter + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + +__all__ = [ + 'EnrichWithBigTable', + 'ExceptionLevel', +] + +_LOGGER = logging.getLogger(__name__) + + +class ExceptionLevel(Enum): + """ExceptionLevel defines the exception level options to either + log a warning, or raise an exception, or do nothing when a BigTable query + returns an empty row. + + Members: + - RAISE: Raise the exception. + - WARN: Log a warning for exception without raising it. + - QUIET: Neither log nor raise the exception. + """ + RAISE = 0 + WARN = 1 + QUIET = 2 + + +class EnrichWithBigTable(EnrichmentSourceHandler[beam.Row, beam.Row]): + """EnrichWithBigTable is a handler for + :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact + with GCP BigTable. + + Args: + project_id (str): GCP project-id of the BigTable cluster. + instance_id (str): GCP instance-id of the BigTable cluster. + table_id (str): GCP table-id of the BigTable. + row_key (str): unique row-key field name from the input `beam.Row` object + to use as `row_key` for BigTable querying. + row_filter: a ``:class:`google.cloud.bigtable.row_filters.RowFilter``` to + filter data read with ``read_row()``. + Defaults to `CellsColumnLimitFilter(1)`. + app_profile_id (str): App profile ID to use for BigTable. + See https://cloud.google.com/bigtable/docs/app-profiles for more details. + encoding (str): encoding type to convert the string to bytes and vice-versa + from BigTable. Default is `utf-8`. + exception_level: a `enum.Enum` value from + ``apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel`` + to set the level when an empty row is returned from the BigTable query. + Defaults to ``ExceptionLevel.WARN``. + """ + def __init__( + self, + project_id: str, + instance_id: str, + table_id: str, + row_key: str, + row_filter: Optional[RowFilter] = CellsColumnLimitFilter(1), + app_profile_id: str = None, # type: ignore[assignment] + encoding: str = 'utf-8', + exception_level: ExceptionLevel = ExceptionLevel.WARN, + ): + self._project_id = project_id + self._instance_id = instance_id + self._table_id = table_id + self._row_key = row_key + self._row_filter = row_filter + self._app_profile_id = app_profile_id + self._encoding = encoding + self._exception_level = exception_level + + def __enter__(self): + """connect to the Google BigTable cluster.""" + self.client = Client(project=self._project_id) + self.instance = self.client.instance(self._instance_id) + self._table = bigtable.table.Table( + table_id=self._table_id, + instance=self.instance, + app_profile_id=self._app_profile_id) + + def __call__(self, request: beam.Row, *args, **kwargs): + """ + Reads a row from the GCP BigTable and returns + a `Tuple` of request and response. + + Args: + request: the input `beam.Row` to enrich. + """ + response_dict: Dict[str, Any] = {} + row_key_str: str = "" + try: + request_dict = request._asdict() + row_key_str = str(request_dict[self._row_key]) + row_key = row_key_str.encode(self._encoding) + row = self._table.read_row(row_key, filter_=self._row_filter) + if row: + for cf_id, cf_v in row.cells.items(): + response_dict[cf_id] = {} + for k, v in cf_v.items(): + response_dict[cf_id][k.decode(self._encoding)] = \ + v[0].value.decode(self._encoding) + elif self._exception_level == ExceptionLevel.WARN: + _LOGGER.warning( + 'no matching row found for row_key: %s ' + 'with row_filter: %s' % (row_key_str, self._row_filter)) + elif self._exception_level == ExceptionLevel.RAISE: + raise ValueError( + 'no matching row found for row_key: %s ' + 'with row_filter=%s' % (row_key_str, self._row_filter)) + except KeyError: + raise KeyError('row_key %s not found in input PCollection.' % row_key_str) + except NotFound: + raise NotFound( + 'GCP BigTable cluster `%s:%s:%s` not found.' % + (self._project_id, self._instance_id, self._table_id)) + except Exception as e: + raise e + + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Clean the instantiated BigTable client.""" + self.client = None + self.instance = None + self._table = None diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py new file mode 100644 index 0000000000000..dd48c8e5ef4d0 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -0,0 +1,300 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest +from typing import Dict +from typing import List +from typing import NamedTuple + +import pytest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import BeamAssertException + +# pylint: disable=ungrouped-imports +try: + from google.api_core.exceptions import NotFound + from google.cloud.bigtable import Client + from google.cloud.bigtable.row_filters import ColumnRangeFilter + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigtable import EnrichWithBigTable + from apache_beam.transforms.enrichment_handlers.bigtable import ExceptionLevel +except ImportError: + raise unittest.SkipTest('GCP BigTable dependencies are not installed.') + + +class ValidateResponse(beam.DoFn): + """ValidateResponse validates if a PCollection of `beam.Row` + has the required fields.""" + def __init__( + self, + n_fields: int, + fields: List[str], + enriched_fields: Dict[str, List[str]]): + self.n_fields = n_fields + self._fields = fields + self._enriched_fields = enriched_fields + + def process(self, element: beam.Row, *args, **kwargs): + element_dict = element.as_dict() + if len(element_dict.keys()) != self.n_fields: + raise BeamAssertException( + "Expected %d fields in enriched PCollection:" % self.n_fields) + + for field in self._fields: + if field not in element_dict or element_dict[field] is None: + raise BeamAssertException(f"Expected a not None field: {field}") + + for column_family, columns in self._enriched_fields.items(): + if (len(element_dict[column_family]) != len(columns) or + not all(key in element_dict[column_family] for key in columns)): + raise BeamAssertException( + "Response from bigtable should contain a %s column_family with " + "%s keys." % (column_family, columns)) + + +class _Currency(NamedTuple): + s_id: int + id: str + + +def create_rows(table): + product_id = 'product_id' + product_name = 'product_name' + product_stock = 'product_stock' + + column_family_id = "product" + products = [ + { + 'product_id': 1, 'product_name': 'pixel 5', 'product_stock': 2 + }, + { + 'product_id': 2, 'product_name': 'pixel 6', 'product_stock': 4 + }, + { + 'product_id': 3, 'product_name': 'pixel 7', 'product_stock': 20 + }, + { + 'product_id': 4, 'product_name': 'pixel 8', 'product_stock': 10 + }, + { + 'product_id': 5, 'product_name': 'iphone 11', 'product_stock': 3 + }, + { + 'product_id': 6, 'product_name': 'iphone 12', 'product_stock': 7 + }, + { + 'product_id': 7, 'product_name': 'iphone 13', 'product_stock': 8 + }, + { + 'product_id': 8, 'product_name': 'iphone 14', 'product_stock': 3 + }, + ] + + for item in products: + row_key = str(item[product_id]).encode() + row = table.direct_row(row_key) + row.set_cell( + column_family_id, + product_id.encode(), + str(item[product_id]), + timestamp=datetime.datetime.utcnow()) + row.set_cell( + column_family_id, + product_name.encode(), + item[product_name], + timestamp=datetime.datetime.utcnow()) + row.set_cell( + column_family_id, + product_stock.encode(), + str(item[product_stock]), + timestamp=datetime.datetime.utcnow()) + row.commit() + + +@pytest.mark.it_postcommit +class TestBigTableEnrichment(unittest.TestCase): + def setUp(self): + self.project_id = 'apache-beam-testing' + self.instance_id = 'beam-test' + self.table_id = 'bigtable-enrichment-test' + self.req = [ + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), + beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2), + beam.Row(sale_id=7, customer_id=7, product_id=1, quantity=1), + ] + self.row_key = 'product_id' + self.column_family_id = 'product' + client = Client(project=self.project_id) + instance = client.instance(self.instance_id) + self.table = instance.table(self.table_id) + create_rows(self.table) + + def tearDown(self) -> None: + self.table = None + + def test_enrichment_with_bigtable(self): + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_id', 'product_name', 'product_stock'], + } + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + def test_enrichment_with_bigtable_row_filter(self): + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_name', 'product_stock'], + } + start_column = 'product_name'.encode() + column_filter = ColumnRangeFilter(self.column_family_id, start_column) + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key, + row_filter=column_filter) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + def test_enrichment_with_bigtable_no_enrichment(self): + # row_key which is product_id=11 doesn't exist, so the enriched field + # won't be added. Hence, the response is same as the request. + expected_fields = ['sale_id', 'customer_id', 'product_id', 'quantity'] + expected_enriched_fields = {} + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key) + req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)] + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(req) + | "Enrich W/ BigTable" >> Enrichment(bigtable) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + def test_enrichment_with_bigtable_bad_row_filter(self): + # in case of a bad column filter, that is, incorrect column_family_id and + # columns, no enrichment is done. If the column_family is correct but not + # column names then all columns in that column_family are returned. + start_column = 'car_name'.encode() + column_filter = ColumnRangeFilter('car_name', start_column) + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key, + row_filter=column_filter) + with self.assertRaises(NotFound): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_enrichment_with_bigtable_raises_key_error(self): + """raises a `KeyError` when the row_key doesn't exist in + the input PCollection.""" + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key='car_name') + with self.assertRaises(KeyError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_enrichment_with_bigtable_raises_not_found(self): + """raises a `NotFound` exception when the GCP BigTable Cluster + doesn't exist.""" + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id='invalid_table', + row_key=self.row_key) + with self.assertRaises(NotFound): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ BigTable" >> Enrichment(bigtable)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_enrichment_with_bigtable_exception_level(self): + """raises a `ValueError` exception when the GCP BigTable query returns + an empty row.""" + bigtable = EnrichWithBigTable( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key, + exception_level=ExceptionLevel.RAISE) + req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)] + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(req) + | "Enrich W/ BigTable" >> Enrichment(bigtable)) + res = test_pipeline.run() + res.wait_until_finish() + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py b/sdks/python/apache_beam/transforms/enrichment_it_test.py new file mode 100644 index 0000000000000..89842cb18be02 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import time +import unittest +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import pytest +import urllib3 + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import BeamAssertException + +# pylint: disable=ungrouped-imports +try: + from apache_beam.io.requestresponse import UserCodeExecutionException + from apache_beam.io.requestresponse import UserCodeQuotaException + from apache_beam.io.requestresponse_it_test import _PAYLOAD + from apache_beam.io.requestresponse_it_test import EchoITOptions + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment import EnrichmentSourceHandler +except ImportError: + raise unittest.SkipTest('RequestResponseIO dependencies are not installed.') + + +class Request(NamedTuple): + id: str + payload: bytes + + +def _custom_join(left, right): + """custom_join returns the id and resp_payload along with a timestamp""" + right['timestamp'] = time.time() + return beam.Row(**right) + + +class SampleHTTPEnrichment(EnrichmentSourceHandler[Request, beam.Row]): + """Implements ``EnrichmentSourceHandler`` to call the ``EchoServiceGrpc``'s + HTTP handler. + """ + def __init__(self, url: str): + self.url = url + '/v1/echo' # append path to the mock API. + + def __call__(self, request: Request, *args, **kwargs): + """Overrides ``Caller``'s call method invoking the + ``EchoServiceGrpc``'s HTTP handler with an `dict`, returning + either a successful ``Tuple[dict,dict]`` or throwing either a + ``UserCodeExecutionException``, ``UserCodeTimeoutException``, + or a ``UserCodeQuotaException``. + """ + try: + resp = urllib3.request( + "POST", + self.url, + json={ + "id": request.id, "payload": str(request.payload, 'utf-8') + }, + retries=False) + + if resp.status < 300: + resp_body = resp.json() + resp_id = resp_body['id'] + payload = resp_body['payload'] + return ( + request, beam.Row(id=resp_id, resp_payload=bytes(payload, 'utf-8'))) + + if resp.status == 429: # Too Many Requests + raise UserCodeQuotaException(resp.reason) + elif resp.status != 200: + raise UserCodeExecutionException(resp.status, resp.reason, request) + + except urllib3.exceptions.HTTPError as e: + raise UserCodeExecutionException(e) + + +class ValidateFields(beam.DoFn): + """ValidateFields validates if a PCollection of `beam.Row` + has certain fields.""" + def __init__(self, n_fields: int, fields: List[str]): + self.n_fields = n_fields + self._fields = fields + + def process(self, element: beam.Row, *args, **kwargs): + element_dict = element.as_dict() + if len(element_dict.keys()) != self.n_fields: + raise BeamAssertException( + "Expected %d fields in enriched PCollection:" + " id, payload and resp_payload" % self.n_fields) + + for field in self._fields: + if field not in element_dict or element_dict[field] is None: + raise BeamAssertException(f"Expected a not None field: {field}") + + +@pytest.mark.it_postcommit +class TestEnrichment(unittest.TestCase): + options: Union[EchoITOptions, None] = None + client: Union[SampleHTTPEnrichment, None] = None + + @classmethod + def setUpClass(cls) -> None: + cls.options = EchoITOptions() + http_endpoint_address = cls.options.http_endpoint_address + if not http_endpoint_address or http_endpoint_address == '': + raise unittest.SkipTest('HTTP_ENDPOINT_ADDRESS is required.') + cls.client = SampleHTTPEnrichment(http_endpoint_address) + + @classmethod + def _get_client_and_options( + cls) -> Tuple[SampleHTTPEnrichment, EchoITOptions]: + assert cls.options is not None + assert cls.client is not None + return cls.client, cls.options + + def test_http_enrichment(self): + """Tests Enrichment Transform against the Mock-API HTTP endpoint + with the default cross join.""" + client, options = TestEnrichment._get_client_and_options() + req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) + fields = ['id', 'payload', 'resp_payload'] + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | 'Create PCollection' >> beam.Create([req]) + | 'Enrichment Transform' >> Enrichment(client) + | 'Assert Fields' >> beam.ParDo( + ValidateFields(len(fields), fields=fields))) + + def test_http_enrichment_custom_join(self): + """Tests Enrichment Transform against the Mock-API HTTP endpoint + with a custom join function.""" + client, options = TestEnrichment._get_client_and_options() + req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) + fields = ['id', 'resp_payload', 'timestamp'] + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | 'Create PCollection' >> beam.Create([req]) + | 'Enrichment Transform' >> Enrichment(client, join_fn=_custom_join) + | 'Assert Fields' >> beam.ParDo( + ValidateFields(len(fields), fields=fields))) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_test.py b/sdks/python/apache_beam/transforms/enrichment_test.py new file mode 100644 index 0000000000000..23b5f1828c15c --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_test.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import apache_beam as beam + +# pylint: disable=ungrouped-imports +try: + from apache_beam.transforms.enrichment import cross_join +except ImportError: + raise unittest.SkipTest('RequestResponseIO dependencies are not installed.') + + +class TestEnrichmentTransform(unittest.TestCase): + def test_cross_join(self): + left = {'id': 1, 'key': 'city'} + right = {'id': 1, 'value': 'durham'} + expected = beam.Row(id=1, key='city', value='durham') + output = cross_join(left, right) + self.assertEqual(expected, output) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()