Skip to content

Commit

Permalink
[YAML] Add Partition transform. (#30368)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Mar 29, 2024
1 parent 25805db commit e3fee51
Show file tree
Hide file tree
Showing 7 changed files with 451 additions and 15 deletions.
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/yaml/programming_guide_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,25 @@ def extract_timestamp(x):
# [END setting_timestamp]
''')

def test_partition(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(percentile=1),
beam.Row(percentile=20),
beam.Row(percentile=90),
])
_ = elements | YamlTransform(
'''
# [START model_multiple_pcollections_partition]
type: Partition
config:
by: str(percentile // 10)
language: python
outputs: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
# [END model_multiple_pcollections_partition]
''')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
30 changes: 17 additions & 13 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,25 @@ def expand(self, pcoll):
lambda _: 1, sum, 'count')


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


RENDER_DIR = None
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeGroupingTransform': FakeAggregation,
'SomeTransform': _Fakes.SomeTransform,
'AnotherTransform': _Fakes.SomeTransform,
}


Expand All @@ -155,7 +168,7 @@ def input_file(self, name, content):
return path

def input_csv(self):
return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n')

def input_tsv(self):
return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
Expand Down Expand Up @@ -250,13 +263,15 @@ def parse_test_methods(markdown_lines):
else:
if code_lines:
if code_lines[0].startswith('- type:'):
is_chain = not any('input:' in line for line in code_lines)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain',
' type: chain' if is_chain else '',
' transforms:',
' - type: ReadFromCsv',
' name: input',
' config:',
' path: whatever',
] + [' ' + line for line in code_lines]
Expand All @@ -278,17 +293,6 @@ def createTestSuite(name, path):
return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


# These are copied from $ROOT/website/www/site/content/en/documentation/sdks
# at build time.
YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
Expand Down
88 changes: 87 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
Expand All @@ -42,6 +43,7 @@
from apache_beam.typehints import row_type
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints import typehints
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
Expand Down Expand Up @@ -569,6 +571,86 @@ def extract_expr(name, v):
return pcoll | sql_transform_constructor(query)


@beam.ptransform.ptransform_fn
def _Partition(
pcoll,
by: Union[str, Dict[str, str]],
outputs: List[str],
unknown_output: Optional[str] = None,
error_handling: Optional[Mapping[str, Any]] = None,
language: Optional[str] = 'generic'):
"""Splits an input into several distinct outputs.
Each input element will go to a distinct output based on the field or
function given in the `by` configuration parameter.
Args:
by: A field, callable, or expression giving the destination output for
this element. Should return a string that is a member of the `outputs`
parameter. If `unknown_output` is also set, other returns values are
accepted as well, otherwise an error will be raised.
outputs: The set of outputs into which this input is being partitioned.
unknown_output: (Optional) If set, indicates a destination output for any
elements that are not assigned an output listed in the `outputs`
parameter.
error_handling: (Optional) Whether and how to handle errors during
partitioning.
language: (Optional) The language of the `by` expression.
"""
split_fn = _as_callable_for_pcoll(pcoll, by, 'by', language)
try:
split_fn_output_type = trivial_inference.infer_return_type(
split_fn, [pcoll.element_type])
except (TypeError, ValueError):
pass
else:
if not typehints.is_consistent_with(split_fn_output_type,
typehints.Optional[str]):
raise ValueError(
f'Partition function "{by}" must return a string type '
f'not {split_fn_output_type}')
error_output = error_handling['output'] if error_handling else None
if error_output in outputs:
raise ValueError(
f'Error handling output "{error_output}" '
f'cannot be among the listed outputs {outputs}')
T = TypeVar('T')

def split(element):
tag = split_fn(element)
if tag is None:
tag = unknown_output
if not isinstance(tag, str):
raise ValueError(
f'Returned output name "{tag}" of type {type(tag)} '
f'from "{by}" must be a string.')
if tag not in outputs:
if unknown_output:
tag = unknown_output
else:
raise ValueError(f'Unknown output name "{tag}" from {by}')
return beam.pvalue.TaggedOutput(tag, element)

output_set = set(outputs)
if unknown_output:
output_set.add(unknown_output)
if error_output:
output_set.add(error_output)
mapping_transform = beam.Map(split)
if error_output:
mapping_transform = mapping_transform.with_exception_handling(
**exception_handling_args(error_handling))
else:
mapping_transform = mapping_transform.with_outputs(*output_set)
splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T)
result = {out: getattr(splits, out) for out in output_set}
if error_output:
result[
error_output] = result[error_output] | _map_errors_to_standard_format(
pcoll.element_type)
return result


@beam.ptransform.ptransform_fn
@maybe_with_exception_handling_transform_fn
def _AssignTimestamps(
Expand All @@ -588,7 +670,8 @@ def _AssignTimestamps(
Args:
timestamp: A field, callable, or expression giving the new timestamp.
language: The language of the timestamp expression.
error_handling: Whether and how to handle errors during iteration.
error_handling: Whether and how to handle errors during timestamp
evaluation.
"""
timestamp_fn = _as_callable_for_pcoll(pcoll, timestamp, 'timestamp', language)
T = TypeVar('T')
Expand All @@ -611,6 +694,9 @@ def create_mapping_providers():
'MapToFields-python': _PyJsMapToFields,
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
'Partition-python': _Partition,
'Partition-javascript': _Partition,
'Partition-generic': _Partition,
}),
yaml_provider.SqlBackedProvider({
'Filter-sql': _SqlFilterTransform,
Expand Down
Loading

0 comments on commit e3fee51

Please sign in to comment.