From e3fee5156b3515f96dc5ba90ea2fbc6f6be2bd34 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 28 Mar 2024 17:25:00 -0700 Subject: [PATCH] [YAML] Add Partition transform. (#30368) --- .../yaml/programming_guide_test.py | 19 ++ sdks/python/apache_beam/yaml/readme_test.py | 30 ++- sdks/python/apache_beam/yaml/yaml_mapping.py | 88 ++++++- .../apache_beam/yaml/yaml_mapping_test.py | 239 ++++++++++++++++++ .../python/apache_beam/yaml/yaml_transform.py | 6 +- .../en/documentation/programming-guide.md | 16 ++ .../content/en/documentation/sdks/yaml-udf.md | 68 +++++ 7 files changed, 451 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/yaml/programming_guide_test.py b/sdks/python/apache_beam/yaml/programming_guide_test.py index cd7bf6a88149..fe5e242f7f5b 100644 --- a/sdks/python/apache_beam/yaml/programming_guide_test.py +++ b/sdks/python/apache_beam/yaml/programming_guide_test.py @@ -404,6 +404,25 @@ def extract_timestamp(x): # [END setting_timestamp] ''') + def test_partition(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(percentile=1), + beam.Row(percentile=20), + beam.Row(percentile=90), + ]) + _ = elements | YamlTransform( + ''' + # [START model_multiple_pcollections_partition] + type: Partition + config: + by: str(percentile // 10) + language: python + outputs: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] + # [END model_multiple_pcollections_partition] + ''') + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/readme_test.py b/sdks/python/apache_beam/yaml/readme_test.py index 4ca60e6176bd..ea7a015dab5d 100644 --- a/sdks/python/apache_beam/yaml/readme_test.py +++ b/sdks/python/apache_beam/yaml/readme_test.py @@ -128,12 +128,25 @@ def expand(self, pcoll): lambda _: 1, sum, 'count') +class _Fakes: + fn = str + + class SomeTransform(beam.PTransform): + def __init__(*args, **kwargs): + pass + + def expand(self, pcoll): + return pcoll + + RENDER_DIR = None TEST_TRANSFORMS = { 'Sql': FakeSql, 'ReadFromPubSub': FakeReadFromPubSub, 'WriteToPubSub': FakeWriteToPubSub, 'SomeGroupingTransform': FakeAggregation, + 'SomeTransform': _Fakes.SomeTransform, + 'AnotherTransform': _Fakes.SomeTransform, } @@ -155,7 +168,7 @@ def input_file(self, name, content): return path def input_csv(self): - return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n') + return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n') def input_tsv(self): return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n') @@ -250,13 +263,15 @@ def parse_test_methods(markdown_lines): else: if code_lines: if code_lines[0].startswith('- type:'): + is_chain = not any('input:' in line for line in code_lines) # Treat this as a fragment of a larger pipeline. # pylint: disable=not-an-iterable code_lines = [ 'pipeline:', - ' type: chain', + ' type: chain' if is_chain else '', ' transforms:', ' - type: ReadFromCsv', + ' name: input', ' config:', ' path: whatever', ] + [' ' + line for line in code_lines] @@ -278,17 +293,6 @@ def createTestSuite(name, path): return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme))) -class _Fakes: - fn = str - - class SomeTransform(beam.PTransform): - def __init__(*args, **kwargs): - pass - - def expand(self, pcoll): - return pcoll - - # These are copied from $ROOT/website/www/site/content/en/documentation/sdks # at build time. YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs')) diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 954e32cdf7b1..4839728dd886 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -24,6 +24,7 @@ from typing import Callable from typing import Collection from typing import Dict +from typing import List from typing import Mapping from typing import NamedTuple from typing import Optional @@ -42,6 +43,7 @@ from apache_beam.typehints import row_type from apache_beam.typehints import schemas from apache_beam.typehints import trivial_inference +from apache_beam.typehints import typehints from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type from apache_beam.utils import python_callable @@ -569,6 +571,86 @@ def extract_expr(name, v): return pcoll | sql_transform_constructor(query) +@beam.ptransform.ptransform_fn +def _Partition( + pcoll, + by: Union[str, Dict[str, str]], + outputs: List[str], + unknown_output: Optional[str] = None, + error_handling: Optional[Mapping[str, Any]] = None, + language: Optional[str] = 'generic'): + """Splits an input into several distinct outputs. + + Each input element will go to a distinct output based on the field or + function given in the `by` configuration parameter. + + Args: + by: A field, callable, or expression giving the destination output for + this element. Should return a string that is a member of the `outputs` + parameter. If `unknown_output` is also set, other returns values are + accepted as well, otherwise an error will be raised. + outputs: The set of outputs into which this input is being partitioned. + unknown_output: (Optional) If set, indicates a destination output for any + elements that are not assigned an output listed in the `outputs` + parameter. + error_handling: (Optional) Whether and how to handle errors during + partitioning. + language: (Optional) The language of the `by` expression. + """ + split_fn = _as_callable_for_pcoll(pcoll, by, 'by', language) + try: + split_fn_output_type = trivial_inference.infer_return_type( + split_fn, [pcoll.element_type]) + except (TypeError, ValueError): + pass + else: + if not typehints.is_consistent_with(split_fn_output_type, + typehints.Optional[str]): + raise ValueError( + f'Partition function "{by}" must return a string type ' + f'not {split_fn_output_type}') + error_output = error_handling['output'] if error_handling else None + if error_output in outputs: + raise ValueError( + f'Error handling output "{error_output}" ' + f'cannot be among the listed outputs {outputs}') + T = TypeVar('T') + + def split(element): + tag = split_fn(element) + if tag is None: + tag = unknown_output + if not isinstance(tag, str): + raise ValueError( + f'Returned output name "{tag}" of type {type(tag)} ' + f'from "{by}" must be a string.') + if tag not in outputs: + if unknown_output: + tag = unknown_output + else: + raise ValueError(f'Unknown output name "{tag}" from {by}') + return beam.pvalue.TaggedOutput(tag, element) + + output_set = set(outputs) + if unknown_output: + output_set.add(unknown_output) + if error_output: + output_set.add(error_output) + mapping_transform = beam.Map(split) + if error_output: + mapping_transform = mapping_transform.with_exception_handling( + **exception_handling_args(error_handling)) + else: + mapping_transform = mapping_transform.with_outputs(*output_set) + splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T) + result = {out: getattr(splits, out) for out in output_set} + if error_output: + result[ + error_output] = result[error_output] | _map_errors_to_standard_format( + pcoll.element_type) + return result + + @beam.ptransform.ptransform_fn @maybe_with_exception_handling_transform_fn def _AssignTimestamps( @@ -588,7 +670,8 @@ def _AssignTimestamps( Args: timestamp: A field, callable, or expression giving the new timestamp. language: The language of the timestamp expression. - error_handling: Whether and how to handle errors during iteration. + error_handling: Whether and how to handle errors during timestamp + evaluation. """ timestamp_fn = _as_callable_for_pcoll(pcoll, timestamp, 'timestamp', language) T = TypeVar('T') @@ -611,6 +694,9 @@ def create_mapping_providers(): 'MapToFields-python': _PyJsMapToFields, 'MapToFields-javascript': _PyJsMapToFields, 'MapToFields-generic': _PyJsMapToFields, + 'Partition-python': _Partition, + 'Partition-javascript': _Partition, + 'Partition-generic': _Partition, }), yaml_provider.SqlBackedProvider({ 'Filter-sql': _SqlFilterTransform, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 9dca107dca51..d5aa4038ef7a 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -151,6 +151,245 @@ def test_validate_explicit_types(self): ''') self.assertEqual(result.element_type._fields[0][1], str) + def test_partition(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: "'even' if len(element) % 2 == 0 else 'odd'" + language: python + outputs: [even, odd] + ''') + assert_that( + result['even'] | beam.Map(lambda x: x.element), + equal_to(['banana', 'orange']), + label='Even') + assert_that( + result['odd'] | beam.Map(lambda x: x.element), + equal_to(['apple']), + label='Odd') + + def test_partition_callable(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: + callable: + "lambda row: 'even' if len(row.element) % 2 == 0 else 'odd'" + language: python + outputs: [even, odd] + ''') + assert_that( + result['even'] | beam.Map(lambda x: x.element), + equal_to(['banana', 'orange']), + label='Even') + assert_that( + result['odd'] | beam.Map(lambda x: x.element), + equal_to(['apple']), + label='Odd') + + def test_partition_with_unknown(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: "element.lower()[0]" + language: python + outputs: [a, b, c] + unknown_output: other + ''') + assert_that( + result['a'] | beam.Map(lambda x: x.element), + equal_to(['apple']), + label='A') + assert_that( + result['b'] | beam.Map(lambda x: x.element), + equal_to(['banana']), + label='B') + assert_that( + result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C') + assert_that( + result['other'] | beam.Map(lambda x: x.element), + equal_to(['orange']), + label='Other') + + def test_partition_without_unknown(self): + with self.assertRaisesRegex(ValueError, r'.*Unknown output name.*"o".*'): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + _ = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: "element.lower()[0]" + language: python + outputs: [a, b, c] + ''') + + def test_partition_without_unknown_with_error(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: "element.lower()[0]" + language: python + outputs: [a, b, c] + error_handling: + output: unknown + ''') + assert_that( + result['a'] | beam.Map(lambda x: x.element), + equal_to(['apple']), + label='A') + assert_that( + result['b'] | beam.Map(lambda x: x.element), + equal_to(['banana']), + label='B') + assert_that( + result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C') + assert_that( + result['unknown'] | beam.Map(lambda x: x.element.element), + equal_to(['orange']), + label='Errors') + + def test_partition_with_actual_error(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple'), + beam.Row(element='banana'), + beam.Row(element='orange'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: "element.lower()[5]" + language: python + outputs: [a, b, c] + unknown_output: other + error_handling: + output: errors + ''') + assert_that( + result['a'] | beam.Map(lambda x: x.element), + equal_to(['banana']), + label='B') + assert_that( + result['other'] | beam.Map(lambda x: x.element), + equal_to(['orange']), + label='Other') + # Apple only has 5 letters, resulting in an index error. + assert_that( + result['errors'] | beam.Map(lambda x: x.element.element), + equal_to(['apple']), + label='Errors') + + def test_partition_no_language(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple', texture='smooth'), + beam.Row(element='banana', texture='smooth'), + beam.Row(element='orange', texture='bumpy'), + ]) + result = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: texture + outputs: [bumpy, smooth] + ''') + assert_that( + result['bumpy'] | beam.Map(lambda x: x.element), + equal_to(['orange']), + label='Bumpy') + assert_that( + result['smooth'] | beam.Map(lambda x: x.element), + equal_to(['apple', 'banana']), + label='Smooth') + + def test_partition_bad_static_type(self): + with self.assertRaisesRegex( + ValueError, r'.*Partition function .*must return a string.*'): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple', texture='smooth'), + beam.Row(element='banana', texture='smooth'), + beam.Row(element='orange', texture='bumpy'), + ]) + _ = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: len(texture) + outputs: [bumpy, smooth] + language: python + ''') + + def test_partition_bad_runtime_type(self): + with self.assertRaisesRegex(ValueError, + r'.*Returned output name.*must be a string.*'): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(element='apple', texture='smooth'), + beam.Row(element='banana', texture='smooth'), + beam.Row(element='orange', texture='bumpy'), + ]) + _ = elements | YamlTransform( + ''' + type: Partition + input: input + config: + by: print(texture) + outputs: [bumpy, smooth] + language: python + ''') + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 03574b5f98ff..df2fdbf6aaa1 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -918,7 +918,11 @@ def ensure_transforms_have_providers(spec): return spec def preprocess_langauges(spec): - if spec['type'] in ('Filter', 'MapToFields', 'Combine', 'AssignTimestamps'): + if spec['type'] in ('AssignTimestamps', + 'Combine', + 'Filter', + 'MapToFields', + 'Partition'): language = spec.get('config', {}).get('language', 'generic') new_type = spec['type'] + '-' + language if known_transforms and new_type not in known_transforms: diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index 4c51d99ce657..b228dac1909f 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2153,6 +2153,14 @@ students = ... {{< code_sample "sdks/typescript/test/docs/programming_guide.ts" model_multiple_pcollections_partition >}} {{< /highlight >}} +{{< highlight yaml >}} +{{< code_sample "sdks/python/apache_beam/yaml/programming_guide_test.py" model_multiple_pcollections_partition >}} +{{< /highlight >}} + +{{< paragraph class="language-yaml">}} +Note that in Beam YAML, `PCollections` are partitioned via string rather than integer values. +{{< /paragraph >}} + ### 4.3. Requirements for writing user code for Beam transforms {#requirements-for-writing-user-code-for-beam-transforms} When you build user code for a Beam transform, you should keep in mind the @@ -2415,6 +2423,14 @@ properties in your `ParDo` operation and follow this operation with a `Split` to break it into multiple `PCollection`s. {{< /paragraph >}} +{{< paragraph class="language-yaml">}} +In Beam YAML, one obtains multiple outputs by emitting all outputs to a single +`PCollection`, possibly with an extra field, and then using `Partition` to +split this single `PCollection` into multiple distinct `PCollection` +outputs. +{{< /paragraph >}} + + #### 4.5.1. Tags for multiple outputs {#output-tags} {{< paragraph class="language-typescript">}} diff --git a/website/www/site/content/en/documentation/sdks/yaml-udf.md b/website/www/site/content/en/documentation/sdks/yaml-udf.md index c2ab3eb64604..5a51f1af1a15 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-udf.md +++ b/website/www/site/content/en/documentation/sdks/yaml-udf.md @@ -207,6 +207,74 @@ criteria. This can be accomplished with a `Filter` transform, e.g. keep: "col2 > 0" ``` +## Partitioning + +It can also be useful to send different elements to different places +(similar to what is done with side outputs in other SDKs). +While this can be done with a set of `Filter` operations, if every +element has a single destination it can be more natural to use a `Partition` +transform instead which sends every element to a unique output. +For example, this will send all elements where `col1` is equal to `"a"` to the +output `Partition.a`. + +``` +- type: Partition + input: input + config: + by: col1 + outputs: ['a', 'b', 'c'] + +- type: SomeTransform + input: Partition.a + config: + param: ... + +- type: AnotherTransform + input: Partition.b + config: + param: ... +``` + +One can also specify the destination as a function, e.g. + +``` +- type: Partition + input: input + config: + by: "'even' if col2 % 2 == 0 else 'odd'" + language: python + outputs: ['even', 'odd'] +``` + +One can optionally provide a catch-all output which will capture all elements +that are not in the named outputs (which would otherwise be an error): + +``` +- type: Partition + input: input + config: + by: col1 + outputs: ['a', 'b', 'c'] + unknown_output: 'other' +``` + +Sometimes one wants to split a PCollection into multiple PCollections +that aren't necessarily disjoint. To send elements to multiple (or no) outputs, +one could use an iterable column and precede the `Partition` with an `Explode`. + +``` +- type: Explode + input: input + config: + fields: col1 + +- type: Partition + input: Explode + config: + by: col1 + outputs: ['a', 'b', 'c'] +``` + ## Types Beam will try to infer the types involved in the mappings, but sometimes this