Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Add Partition transform. #30368

Merged
merged 19 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/yaml/programming_guide_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,25 @@ def extract_timestamp(x):
# [END setting_timestamp]
''')

def test_partition(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(percentile=1),
beam.Row(percentile=20),
beam.Row(percentile=90),
])
_ = elements | YamlTransform(
'''
# [START model_multiple_pcollections_partition]
type: Partition
config:
by: str(percentile // 10)
language: python
outputs: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
# [END model_multiple_pcollections_partition]
''')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
30 changes: 17 additions & 13 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,25 @@ def expand(self, pcoll):
lambda _: 1, sum, 'count')


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


RENDER_DIR = None
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeGroupingTransform': FakeAggregation,
'SomeTransform': _Fakes.SomeTransform,
'AnotherTransform': _Fakes.SomeTransform,
}


Expand All @@ -155,7 +168,7 @@ def input_file(self, name, content):
return path

def input_csv(self):
return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n')

def input_tsv(self):
return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
Expand Down Expand Up @@ -250,13 +263,15 @@ def parse_test_methods(markdown_lines):
else:
if code_lines:
if code_lines[0].startswith('- type:'):
is_chain = not any('input:' in line for line in code_lines)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain',
' type: chain' if is_chain else '',
' transforms:',
' - type: ReadFromCsv',
' name: input',
' config:',
' path: whatever',
] + [' ' + line for line in code_lines]
Expand All @@ -278,17 +293,6 @@ def createTestSuite(name, path):
return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


# These are copied from $ROOT/website/www/site/content/en/documentation/sdks
# at build time.
YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
Expand Down
88 changes: 87 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
Expand All @@ -42,6 +43,7 @@
from apache_beam.typehints import row_type
from apache_beam.typehints import schemas
from apache_beam.typehints import trivial_inference
from apache_beam.typehints import typehints
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils import python_callable
Expand Down Expand Up @@ -569,6 +571,86 @@ def extract_expr(name, v):
return pcoll | sql_transform_constructor(query)


@beam.ptransform.ptransform_fn
def _Partition(
pcoll,
by: Union[str, Dict[str, str]],
outputs: List[str],
unknown_output: Optional[str] = None,
error_handling: Optional[Mapping[str, Any]] = None,
language: Optional[str] = 'generic'):
"""Splits an input into several distinct outputs.

Each input element will go to a distinct output based on the field or
function given in the `by` configuration parameter.

Args:
by: A field, callable, or expression giving the destination output for
this element. Should return a string that is a member of the `outputs`
parameter. If `unknown_output` is also set, other returns values are
accepted as well, otherwise an error will be raised.
outputs: The set of outputs into which this input is being partitioned.
unknown_output: (Optional) If set, indicates a destination output for any
elements that are not assigned an output listed in the `outputs`
parameter.
error_handling: (Optional) Whether and how to handle errors during
partitioning.
language: (Optional) The language of the `by` expression.
"""
split_fn = _as_callable_for_pcoll(pcoll, by, 'by', language)
try:
split_fn_output_type = trivial_inference.infer_return_type(
split_fn, [pcoll.element_type])
except (TypeError, ValueError):
pass
else:
if not typehints.is_consistent_with(split_fn_output_type,
typehints.Optional[str]):
raise ValueError(
f'Partition function "{by}" must return a string type '
f'not {split_fn_output_type}')
error_output = error_handling['output'] if error_handling else None
if error_output in outputs:
raise ValueError(
f'Error handling output "{error_output}" '
f'cannot be among the listed outputs {outputs}')
T = TypeVar('T')

def split(element):
tag = split_fn(element)
if tag is None:
tag = unknown_output
if not isinstance(tag, str):
raise ValueError(
f'Returned output name "{tag}" of type {type(tag)} '
f'from "{by}" must be a string.')
if tag not in outputs:
if unknown_output:
tag = unknown_output
else:
raise ValueError(f'Unknown output name "{tag}" from {by}')
return beam.pvalue.TaggedOutput(tag, element)

output_set = set(outputs)
if unknown_output:
output_set.add(unknown_output)
if error_output:
output_set.add(error_output)
mapping_transform = beam.Map(split)
if error_output:
mapping_transform = mapping_transform.with_exception_handling(
**exception_handling_args(error_handling))
else:
mapping_transform = mapping_transform.with_outputs(*output_set)
splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T)
result = {out: getattr(splits, out) for out in output_set}
if error_output:
result[
error_output] = result[error_output] | _map_errors_to_standard_format(
pcoll.element_type)
return result


@beam.ptransform.ptransform_fn
@maybe_with_exception_handling_transform_fn
def _AssignTimestamps(
Expand All @@ -588,7 +670,8 @@ def _AssignTimestamps(
Args:
timestamp: A field, callable, or expression giving the new timestamp.
language: The language of the timestamp expression.
error_handling: Whether and how to handle errors during iteration.
error_handling: Whether and how to handle errors during timestamp
evaluation.
"""
timestamp_fn = _as_callable_for_pcoll(pcoll, timestamp, 'timestamp', language)
T = TypeVar('T')
Expand All @@ -611,6 +694,9 @@ def create_mapping_providers():
'MapToFields-python': _PyJsMapToFields,
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
'Partition-python': _Partition,
'Partition-javascript': _Partition,
'Partition-generic': _Partition,
}),
yaml_provider.SqlBackedProvider({
'Filter-sql': _SqlFilterTransform,
Expand Down
Loading
Loading