diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index c4edaa85a89d..63bd5651def0 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,3 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run" } + diff --git a/CHANGES.md b/CHANGES.md index bce9636237e3..dd66e9192b69 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -71,6 +71,7 @@ * Beam YAML now supports the jinja templating syntax. Template variables can be passed with the (json-formatted) `--jinja_variables` flag. * DataFrame API now supports pandas 2.1.x and adds 12 more string functions for Series.([#31185](https://github.com/apache/beam/pull/31185)). +* Added BigQuery handler for enrichment transform (Python) ([#31295](https://github.com/apache/beam/pull/31295)) ## Breaking Changes diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index 4458aa59c18c..d7011e5a8ff3 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -28,6 +28,8 @@ from typing import Any from typing import Dict from typing import Generic +from typing import List +from typing import Mapping from typing import Optional from typing import Tuple from typing import TypeVar @@ -42,6 +44,7 @@ from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler from apache_beam.metrics import Metrics from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC +from apache_beam.transforms.util import BatchElements from apache_beam.utils import retry RequestT = TypeVar('RequestT') @@ -143,6 +146,10 @@ def get_cache_key(self, request: RequestT) -> str: """ return "" + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return {} + class ShouldBackOff(abc.ABC): """ @@ -476,53 +483,67 @@ def __init__( def __enter__(self): self.client = redis.Redis(self.host, self.port, **self.kwargs) - def __call__(self, element, *args, **kwargs): - if self.mode == _RedisMode.READ: - cache_request = self.source_caller.get_cache_key(element) - # check if the caller is a enrichment handler. EnrichmentHandler - # provides the request format for cache. - if cache_request: - encoded_request = self.request_coder.encode(cache_request) - else: - encoded_request = self.request_coder.encode(element) - - encoded_response = self.client.get(encoded_request) - if not encoded_response: - # no cache entry present for this request. + def _read_cache(self, element): + cache_request = self.source_caller.get_cache_key(element) + # check if the caller is a enrichment handler. EnrichmentHandler + # provides the request format for cache. + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element) + + encoded_response = self.client.get(encoded_request) + if not encoded_response: + # no cache entry present for this request. + return element, None + + if self.response_coder is None: + try: + response_dict = json.loads(encoded_response.decode('utf-8')) + response = beam.Row(**response_dict) + except Exception: + _LOGGER.warning( + 'cannot decode response from redis cache for %s.' % element) return element, None + else: + response = self.response_coder.decode(encoded_response) + return element, response - if self.response_coder is None: - try: - response_dict = json.loads(encoded_response.decode('utf-8')) - response = beam.Row(**response_dict) - except Exception: - _LOGGER.warning( - 'cannot decode response from redis cache for %s.' % element) - return element, None - else: - response = self.response_coder.decode(encoded_response) - return element, response + def _write_cache(self, element): + cache_request = self.source_caller.get_cache_key(element[0]) + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element[0]) + if self.response_coder is None: + try: + encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') + except Exception: + _LOGGER.warning( + 'cannot encode response %s for %s to store in ' + 'redis cache.' % (element[1], element[0])) + return element else: - cache_request = self.source_caller.get_cache_key(element[0]) - if cache_request: - encoded_request = self.request_coder.encode(cache_request) + encoded_response = self.response_coder.encode(element[1]) + # Write to cache with TTL. Set nx to True to prevent overwriting for the + # same key. + self.client.set( + encoded_request, encoded_response, self.time_to_live, nx=True) + return element + + def __call__(self, element, *args, **kwargs): + if self.mode == _RedisMode.READ: + if isinstance(element, List): + responses = [self._read_cache(e) for e in element] + return responses else: - encoded_request = self.request_coder.encode(element[0]) - if self.response_coder is None: - try: - encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') - except Exception: - _LOGGER.warning( - 'cannot encode response %s for %s to store in ' - 'redis cache.' % (element[1], element[0])) - return element + return self._read_cache(element) + else: + if isinstance(element, List): + responses = [self._write_cache(e) for e in element] + return responses else: - encoded_response = self.response_coder.encode(element[1]) - # Write to cache with TTL. Set nx to True to prevent overwriting for the - # same key. - self.client.set( - encoded_request, encoded_response, self.time_to_live, nx=True) - return element + return self._write_cache(element) def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() @@ -708,6 +729,13 @@ def request_coder(self, request_coder: coders.Coder): self._request_coder = request_coder +class FlattenBatch(beam.DoFn): + """Flatten a batched PCollection.""" + def process(self, elements, *args, **kwargs): + for element in elements: + yield element + + class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], beam.PCollection[ResponseT]]): """A :class:`RequestResponseIO` transform to read and write to APIs. @@ -753,6 +781,7 @@ def __init__( self._repeater = NoOpsRepeater() self._cache = cache self._throttler = throttler + self._batching_kwargs = self._caller.batch_elements_kwargs() def expand( self, @@ -774,6 +803,10 @@ def expand( ).with_outputs( 'cache_misses', main='cached_responses')) + # Batch elements if batching is enabled. + if self._batching_kwargs: + inputs = inputs | BatchElements(**self._batching_kwargs) + if isinstance(self._throttler, DefaultThrottler): # DefaultThrottler applies throttling in the DoFn of # Call PTransform. @@ -796,6 +829,10 @@ def expand( should_backoff=self._should_backoff, repeater=self._repeater)) + # if batching is enabled then handle accordingly. + if self._batching_kwargs: + responses = responses | "FlattenBatch" >> beam.ParDo(FlattenBatch()) + if self._cache: # write to cache. _ = responses | self._cache.get_write() diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index ddfbba5337fb..5bb1e2024e79 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -153,12 +153,14 @@ def expand(self, if self._cache: self._cache.request_coder = request_coder - fetched_data = input_row | RequestResponseIO( - caller=self._source_handler, - timeout=self._timeout, - repeater=self._repeater, - cache=self._cache, - throttler=self._throttler) + fetched_data = ( + input_row + | "Enrichment-RRIO" >> RequestResponseIO( + caller=self._source_handler, + timeout=self._timeout, + repeater=self._repeater, + cache=self._cache, + throttler=self._throttler)) # EnrichmentSourceHandler returns a tuple of (request,response). return ( diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py new file mode 100644 index 000000000000..382ae123a81d --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -0,0 +1,256 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + +from google.api_core.exceptions import BadRequest +from google.cloud import bigquery + +import apache_beam as beam +from apache_beam.pvalue import Row +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + +QueryFn = Callable[[beam.Row], str] +ConditionValueFn = Callable[[beam.Row], List[Any]] + + +def _validate_bigquery_metadata( + table_name, row_restriction_template, fields, condition_value_fn, query_fn): + if query_fn: + if bool(table_name or row_restriction_template or fields or + condition_value_fn): + raise ValueError( + "Please provide either `query_fn` or the parameters `table_name`, " + "`row_restriction_template`, and `fields/condition_value_fn` " + "together.") + else: + if not (table_name and row_restriction_template): + raise ValueError( + "Please provide either `query_fn` or the parameters " + "`table_name`, `row_restriction_template` together.") + if ((fields and condition_value_fn) or + (not fields and not condition_value_fn)): + raise ValueError( + "Please provide exactly one of `fields` or " + "`condition_value_fn`") + + +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], + Union[Row, List[Row]]]): + """Enrichment handler for Google Cloud BigQuery. + + Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` + transform. + + To use this handler you need either of the following combinations: + * `table_name`, `row_restriction_template`, `fields` + * `table_name`, `row_restriction_template`, `condition_value_fn` + * `query_fn` + + By default, the handler pulls all columns from the BigQuery table. + To override this, use the `column_name` parameter to specify a list of column + names to fetch. + + This handler pulls data from BigQuery per element by default. To change this + behavior, set the `min_batch_size` and `max_batch_size` parameters. + These min and max values for batch size are sent to the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + NOTE: Elements cannot be batched when using the `query_fn` parameter. + """ + def __init__( + self, + project: str, + *, + table_name: str = "", + row_restriction_template: str = "", + fields: Optional[List[str]] = None, + column_names: Optional[List[str]] = None, + condition_value_fn: Optional[ConditionValueFn] = None, + query_fn: Optional[QueryFn] = None, + min_batch_size: int = 1, + max_batch_size: int = 10000, + **kwargs, + ): + """ + Example Usage: + handler = BigQueryEnrichmentHandler(project=project_name, + row_restriction="id='{}'", + table_name='project.dataset.table', + fields=fields, + min_batch_size=2, + max_batch_size=100) + + Args: + project: Google Cloud project ID for the BigQuery table. + table_name (str): Fully qualified BigQuery table name + in the format `project.dataset.table`. + row_restriction_template (str): A template string for the `WHERE` clause + in the BigQuery query with placeholders (`{}`) to dynamically filter + rows based on input data. + fields: (Optional[List[str]]) List of field names present in the input + `beam.Row`. These are used to construct the WHERE clause + (if `condition_value_fn` is not provided). + column_names: (Optional[List[str]]) Names of columns to select from the + BigQuery table. If not provided, all columns (`*`) are selected. + condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function + that takes a `beam.Row` and returns a list of value to populate in the + placeholder `{}` of `WHERE` clause in the query. + query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a + `beam.Row` and returns a complete BigQuery SQL query string. + min_batch_size (int): Minimum number of rows to batch together when + querying BigQuery. Defaults to 1 if `query_fn` is not specified. + max_batch_size (int): Maximum number of rows to batch together. + Defaults to 10,000 if `query_fn` is not specified. + **kwargs: Additional keyword arguments to pass to `bigquery.Client`. + + Note: + * `min_batch_size` and `max_batch_size` cannot be defined if the + `query_fn` is provided. + * Either `fields` or `condition_value_fn` must be provided for query + construction if `query_fn` is not provided. + * Ensure appropriate permissions are granted for BigQuery access. + """ + _validate_bigquery_metadata( + table_name, + row_restriction_template, + fields, + condition_value_fn, + query_fn) + self.project = project + self.column_names = column_names + self.select_fields = ",".join(column_names) if column_names else '*' + self.row_restriction_template = row_restriction_template + self.table_name = table_name + self.fields = fields if fields else [] + self.condition_value_fn = condition_value_fn + self.query_fn = query_fn + self.query_template = ( + "SELECT %s FROM %s WHERE %s" % + (self.select_fields, self.table_name, self.row_restriction_template)) + self.kwargs = kwargs + self._batching_kwargs = {} + if not query_fn: + self._batching_kwargs['min_batch_size'] = min_batch_size + self._batching_kwargs['max_batch_size'] = max_batch_size + + def __enter__(self): + self.client = bigquery.Client(project=self.project, **self.kwargs) + + def _execute_query(self, query: str): + try: + results = self.client.query(query=query).result() + if self._batching_kwargs: + return [dict(row.items()) for row in results] + else: + return [dict(row.items()) for row in results][0] + except BadRequest as e: + raise BadRequest( + f'Could not execute the query: {query}. Please check if ' + f'the query is properly formatted and the BigQuery ' + f'table exists. {e}') + except RuntimeError as e: + raise RuntimeError(f"Could not complete the query request: {query}. {e}") + + def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): + if isinstance(request, List): + values = [] + responses = [] + requests_map: Dict[Any, Any] = {} + batch_size = len(request) + raw_query = self.query_template + if batch_size > 1: + batched_condition_template = ' or '.join( + [self.row_restriction_template] * batch_size) + raw_query = self.query_template.replace( + self.row_restriction_template, batched_condition_template) + for req in request: + request_dict = req._asdict() + try: + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [request_dict[field] for field in self.fields]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) + values.extend(current_values) + requests_map.update((val, req) for val in current_values) + query = raw_query.format(*values) + + responses_dict = self._execute_query(query) + for response in responses_dict: + for value in response.values(): + if value in requests_map: + responses.append((requests_map[value], beam.Row(**response))) + return responses + else: + request_dict = request._asdict() + if self.query_fn: + # if a query_fn is provided then it return a list of values + # that should be populated into the query template string. + query = self.query_fn(request) + else: + values = ( + self.condition_value_fn(request) if self.condition_value_fn else + list(map(request_dict.get, self.fields))) + # construct the query. + query = self.query_template.format(*values) + response_dict = self._execute_query(query) + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): + if isinstance(request, List): + cache_keys = [] + for req in request: + req_dict = req._asdict() + try: + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [req_dict[field] for field in self.fields]) + key = ";".join(["%s"] * len(current_values)) + cache_keys.extend([key % tuple(current_values)]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) + return cache_keys + else: + req_dict = request._asdict() + try: + current_values = ( + self.condition_value_fn(request) if self.condition_value_fn else + [req_dict[field] for field in self.fields]) + key = ";".join(["%s"] * len(current_values)) + cache_key = key % tuple(current_values) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) + return cache_key + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return self._batching_kwargs diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py new file mode 100644 index 000000000000..0b8a384b934d --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -0,0 +1,289 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import unittest +from unittest.mock import MagicMock + +import pytest + +import apache_beam as beam +from apache_beam.coders import coders +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from google.api_core.exceptions import BadRequest + from testcontainers.redis import RedisContainer + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery import \ + BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import \ + ValidateResponse +except ImportError: + raise unittest.SkipTest( + 'Google Cloud BigQuery dependencies are not installed.') + +_LOGGER = logging.getLogger(__name__) + + +def query_fn(row: beam.Row): + query = ( + "SELECT * FROM " + "`apache-beam-testing.my_ecommerce.product_details`" + " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] + return query + + +def condition_value_fn(row: beam.Row): + return [row.id] # type: ignore[attr-defined] + + +@pytest.mark.uses_testcontainer +class TestBigQueryEnrichmentIT(unittest.TestCase): + def setUp(self) -> None: + self.project = 'apache-beam-testing' + self.condition_template = "id = '{}'" + self.table_name = "`apache-beam-testing.my_ecommerce.product_details`" + self.retries = 3 + self._start_container() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error( + 'Unable to start redis container for BigQuery ' + ' enrichment tests.') + raise e + + def tearDown(self) -> None: + self.container.stop() + self.client = None + + def test_bigquery_enrichment(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + fields = ['id'] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + fields=fields, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_query_fn(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler(project=self.project, query_fn=query_fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_condition_value_fn(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_condition_without_batch(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_bad_request(self): + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + column_names=['wrong_column'], + condition_value_fn=condition_value_fn, + ) + with self.assertRaises(BadRequest): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(requests) + | "Enrichment" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_bigquery_enrichment_with_redis(self): + """ + In this test, we run two pipelines back to back. + + In the first pipeline, we run a simple BigQuery enrichment pipeline + with zero cache records. Therefore, it makes call to the source + and ultimately writes to the cache with a TTL of 300 seconds. + + For the second pipeline, we mock the + `BigQueryEnrichmentHandler`'s `__call__` method to always + return a `None` response. However, this change won't impact the second + pipeline because the Enrichment transform first checks the cache to fulfill + requests. Since all requests are cached, it will return from there without + making calls to the BigQuery service. + """ + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + + # manually check cache entry + c = coders.StrUtf8Coder() + for req in requests: + key = handler.get_cache_key(req) + response = self.client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + actual = BigQueryEnrichmentHandler.__call__ + BigQueryEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + beam.Row())) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + BigQueryEnrichmentHandler.__call__ = actual + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py new file mode 100644 index 000000000000..98ac6244910c --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +# pylint: disable=ungrouped-imports +try: + from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigquery_it_test import condition_value_fn + from apache_beam.transforms.enrichment_handlers.bigquery_it_test import query_fn +except ImportError: + raise unittest.SkipTest( + 'Google Cloud BigQuery dependencies are not installed.') + + +class TestBigQueryEnrichment(unittest.TestCase): + def setUp(self) -> None: + self.project = 'apache-beam-testing' + + @parameterized.expand([ + ("", "", [], None, None, 1, 2), + ("table", "", ["id"], condition_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], condition_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], None, query_fn, 2, 10), + ]) + def test_valid_params( + self, + table_name, + row_restriction_template, + fields, + condition_value_fn, + query_fn, + min_batch_size, + max_batch_size): + """ + TC 1: Only batch size are provided. It should raise an error. + TC 2: Either of `row_restriction template` or `query_fn` is not provided. + TC 3: Both `fields` and `condition_value_fn` are provided. + TC 4: Query construction details are provided along with `query_fn`. + """ + with self.assertRaises(ValueError): + _ = BigQueryEnrichmentHandler( + project=self.project, + table_name=table_name, + row_restriction_template=row_restriction_template, + fields=fields, + condition_value_fn=condition_value_fn, + query_fn=query_fn, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index c0955aa183ba..5929326f75f9 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -199,6 +199,7 @@ ignore_identifiers = [ 'apache_beam.transforms.ptransform.PTransformWithSideInputs', 'apache_beam.transforms.trigger._ParallelTriggerFn', 'apache_beam.transforms.trigger.InMemoryUnmergedState', + 'apache_beam.transforms.utils.BatchElements', 'apache_beam.typehints.typehints.AnyTypeConstraint', 'apache_beam.typehints.typehints.CompositeTypeHint', 'apache_beam.typehints.typehints.TypeConstraint',