From ac59edc65a2868082c4966fab18f2c8d7a133129 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Tue, 14 May 2024 17:13:47 -0400 Subject: [PATCH 1/8] bigquery enrichment with it --- sdks/python/apache_beam/io/requestresponse.py | 56 +++-- .../apache_beam/transforms/enrichment.py | 24 ++- .../enrichment_handlers/bigquery.py | 192 ++++++++++++++++++ .../enrichment_handlers/bigquery_it_test.py | 160 +++++++++++++++ 4 files changed, 409 insertions(+), 23 deletions(-) create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index 4458aa59c18cd..a4289190b6afe 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -28,6 +28,7 @@ from typing import Any from typing import Dict from typing import Generic +from typing import List from typing import Optional from typing import Tuple from typing import TypeVar @@ -477,55 +478,68 @@ def __enter__(self): self.client = redis.Redis(self.host, self.port, **self.kwargs) def __call__(self, element, *args, **kwargs): + are_batched_requests = isinstance(element, List) if self.mode == _RedisMode.READ: - cache_request = self.source_caller.get_cache_key(element) - # check if the caller is a enrichment handler. EnrichmentHandler - # provides the request format for cache. - if cache_request: - encoded_request = self.request_coder.encode(cache_request) + return self._read_cache(element, are_batched_requests) + else: + return self._write_cache(element, are_batched_requests) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + def _read_cache(self, request, are_batched_requests: bool): + requests = [request] if not are_batched_requests else request + responses = [] + for req in requests: + cache_key = self.source_caller.get_cache_key(req) + if cache_key: + encoded_request = self.request_coder.encode(cache_key) else: - encoded_request = self.request_coder.encode(element) + encoded_request = self.request_coder.encode(req) encoded_response = self.client.get(encoded_request) if not encoded_response: # no cache entry present for this request. - return element, None + responses.append((req, None)) + continue if self.response_coder is None: try: response_dict = json.loads(encoded_response.decode('utf-8')) response = beam.Row(**response_dict) + responses.append((req, response)) except Exception: _LOGGER.warning( - 'cannot decode response from redis cache for %s.' % element) - return element, None + 'cannot decode response from redis cache for %s.' % req) + responses.append((req, None)) else: response = self.response_coder.decode(encoded_response) - return element, response - else: - cache_request = self.source_caller.get_cache_key(element[0]) + responses.append((req, response)) + return responses if are_batched_requests else responses[0] + + def _write_cache(self, request, are_batched_requests): + requests = [request] if not are_batched_requests else request + for req in requests: + cache_request = self.source_caller.get_cache_key(req[0]) if cache_request: encoded_request = self.request_coder.encode(cache_request) else: - encoded_request = self.request_coder.encode(element[0]) + encoded_request = self.request_coder.encode(req[0]) if self.response_coder is None: try: - encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') + encoded_response = json.dumps(req[1]._asdict()).encode('utf-8') except Exception: _LOGGER.warning( 'cannot encode response %s for %s to store in ' - 'redis cache.' % (element[1], element[0])) - return element + 'redis cache.' % (req[1], req[0])) + continue else: - encoded_response = self.response_coder.encode(element[1]) + encoded_response = self.response_coder.encode(req[1]) # Write to cache with TTL. Set nx to True to prevent overwriting for the # same key. self.client.set( encoded_request, encoded_response, self.time_to_live, nx=True) - return element - - def __exit__(self, exc_type, exc_val, exc_tb): - self.client.close() + return request class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT], diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index ddfbba5337fb4..71291f564fa3d 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -16,7 +16,7 @@ # import logging from datetime import timedelta -from typing import Any +from typing import Any, Mapping from typing import Callable from typing import Dict from typing import Optional @@ -24,6 +24,7 @@ from typing import Union import apache_beam as beam +from apache_beam import BatchElements from apache_beam.coders import coders from apache_beam.io.requestresponse import DEFAULT_CACHE_ENTRY_TTL_SEC from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS @@ -82,6 +83,12 @@ def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: return beam.Row(**left) +class FlattenBatch(beam.DoFn): + def process(self, elements, *args, **kwargs): + for element in elements: + yield element + + class EnrichmentSourceHandler(Caller[InputT, OutputT]): """Wrapper class for `apache_beam.io.requestresponse.Caller`. @@ -100,6 +107,10 @@ def get_cache_key(self, request: InputT) -> str: """ return "request: %s" % request + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return {} + class Enrichment(beam.PTransform[beam.PCollection[InputT], beam.PCollection[OutputT]]): @@ -143,6 +154,7 @@ def __init__( self._timeout = timeout self._repeater = repeater self._throttler = throttler + self._batching_kwargs = self._source_handler.batch_elements_kwargs() def expand(self, input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: @@ -153,7 +165,11 @@ def expand(self, if self._cache: self._cache.request_coder = request_coder - fetched_data = input_row | RequestResponseIO( + input_data = input_row + if self._batching_kwargs: + input_data = input_row | BatchElements(**self._batching_kwargs) + + fetched_data = input_data | RequestResponseIO( caller=self._source_handler, timeout=self._timeout, repeater=self._repeater, @@ -161,6 +177,10 @@ def expand(self, throttler=self._throttler) # EnrichmentSourceHandler returns a tuple of (request,response). + # if batching is enabled then handle accordingly + if self._batching_kwargs: + fetched_data = fetched_data | "Flatten" >> beam.ParDo(FlattenBatch()) + return ( fetched_data | "enrichment_join" >> diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py new file mode 100644 index 0000000000000..8e9ad74408ec6 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -0,0 +1,192 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any +from typing import Callable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + +from apache_beam.pvalue import Row +from google.cloud import bigquery + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + +QueryFn = Callable[[beam.Row], str] +ConditionValueFn = Callable[[beam.Row], List[Any]] + + +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], + Union[Row, List[Row]]]): + def __init__( + self, + project: str, + *, + table_name: Optional[str] = "", + row_restriction_template: Optional[str] = "", + fields: Optional[List[str]] = None, + column_names: Optional[List[str]] = None, + condition_value_fn: Optional[ConditionValueFn] = None, + query_fn: Optional[QueryFn] = None, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + **kwargs, + ): + """BigQuery handler for + :class:`apache_beam.transforms.enrichment.Enrichment` transform. + + Example Usage:: + handler = BigQueryEnrichmentHandler(project=project_name, + row_restriction="id='{}'", + + table_name='project.dataset.table', + + fields=fields, + + min_batch_size=2, + + max_batch_size=100, + ) + + Args: + project: Google Cloud project ID for the BigQuery table. + table_name (str): Fully qualified BigQuery table name + in the format `project.dataset.table`. + row_restriction_template (str): A template string for the `WHERE` clause + in the BigQuery query with placeholders to dynamically filter rows + based on input data. + fields: (Optional[List[str]]) List of field names present in the input + `beam.Row`. These are used to construct the WHERE clause + (if `condition_value_fn` is not provided). + column_names: (Optional[List[str]]) Names of columns to select from the + BigQuery table. If not provided, all columns (`*`) are selected. + condition_value_fn: (Optional[Callable[[beam.Row], str]]) A function + that takes a `beam.Row` and returns a string value to be used in the + `WHERE` clause of the BigQuery query. + query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a + `beam.Row` and returns a complete BigQuery SQL query string. + If provided, it overrides the default query construction that use + `fields`, `column_names`, and `condition_value_fn`. + min_batch_size: (Optional[int]) Minimum number of rows to batch together + when querying BigQuery. + max_batch_size: (Optional[int]) Maximum number of rows to batch together. + **kwargs: Additional keyword arguments to pass to `bigquery.Client`. + + Note: + * `min_batch_size` and `max_batch_size` won't have any effect if the + `query_fn` is provided. + * Either `fields` or `condition_value_fn` must be provided for query + construction if `query_fn` is not provided. + * If `query_fn` is provided, it overrides the default query construction. + * Ensure appropriate permissions are granted for BigQuery access. + """ + self.project = project + self.column_names = column_names + self.select_fields = ",".join(column_names) if column_names else '*' + self.row_restriction_template = row_restriction_template + self.table_name = table_name + self.fields = fields + self.condition_value_fn = condition_value_fn + self.query_fn = query_fn + self.query_template = ( + "SELECT %s FROM %s WHERE %s" % + (self.select_fields, self.table_name, self.row_restriction_template)) + self.kwargs = kwargs + self._batching_kwargs = {} + if not query_fn: + if min_batch_size is not None: + self._batching_kwargs['min_batch_size'] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs['max_batch_size'] = max_batch_size + + def __enter__(self): + self.client = bigquery.Client(project=self.project, **self.kwargs) + + def _execute_query(self, query: str): + try: + results = self.client.query(query=query).result() + if self._batching_kwargs: + return [dict(row.items()) for row in results] + else: + return [dict(row.items()) for row in results][0] + except RuntimeError: + raise RuntimeError("Could not complete the query request: %s" % query) + + def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): + # raise ValueError(type(request)) + if isinstance(request, List): + values = [] + responses = [] + requests_map = {} + batch_size = len(request) + raw_query = self.query_template + if batch_size > 1: + batched_condition_template = ' or '.join( + [self.row_restriction_template] * batch_size) + raw_query = self.query_template.replace( + self.row_restriction_template, batched_condition_template) + for req in request: + request_dict = req._asdict() + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [request_dict.get(field) for field in self.fields]) + values.extend(current_values) + requests_map.update((val, req) for val in current_values) + query = raw_query.format(*values) + + responses_dict = self._execute_query(query) + for response in responses_dict: + for value in response.values(): + if value in requests_map: + responses.append((requests_map[value], beam.Row(**response))) + return responses + else: + request_dict = request._asdict() + if self.query_fn: + # if a query_fn is provided then it return a list of values + # that should be populated into the query template string. + query = self.query_fn(request) + else: + values = ( + self.condition_value_fn(request) if self.condition_value_fn else + list(map(request_dict.get, self.fields))) + # construct the query. + query = self.query_template.format(*values) + response_dict = self._execute_query(query) + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + def get_cache_key( + self, request: Union[beam.Row, List[beam.Row]]) -> Union[str, List[str]]: + key = ";".join(["%s"] * len(self.fields)) + if self._batching_kwargs: + cache_keys = [] + for req in request: + req_dict = req._asdict() + key = ";".join(["%s"] * len(self.fields)) + cache_keys.extend([key % req_dict[field] for field in self.fields]) + return cache_keys + else: + req_dict = request._asdict() + return key % (req_dict.get(field) for field in self.fields) + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return self._batching_kwargs diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py new file mode 100644 index 0000000000000..3b5967dcad049 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import ValidateResponse + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.enrichment import Enrichment +from apache_beam.transforms.enrichment_handlers.bigquery import \ + BigQueryEnrichmentHandler + + +def query_fn(row: beam.Row): + query = ( + "SELECT * FROM " + "`google.com:clouddfe.my_ecommerce.product_details`" + " WHERE id = '{}'".format(row.id)) + return query + + +def condition_value_fn(row: beam.Row): + return [row.id] + + +class TestBigQueryEnrichment(unittest.TestCase): + def setUp(self) -> None: + self.project = 'google.com:clouddfe' + self.query_template = ( + "SELECT * FROM " + "`google.com:clouddfe.my_ecommerce.product_details`" + " WHERE id = '{}'") + self.condition_template = "id = '{}'" + self.table_name = "`google.com:clouddfe.my_ecommerce.product_details`" + + def test_bigquery_enrichment(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + fields = ['id'] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + fields=fields, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_query_fn(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler(project=self.project, query_fn=query_fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_condition_value_fn(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_bigquery_enrichment_with_condition_without_batch(self): + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + +if __name__ == '__main__': + unittest.main() From 4738dadb2343a5727e4a100c038ed94a3c1c4b66 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Tue, 14 May 2024 20:28:00 -0400 Subject: [PATCH 2/8] unit tests --- .../apache_beam/transforms/enrichment.py | 3 +- .../enrichment_handlers/bigquery.py | 94 +++++++++++++------ .../enrichment_handlers/bigquery_it_test.py | 22 +++-- .../enrichment_handlers/bigquery_test.py | 72 ++++++++++++++ sdks/python/scripts/generate_pydoc.sh | 1 + 5 files changed, 155 insertions(+), 37 deletions(-) create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index 71291f564fa3d..b85b43d69ce5c 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -16,9 +16,10 @@ # import logging from datetime import timedelta -from typing import Any, Mapping +from typing import Any from typing import Callable from typing import Dict +from typing import Mapping from typing import Optional from typing import TypeVar from typing import Union diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 8e9ad74408ec6..7e34aefcc78ee 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -16,29 +16,69 @@ # from typing import Any from typing import Callable +from typing import Dict from typing import List from typing import Mapping from typing import Optional from typing import Union -from apache_beam.pvalue import Row from google.cloud import bigquery import apache_beam as beam +from apache_beam.pvalue import Row from apache_beam.transforms.enrichment import EnrichmentSourceHandler QueryFn = Callable[[beam.Row], str] ConditionValueFn = Callable[[beam.Row], List[Any]] +def _validate_batch_query_fn(query_fn, min_batch_size, max_batch_size): + if query_fn and min_batch_size and max_batch_size: + raise ValueError( + "Please provide exactly one of `query_fn` or " + "(`min_batch_size` and `max_batch_size`)") + + +def _validate_bigquery_metadata( + table_name, row_restriction_template, fields, condition_value_fn, query_fn): + if query_fn and bool(table_name or row_restriction_template or fields or + condition_value_fn): + raise ValueError( + "Please provide either `query_fn` or the parameters " + "`table_name`, `row_restriction_template`, and `fields` " + "together.") + elif not query_fn and not (table_name and row_restriction_template and + (fields or condition_value_fn)): + raise ValueError( + "Please provide either `query_fn` or the parameters " + "`table_name`, `row_restriction_template`, and" + "`fields/condition_value_fn` together.") + if (fields and condition_value_fn) or (not fields and not condition_value_fn): + raise ValueError( + "Please provide exactly one of `fields` or " + "`condition_value_fn`") + + class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], Union[Row, List[Row]]]): + """Enrichment handler for Google Cloud BigQuery. + + Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` + transform. + + This handler pulls data from BigQuery per element by default. To change this + behavior, set the `min_batch_size` and `max_batch_size` parameters. + These min and max values for batch size are sent to the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + NOTE: Elements cannot be batched when using the `query_fn` parameter. + """ def __init__( self, project: str, *, - table_name: Optional[str] = "", - row_restriction_template: Optional[str] = "", + table_name: str = "", + row_restriction_template: str = "", fields: Optional[List[str]] = None, column_names: Optional[List[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, @@ -47,21 +87,14 @@ def __init__( max_batch_size: Optional[int] = None, **kwargs, ): - """BigQuery handler for - :class:`apache_beam.transforms.enrichment.Enrichment` transform. - - Example Usage:: + """ + Example Usage: handler = BigQueryEnrichmentHandler(project=project_name, - row_restriction="id='{}'", - - table_name='project.dataset.table', - - fields=fields, - - min_batch_size=2, - - max_batch_size=100, - ) + row_restriction="id='{}'", + table_name='project.dataset.table', + fields=fields, + min_batch_size=2, + max_batch_size=100) Args: project: Google Cloud project ID for the BigQuery table. @@ -95,12 +128,19 @@ def __init__( * If `query_fn` is provided, it overrides the default query construction. * Ensure appropriate permissions are granted for BigQuery access. """ + _validate_bigquery_metadata( + table_name, + row_restriction_template, + fields, + condition_value_fn, + query_fn) + _validate_batch_query_fn(query_fn, min_batch_size, max_batch_size) self.project = project self.column_names = column_names self.select_fields = ",".join(column_names) if column_names else '*' self.row_restriction_template = row_restriction_template self.table_name = table_name - self.fields = fields + self.fields = fields if fields else [] self.condition_value_fn = condition_value_fn self.query_fn = query_fn self.query_template = ( @@ -108,11 +148,10 @@ def __init__( (self.select_fields, self.table_name, self.row_restriction_template)) self.kwargs = kwargs self._batching_kwargs = {} - if not query_fn: - if min_batch_size is not None: - self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: - self._batching_kwargs['max_batch_size'] = max_batch_size + if min_batch_size is not None: + self._batching_kwargs['min_batch_size'] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs['max_batch_size'] = max_batch_size def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) @@ -132,7 +171,7 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): if isinstance(request, List): values = [] responses = [] - requests_map = {} + requests_map: Dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: @@ -173,16 +212,15 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def get_cache_key( - self, request: Union[beam.Row, List[beam.Row]]) -> Union[str, List[str]]: + def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): key = ";".join(["%s"] * len(self.fields)) - if self._batching_kwargs: + if isinstance(request, List): cache_keys = [] for req in request: req_dict = req._asdict() key = ";".join(["%s"] * len(self.fields)) cache_keys.extend([key % req_dict[field] for field in self.fields]) - return cache_keys + return cache_keys else: req_dict = request._asdict() return key % (req_dict.get(field) for field in self.fields) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index 3b5967dcad049..1a7325d4d22ef 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -16,28 +16,34 @@ # import unittest -from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import ValidateResponse - import apache_beam as beam from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.transforms.enrichment import Enrichment -from apache_beam.transforms.enrichment_handlers.bigquery import \ - BigQueryEnrichmentHandler + +# pylint: disable=ungrouped-imports +try: + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery import \ + BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import \ + ValidateResponse +except ImportError: + raise unittest.SkipTest( + 'Google Cloud BigQuery dependencies are not installed.') def query_fn(row: beam.Row): query = ( "SELECT * FROM " "`google.com:clouddfe.my_ecommerce.product_details`" - " WHERE id = '{}'".format(row.id)) + " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] return query def condition_value_fn(row: beam.Row): - return [row.id] + return [row.id] # type: ignore[attr-defined] -class TestBigQueryEnrichment(unittest.TestCase): +class TestBigQueryEnrichmentIT(unittest.TestCase): def setUp(self) -> None: self.project = 'google.com:clouddfe' self.query_template = ( diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py new file mode 100644 index 0000000000000..7a57ad6629944 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +# pylint: disable=ungrouped-imports +try: + from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigquery_it_test import condition_value_fn + from apache_beam.transforms.enrichment_handlers.bigquery_it_test import query_fn +except ImportError: + raise unittest.SkipTest( + 'Google Cloud BigQuery dependencies are not installed.') + + +class TestBigQueryEnrichment(unittest.TestCase): + def setUp(self) -> None: + self.project = 'apache-beam-testing' + + @parameterized.expand([ + ("", "", [], None, None, 1, 2), + ("table", "", ["id"], condition_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], condition_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], None, query_fn, 2, 10), + ("", "", None, None, query_fn, 2, 10), + ]) + def test_valid_params( + self, + table_name, + row_restriction_template, + fields, + condition_value_fn, + query_fn, + min_batch_size, + max_batch_size): + """ + TC 1: Only batch size are provided. It should raise an error. + TC 2: Either of `row_restriction template` or `query_fn` is not provided. + TC 3: Both `fields` and `condition_value_fn` are provided. + TC 4: Query construction details are provided along with `query_fn`. + TC 5: Batch size is provided with `query_fn`. + """ + with self.assertRaises(ValueError): + _ = BigQueryEnrichmentHandler( + project=self.project, + table_name=table_name, + row_restriction_template=row_restriction_template, + fields=fields, + condition_value_fn=condition_value_fn, + query_fn=query_fn, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index c0955aa183bac..5929326f75f9b 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -199,6 +199,7 @@ ignore_identifiers = [ 'apache_beam.transforms.ptransform.PTransformWithSideInputs', 'apache_beam.transforms.trigger._ParallelTriggerFn', 'apache_beam.transforms.trigger.InMemoryUnmergedState', + 'apache_beam.transforms.utils.BatchElements', 'apache_beam.typehints.typehints.AnyTypeConstraint', 'apache_beam.typehints.typehints.CompositeTypeHint', 'apache_beam.typehints.typehints.TypeConstraint', From a1ae5d169944aa948206426584009a50fcf33122 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Wed, 15 May 2024 11:01:19 -0400 Subject: [PATCH 3/8] more integration tests, move BE to RRIO --- sdks/python/apache_beam/io/requestresponse.py | 137 ++++++++++-------- .../apache_beam/transforms/enrichment.py | 35 +---- .../enrichment_handlers/bigquery.py | 50 +++++-- .../enrichment_handlers/bigquery_it_test.py | 127 ++++++++++++++++ 4 files changed, 254 insertions(+), 95 deletions(-) diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index a4289190b6afe..d7011e5a8ff35 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -29,6 +29,7 @@ from typing import Dict from typing import Generic from typing import List +from typing import Mapping from typing import Optional from typing import Tuple from typing import TypeVar @@ -43,6 +44,7 @@ from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler from apache_beam.metrics import Metrics from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC +from apache_beam.transforms.util import BatchElements from apache_beam.utils import retry RequestT = TypeVar('RequestT') @@ -144,6 +146,10 @@ def get_cache_key(self, request: RequestT) -> str: """ return "" + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return {} + class ShouldBackOff(abc.ABC): """ @@ -477,70 +483,71 @@ def __init__( def __enter__(self): self.client = redis.Redis(self.host, self.port, **self.kwargs) + def _read_cache(self, element): + cache_request = self.source_caller.get_cache_key(element) + # check if the caller is a enrichment handler. EnrichmentHandler + # provides the request format for cache. + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element) + + encoded_response = self.client.get(encoded_request) + if not encoded_response: + # no cache entry present for this request. + return element, None + + if self.response_coder is None: + try: + response_dict = json.loads(encoded_response.decode('utf-8')) + response = beam.Row(**response_dict) + except Exception: + _LOGGER.warning( + 'cannot decode response from redis cache for %s.' % element) + return element, None + else: + response = self.response_coder.decode(encoded_response) + return element, response + + def _write_cache(self, element): + cache_request = self.source_caller.get_cache_key(element[0]) + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element[0]) + if self.response_coder is None: + try: + encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') + except Exception: + _LOGGER.warning( + 'cannot encode response %s for %s to store in ' + 'redis cache.' % (element[1], element[0])) + return element + else: + encoded_response = self.response_coder.encode(element[1]) + # Write to cache with TTL. Set nx to True to prevent overwriting for the + # same key. + self.client.set( + encoded_request, encoded_response, self.time_to_live, nx=True) + return element + def __call__(self, element, *args, **kwargs): - are_batched_requests = isinstance(element, List) if self.mode == _RedisMode.READ: - return self._read_cache(element, are_batched_requests) + if isinstance(element, List): + responses = [self._read_cache(e) for e in element] + return responses + else: + return self._read_cache(element) else: - return self._write_cache(element, are_batched_requests) + if isinstance(element, List): + responses = [self._write_cache(e) for e in element] + return responses + else: + return self._write_cache(element) def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def _read_cache(self, request, are_batched_requests: bool): - requests = [request] if not are_batched_requests else request - responses = [] - for req in requests: - cache_key = self.source_caller.get_cache_key(req) - if cache_key: - encoded_request = self.request_coder.encode(cache_key) - else: - encoded_request = self.request_coder.encode(req) - - encoded_response = self.client.get(encoded_request) - if not encoded_response: - # no cache entry present for this request. - responses.append((req, None)) - continue - - if self.response_coder is None: - try: - response_dict = json.loads(encoded_response.decode('utf-8')) - response = beam.Row(**response_dict) - responses.append((req, response)) - except Exception: - _LOGGER.warning( - 'cannot decode response from redis cache for %s.' % req) - responses.append((req, None)) - else: - response = self.response_coder.decode(encoded_response) - responses.append((req, response)) - return responses if are_batched_requests else responses[0] - - def _write_cache(self, request, are_batched_requests): - requests = [request] if not are_batched_requests else request - for req in requests: - cache_request = self.source_caller.get_cache_key(req[0]) - if cache_request: - encoded_request = self.request_coder.encode(cache_request) - else: - encoded_request = self.request_coder.encode(req[0]) - if self.response_coder is None: - try: - encoded_response = json.dumps(req[1]._asdict()).encode('utf-8') - except Exception: - _LOGGER.warning( - 'cannot encode response %s for %s to store in ' - 'redis cache.' % (req[1], req[0])) - continue - else: - encoded_response = self.response_coder.encode(req[1]) - # Write to cache with TTL. Set nx to True to prevent overwriting for the - # same key. - self.client.set( - encoded_request, encoded_response, self.time_to_live, nx=True) - return request - class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT], beam.PCollection[ResponseT]]): @@ -722,6 +729,13 @@ def request_coder(self, request_coder: coders.Coder): self._request_coder = request_coder +class FlattenBatch(beam.DoFn): + """Flatten a batched PCollection.""" + def process(self, elements, *args, **kwargs): + for element in elements: + yield element + + class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], beam.PCollection[ResponseT]]): """A :class:`RequestResponseIO` transform to read and write to APIs. @@ -767,6 +781,7 @@ def __init__( self._repeater = NoOpsRepeater() self._cache = cache self._throttler = throttler + self._batching_kwargs = self._caller.batch_elements_kwargs() def expand( self, @@ -788,6 +803,10 @@ def expand( ).with_outputs( 'cache_misses', main='cached_responses')) + # Batch elements if batching is enabled. + if self._batching_kwargs: + inputs = inputs | BatchElements(**self._batching_kwargs) + if isinstance(self._throttler, DefaultThrottler): # DefaultThrottler applies throttling in the DoFn of # Call PTransform. @@ -810,6 +829,10 @@ def expand( should_backoff=self._should_backoff, repeater=self._repeater)) + # if batching is enabled then handle accordingly. + if self._batching_kwargs: + responses = responses | "FlattenBatch" >> beam.ParDo(FlattenBatch()) + if self._cache: # write to cache. _ = responses | self._cache.get_write() diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index b85b43d69ce5c..5bb1e2024e793 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -19,13 +19,11 @@ from typing import Any from typing import Callable from typing import Dict -from typing import Mapping from typing import Optional from typing import TypeVar from typing import Union import apache_beam as beam -from apache_beam import BatchElements from apache_beam.coders import coders from apache_beam.io.requestresponse import DEFAULT_CACHE_ENTRY_TTL_SEC from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS @@ -84,12 +82,6 @@ def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: return beam.Row(**left) -class FlattenBatch(beam.DoFn): - def process(self, elements, *args, **kwargs): - for element in elements: - yield element - - class EnrichmentSourceHandler(Caller[InputT, OutputT]): """Wrapper class for `apache_beam.io.requestresponse.Caller`. @@ -108,10 +100,6 @@ def get_cache_key(self, request: InputT) -> str: """ return "request: %s" % request - def batch_elements_kwargs(self) -> Mapping[str, Any]: - """Returns a kwargs suitable for `beam.BatchElements`.""" - return {} - class Enrichment(beam.PTransform[beam.PCollection[InputT], beam.PCollection[OutputT]]): @@ -155,7 +143,6 @@ def __init__( self._timeout = timeout self._repeater = repeater self._throttler = throttler - self._batching_kwargs = self._source_handler.batch_elements_kwargs() def expand(self, input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: @@ -166,22 +153,16 @@ def expand(self, if self._cache: self._cache.request_coder = request_coder - input_data = input_row - if self._batching_kwargs: - input_data = input_row | BatchElements(**self._batching_kwargs) - - fetched_data = input_data | RequestResponseIO( - caller=self._source_handler, - timeout=self._timeout, - repeater=self._repeater, - cache=self._cache, - throttler=self._throttler) + fetched_data = ( + input_row + | "Enrichment-RRIO" >> RequestResponseIO( + caller=self._source_handler, + timeout=self._timeout, + repeater=self._repeater, + cache=self._cache, + throttler=self._throttler)) # EnrichmentSourceHandler returns a tuple of (request,response). - # if batching is enabled then handle accordingly - if self._batching_kwargs: - fetched_data = fetched_data | "Flatten" >> beam.ParDo(FlattenBatch()) - return ( fetched_data | "enrichment_join" >> diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 7e34aefcc78ee..1e23c4de419dd 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -22,6 +22,7 @@ from typing import Optional from typing import Union +from google.api_core.exceptions import BadRequest from google.cloud import bigquery import apache_beam as beam @@ -53,7 +54,8 @@ def _validate_bigquery_metadata( "Please provide either `query_fn` or the parameters " "`table_name`, `row_restriction_template`, and" "`fields/condition_value_fn` together.") - if (fields and condition_value_fn) or (not fields and not condition_value_fn): + if not query_fn and ((fields and condition_value_fn) or + (not fields and not condition_value_fn)): raise ValueError( "Please provide exactly one of `fields` or " "`condition_value_fn`") @@ -163,11 +165,15 @@ def _execute_query(self, query: str): return [dict(row.items()) for row in results] else: return [dict(row.items()) for row in results][0] - except RuntimeError: - raise RuntimeError("Could not complete the query request: %s" % query) + except BadRequest as e: + raise BadRequest( + f'Could not execute the query: {query}. Please check if ' + f'the query is properly formatted and the BigQuery ' + f'table exists. {e}') + except RuntimeError as e: + raise RuntimeError(f"Could not complete the query request: {query}. {e}") def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): - # raise ValueError(type(request)) if isinstance(request, List): values = [] responses = [] @@ -181,9 +187,14 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): self.row_restriction_template, batched_condition_template) for req in request: request_dict = req._asdict() - current_values = ( - self.condition_value_fn(req) if self.condition_value_fn else - [request_dict.get(field) for field in self.fields]) + try: + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [request_dict[field] for field in self.fields]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) values.extend(current_values) requests_map.update((val, req) for val in current_values) query = raw_query.format(*values) @@ -213,17 +224,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): - key = ";".join(["%s"] * len(self.fields)) if isinstance(request, List): cache_keys = [] for req in request: req_dict = req._asdict() - key = ";".join(["%s"] * len(self.fields)) - cache_keys.extend([key % req_dict[field] for field in self.fields]) + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [req_dict[field] for field in self.fields]) + key = ";".join(["%s"] * len(current_values)) + try: + cache_keys.extend([key % tuple(current_values)]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) return cache_keys else: req_dict = request._asdict() - return key % (req_dict.get(field) for field in self.fields) + try: + current_values = ( + self.condition_value_fn(request) if self.condition_value_fn else + [req_dict[field] for field in self.fields]) + key = ";".join(["%s"] * len(current_values)) + cache_key = key % tuple(current_values) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `fields` are the " + "keys in the input `beam.Row`." + str(e)) + return cache_key def batch_elements_kwargs(self) -> Mapping[str, Any]: """Returns a kwargs suitable for `beam.BatchElements`.""" diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index 1a7325d4d22ef..fd8688bf9871b 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -14,13 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import unittest +from unittest.mock import MagicMock + +import pytest import apache_beam as beam +from apache_beam.coders import coders from apache_beam.testing.test_pipeline import TestPipeline # pylint: disable=ungrouped-imports try: + from google.api_core.exceptions import BadRequest + from testcontainers.redis import RedisContainer from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import \ BigQueryEnrichmentHandler @@ -30,6 +37,8 @@ raise unittest.SkipTest( 'Google Cloud BigQuery dependencies are not installed.') +_LOGGER = logging.getLogger(__name__) + def query_fn(row: beam.Row): query = ( @@ -43,6 +52,7 @@ def condition_value_fn(row: beam.Row): return [row.id] # type: ignore[attr-defined] +@pytest.mark.uses_testcontainer class TestBigQueryEnrichmentIT(unittest.TestCase): def setUp(self) -> None: self.project = 'google.com:clouddfe' @@ -52,6 +62,28 @@ def setUp(self) -> None: " WHERE id = '{}'") self.condition_template = "id = '{}'" self.table_name = "`google.com:clouddfe.my_ecommerce.product_details`" + self.retries = 3 + self._start_container() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error( + 'Unable to start redis container for BigQuery ' + ' enrichment tests.') + raise e + + def tearDown(self) -> None: + self.container.stop() + self.client = None def test_bigquery_enrichment(self): expected_fields = [ @@ -161,6 +193,101 @@ def test_bigquery_enrichment_with_condition_without_batch(self): | Enrichment(handler) | beam.ParDo(ValidateResponse(expected_fields))) + def test_bigquery_enrichment_bad_request(self): + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + column_names=['wrong_column'], + condition_value_fn=condition_value_fn, + ) + with self.assertRaises(BadRequest): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(requests) + | "Enrichment" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_bigquery_enrichment_with_redis(self): + """ + In this test, we run two pipelines back to back. + + In the first pipeline, we run a simple BigQuery enrichment pipeline + with zero cache records. Therefore, it makes call to the source + and ultimately writes to the cache with a TTL of 300 seconds. + + For the second pipeline, we mock the + `BigQueryEnrichmentHandler`'s `__call__` method to always + return a `None` response. However, this change won't impact the second + pipeline because the Enrichment transform first checks the cache to fulfill + requests. Since all requests are cached, it will return from there without + making calls to the BigQuery service. + """ + expected_fields = [ + 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + ] + requests = [ + beam.Row( + id='13842', + name='low profile dyed cotton twill cap - navy w39s55d', + quantity=2), + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + ] + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template=self.condition_template, + table_name=self.table_name, + condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + + # manually check cache entry + c = coders.StrUtf8Coder() + for req in requests: + key = handler.get_cache_key(req) + response = self.client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + actual = BigQueryEnrichmentHandler.__call__ + BigQueryEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row( + id='15816', + name='low profile dyed cotton twill cap - putty w39s55d', + quantity=1), + beam.Row())) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + BigQueryEnrichmentHandler.__call__ = actual + if __name__ == '__main__': unittest.main() From 6d4507a08dde7e091cac4da37da59ce430f44f60 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Wed, 15 May 2024 11:58:43 -0400 Subject: [PATCH 4/8] trigger postcommit --- .github/trigger_files/beam_PostCommit_Python.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index c4edaa85a89d1..63bd5651def01 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,3 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run" } + From 9cf94d6f152889056bbc5f298a092c56827dab64 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Wed, 15 May 2024 12:29:36 -0400 Subject: [PATCH 5/8] fix doc, project name --- .../enrichment_handlers/bigquery.py | 21 ++++++++++++------- .../enrichment_handlers/bigquery_it_test.py | 10 +++------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 1e23c4de419dd..8ce12b7a4899d 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -68,6 +68,15 @@ class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` transform. + To use this handler you need either of the following combinations: + * `table_name`, `row_restriction_template`, `fields` + * `table_name`, `row_restriction_template`, `condition_value_fn` + * `query_fn` + + By default, the handler pulls all columns from the BigQuery table. + To override this, use the `column_name` parameter to specify a list of column + names to fetch. + This handler pulls data from BigQuery per element by default. To change this behavior, set the `min_batch_size` and `max_batch_size` parameters. These min and max values for batch size are sent to the @@ -103,20 +112,18 @@ def __init__( table_name (str): Fully qualified BigQuery table name in the format `project.dataset.table`. row_restriction_template (str): A template string for the `WHERE` clause - in the BigQuery query with placeholders to dynamically filter rows - based on input data. + in the BigQuery query with placeholders (`{}`) to dynamically filter + rows based on input data. fields: (Optional[List[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). column_names: (Optional[List[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. - condition_value_fn: (Optional[Callable[[beam.Row], str]]) A function - that takes a `beam.Row` and returns a string value to be used in the - `WHERE` clause of the BigQuery query. + condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function + that takes a `beam.Row` and returns a list of value to populate in the + placeholder `{}` of `WHERE` clause in the query. query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a `beam.Row` and returns a complete BigQuery SQL query string. - If provided, it overrides the default query construction that use - `fields`, `column_names`, and `condition_value_fn`. min_batch_size: (Optional[int]) Minimum number of rows to batch together when querying BigQuery. max_batch_size: (Optional[int]) Maximum number of rows to batch together. diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index fd8688bf9871b..0b8a384b934da 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -43,7 +43,7 @@ def query_fn(row: beam.Row): query = ( "SELECT * FROM " - "`google.com:clouddfe.my_ecommerce.product_details`" + "`apache-beam-testing.my_ecommerce.product_details`" " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] return query @@ -55,13 +55,9 @@ def condition_value_fn(row: beam.Row): @pytest.mark.uses_testcontainer class TestBigQueryEnrichmentIT(unittest.TestCase): def setUp(self) -> None: - self.project = 'google.com:clouddfe' - self.query_template = ( - "SELECT * FROM " - "`google.com:clouddfe.my_ecommerce.product_details`" - " WHERE id = '{}'") + self.project = 'apache-beam-testing' self.condition_template = "id = '{}'" - self.table_name = "`google.com:clouddfe.my_ecommerce.product_details`" + self.table_name = "`apache-beam-testing.my_ecommerce.product_details`" self.retries = 3 self._start_container() From 31cbd2681c83872cd0131483256dc098ed5060bd Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Fri, 17 May 2024 09:56:39 -0400 Subject: [PATCH 6/8] refactor, improve doc comments --- CHANGES.md | 1 + .../enrichment_handlers/bigquery.py | 59 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index bce9636237e36..dd66e9192b699 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -71,6 +71,7 @@ * Beam YAML now supports the jinja templating syntax. Template variables can be passed with the (json-formatted) `--jinja_variables` flag. * DataFrame API now supports pandas 2.1.x and adds 12 more string functions for Series.([#31185](https://github.com/apache/beam/pull/31185)). +* Added BigQuery handler for enrichment transform (Python) ([#31295](https://github.com/apache/beam/pull/31295)) ## Breaking Changes diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 8ce12b7a4899d..c205a01bed3c3 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -42,23 +42,23 @@ def _validate_batch_query_fn(query_fn, min_batch_size, max_batch_size): def _validate_bigquery_metadata( table_name, row_restriction_template, fields, condition_value_fn, query_fn): - if query_fn and bool(table_name or row_restriction_template or fields or - condition_value_fn): - raise ValueError( - "Please provide either `query_fn` or the parameters " - "`table_name`, `row_restriction_template`, and `fields` " - "together.") - elif not query_fn and not (table_name and row_restriction_template and - (fields or condition_value_fn)): - raise ValueError( - "Please provide either `query_fn` or the parameters " - "`table_name`, `row_restriction_template`, and" - "`fields/condition_value_fn` together.") - if not query_fn and ((fields and condition_value_fn) or - (not fields and not condition_value_fn)): - raise ValueError( - "Please provide exactly one of `fields` or " - "`condition_value_fn`") + if query_fn: + if bool(table_name or row_restriction_template or fields or + condition_value_fn): + raise ValueError( + "Please provide either `query_fn` or the parameters `table_name`, " + "`row_restriction_template`, and `fields/condition_value_fn` " + "together.") + else: + if not (table_name and row_restriction_template): + raise ValueError( + "Please provide either `query_fn` or the parameters " + "`table_name`, `row_restriction_template` together.") + if ((fields and condition_value_fn) or + (not fields and not condition_value_fn)): + raise ValueError( + "Please provide exactly one of `fields` or " + "`condition_value_fn`") class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], @@ -94,8 +94,8 @@ def __init__( column_names: Optional[List[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, - min_batch_size: Optional[int] = None, - max_batch_size: Optional[int] = None, + min_batch_size: int = 1, + max_batch_size: int = 10000, **kwargs, ): """ @@ -124,17 +124,17 @@ def __init__( placeholder `{}` of `WHERE` clause in the query. query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a `beam.Row` and returns a complete BigQuery SQL query string. - min_batch_size: (Optional[int]) Minimum number of rows to batch together - when querying BigQuery. - max_batch_size: (Optional[int]) Maximum number of rows to batch together. + min_batch_size (int): Minimum number of rows to batch together when + querying BigQuery. Defaults to 1 if `query_fn` is not specified. + max_batch_size (int): Maximum number of rows to batch together. + Defaults to 10,000 if `query_fn` is not specified. **kwargs: Additional keyword arguments to pass to `bigquery.Client`. Note: - * `min_batch_size` and `max_batch_size` won't have any effect if the + * `min_batch_size` and `max_batch_size` cannot be defined if the `query_fn` is provided. * Either `fields` or `condition_value_fn` must be provided for query construction if `query_fn` is not provided. - * If `query_fn` is provided, it overrides the default query construction. * Ensure appropriate permissions are granted for BigQuery access. """ _validate_bigquery_metadata( @@ -157,9 +157,8 @@ def __init__( (self.select_fields, self.table_name, self.row_restriction_template)) self.kwargs = kwargs self._batching_kwargs = {} - if min_batch_size is not None: + if not query_fn: self._batching_kwargs['min_batch_size'] = min_batch_size - if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size def __enter__(self): @@ -235,11 +234,11 @@ def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): cache_keys = [] for req in request: req_dict = req._asdict() - current_values = ( - self.condition_value_fn(req) if self.condition_value_fn else - [req_dict[field] for field in self.fields]) - key = ";".join(["%s"] * len(current_values)) try: + current_values = ( + self.condition_value_fn(req) if self.condition_value_fn else + [req_dict[field] for field in self.fields]) + key = ";".join(["%s"] * len(current_values)) cache_keys.extend([key % tuple(current_values)]) except KeyError as e: raise KeyError( From 33fc0be18f8556439e356d18ddc5f1d15a4d8dea Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Fri, 17 May 2024 11:01:02 -0400 Subject: [PATCH 7/8] refactor --- .../transforms/enrichment_handlers/bigquery.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index c205a01bed3c3..f2ed9b566e740 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -94,8 +94,8 @@ def __init__( column_names: Optional[List[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, - min_batch_size: int = 1, - max_batch_size: int = 10000, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, **kwargs, ): """ @@ -158,8 +158,10 @@ def __init__( self.kwargs = kwargs self._batching_kwargs = {} if not query_fn: - self._batching_kwargs['min_batch_size'] = min_batch_size - self._batching_kwargs['max_batch_size'] = max_batch_size + self._batching_kwargs['min_batch_size'] = ( + min_batch_size if min_batch_size else 1) + self._batching_kwargs['max_batch_size'] = ( + max_batch_size if max_batch_size else 10000) def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) From 8e19573b4e1c73269dba43809a0945066ad86306 Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Fri, 17 May 2024 11:55:52 -0400 Subject: [PATCH 8/8] remove query_fn check --- .../transforms/enrichment_handlers/bigquery.py | 18 ++++-------------- .../enrichment_handlers/bigquery_test.py | 2 -- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index f2ed9b566e740..382ae123a81d6 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -33,13 +33,6 @@ ConditionValueFn = Callable[[beam.Row], List[Any]] -def _validate_batch_query_fn(query_fn, min_batch_size, max_batch_size): - if query_fn and min_batch_size and max_batch_size: - raise ValueError( - "Please provide exactly one of `query_fn` or " - "(`min_batch_size` and `max_batch_size`)") - - def _validate_bigquery_metadata( table_name, row_restriction_template, fields, condition_value_fn, query_fn): if query_fn: @@ -94,8 +87,8 @@ def __init__( column_names: Optional[List[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, - min_batch_size: Optional[int] = None, - max_batch_size: Optional[int] = None, + min_batch_size: int = 1, + max_batch_size: int = 10000, **kwargs, ): """ @@ -143,7 +136,6 @@ def __init__( fields, condition_value_fn, query_fn) - _validate_batch_query_fn(query_fn, min_batch_size, max_batch_size) self.project = project self.column_names = column_names self.select_fields = ",".join(column_names) if column_names else '*' @@ -158,10 +150,8 @@ def __init__( self.kwargs = kwargs self._batching_kwargs = {} if not query_fn: - self._batching_kwargs['min_batch_size'] = ( - min_batch_size if min_batch_size else 1) - self._batching_kwargs['max_batch_size'] = ( - max_batch_size if max_batch_size else 10000) + self._batching_kwargs['min_batch_size'] = min_batch_size + self._batching_kwargs['max_batch_size'] = max_batch_size def __enter__(self): self.client = bigquery.Client(project=self.project, **self.kwargs) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py index 7a57ad6629944..98ac6244910c9 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_test.py @@ -37,7 +37,6 @@ def setUp(self) -> None: ("table", "", ["id"], condition_value_fn, None, 2, 10), ("table", "id='{}'", ["id"], condition_value_fn, None, 2, 10), ("table", "id='{}'", ["id"], None, query_fn, 2, 10), - ("", "", None, None, query_fn, 2, 10), ]) def test_valid_params( self, @@ -53,7 +52,6 @@ def test_valid_params( TC 2: Either of `row_restriction template` or `query_fn` is not provided. TC 3: Both `fields` and `condition_value_fn` are provided. TC 4: Query construction details are provided along with `query_fn`. - TC 5: Batch size is provided with `query_fn`. """ with self.assertRaises(ValueError): _ = BigQueryEnrichmentHandler(