Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Add Partition transform. #30368

Merged
merged 19 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,25 @@ def expand(self, pcoll):
lambda _: 1, sum, 'count')


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


RENDER_DIR = None
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeGroupingTransform': FakeAggregation,
'SomeTransform': _Fakes.SomeTransform,
'AnotherTransform': _Fakes.SomeTransform,
}


Expand All @@ -155,7 +168,7 @@ def input_file(self, name, content):
return path

def input_csv(self):
return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n')

def input_tsv(self):
return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
Expand Down Expand Up @@ -248,13 +261,15 @@ def parse_test_methods(markdown_lines):
else:
if code_lines:
if code_lines[0].startswith('- type:'):
is_chain = not any('input:' in line for line in code_lines)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain',
' type: chain' if is_chain else '',
' transforms:',
' - type: ReadFromCsv',
' name: input',
' config:',
' path: whatever',
] + [' ' + line for line in code_lines]
Expand All @@ -276,17 +291,6 @@ def createTestSuite(name, path):
return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))


class _Fakes:
fn = str

class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass

def expand(self, pcoll):
return pcoll


# These are copied from $ROOT/website/www/site/content/en/documentation/sdks
# at build time.
YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
Expand Down
48 changes: 48 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
Expand Down Expand Up @@ -550,6 +551,50 @@ def extract_expr(name, v):
return pcoll | sql_transform_constructor(query)


@beam.ptransform.ptransform_fn
def _Split(
pcoll,
outputs: List[str],
destination: Union[str, Dict[str, str]],
unknown_output: Optional[str] = None,
error_handling: Optional[Mapping[str, Any]] = None,
language: Optional[str] = 'generic'):
split_fn = _as_callable_for_pcoll(pcoll, destination, 'destination', language)
robertwb marked this conversation as resolved.
Show resolved Hide resolved
error_output = error_handling['output'] if error_handling else None
if error_output in outputs:
raise ValueError(
f'Error handling output "{error_output}" '
f'cannot be among the listed outputs {outputs}')
T = TypeVar('T')

def split(element):
tag = split_fn(element)
if tag not in outputs:
if unknown_output:
tag = unknown_output
else:
raise ValueError(f'Unknown output name for destination "{tag}"')
return beam.pvalue.TaggedOutput(tag, element)

output_set = set(outputs)
if unknown_output:
output_set.add(unknown_output)
if error_output:
output_set.add(error_output)
mapping_transform = beam.Map(split)
if error_output:
mapping_transform = mapping_transform.with_exception_handling(
**exception_handling_args(error_handling))
else:
mapping_transform = mapping_transform.with_outputs(*output_set)
splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T)
result = {out: getattr(splits, out) for out in output_set}
if error_output:
result[
error_output] = result[error_output] | _map_errors_to_standard_format()
robertwb marked this conversation as resolved.
Show resolved Hide resolved
return result


@beam.ptransform.ptransform_fn
def _AssignTimestamps(
pcoll,
Expand All @@ -576,6 +621,9 @@ def create_mapping_providers():
'MapToFields-python': _PyJsMapToFields,
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
'Split-python': _Split,
'Split-javascript': _Split,
'Split-generic': _Split,
}),
yaml_provider.SqlBackedProvider({
'Filter-sql': _SqlFilterTransform,
Expand Down
171 changes: 171 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,177 @@ def test_validate_explicit_types(self):
''')
self.assertEqual(result.element_type._fields[0][1], str)

def test_split(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple'),
beam.Row(element='banana'),
beam.Row(element='orange'),
])
result = elements | YamlTransform(
'''
type: Split
input: input
config:
language: python
outputs: [even, odd]
destination: "'even' if len(element) % 2 == 0 else 'odd'"
''')
assert_that(
result['even'] | beam.Map(lambda x: x.element),
equal_to(['banana', 'orange']),
label='Even')
assert_that(
result['odd'] | beam.Map(lambda x: x.element),
equal_to(['apple']),
label='Odd')

def test_split_with_unknown(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple'),
beam.Row(element='banana'),
beam.Row(element='orange'),
])
result = elements | YamlTransform(
'''
type: Split
input: input
config:
language: python
outputs: [a, b, c]
unknown_output: other
destination: "element.lower()[0]"
''')
assert_that(
result['a'] | beam.Map(lambda x: x.element),
equal_to(['apple']),
label='A')
assert_that(
result['b'] | beam.Map(lambda x: x.element),
equal_to(['banana']),
label='B')
assert_that(
result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C')
assert_that(
result['other'] | beam.Map(lambda x: x.element),
equal_to(['orange']),
label='Other')

def test_split_without_unknown(self):
with self.assertRaisesRegex(ValueError, r'.*Unknown output name.*"o".*'):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple'),
beam.Row(element='banana'),
beam.Row(element='orange'),
])
_ = elements | YamlTransform(
'''
type: Split
input: input
config:
language: python
outputs: [a, b, c]
destination: "element.lower()[0]"
''')

def test_split_without_unknown_with_error(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple'),
beam.Row(element='banana'),
beam.Row(element='orange'),
])
result = elements | YamlTransform(
'''
type: Split
input: input
config:
language: python
outputs: [a, b, c]
destination: "element.lower()[0]"
error_handling:
output: unknown
''')
assert_that(
result['a'] | beam.Map(lambda x: x.element),
equal_to(['apple']),
label='A')
assert_that(
result['b'] | beam.Map(lambda x: x.element),
equal_to(['banana']),
label='B')
assert_that(
result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C')
assert_that(
result['unknown'] | beam.Map(lambda x: x.element.element),
equal_to(['orange']),
label='Errors')

def test_split_with_actual_error(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple'),
beam.Row(element='banana'),
beam.Row(element='orange'),
])
result = elements | YamlTransform(
'''
type: Split
input: input
config:
language: python
outputs: [a, b, c]
destination: "element.lower()[5]"
unknown_output: other
error_handling:
output: errors
''')
assert_that(
result['a'] | beam.Map(lambda x: x.element),
equal_to(['banana']),
label='B')
assert_that(
result['other'] | beam.Map(lambda x: x.element),
equal_to(['orange']),
label='Other')
# Apple only has 5 letters, resulting in an index error.
assert_that(
result['errors'] | beam.Map(lambda x: x.element.element),
equal_to(['apple']),
label='Errors')

def test_split_no_language(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(element='apple', texture='smooth'),
beam.Row(element='banana', texture='smooth'),
beam.Row(element='orange', texture='bumpy'),
])
result = elements | YamlTransform(
'''
type: Split
input: input
config:
outputs: [bumpy, smooth]
destination: texture
''')
assert_that(
result['bumpy'] | beam.Map(lambda x: x.element),
equal_to(['orange']),
label='Bumpy')
assert_that(
result['smooth'] | beam.Map(lambda x: x.element),
equal_to(['apple', 'banana']),
label='Smooth')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
6 changes: 5 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,11 @@ def ensure_transforms_have_providers(spec):
return spec

def preprocess_langauges(spec):
if spec['type'] in ('Filter', 'MapToFields', 'Combine', 'AssignTimestamps'):
if spec['type'] in ('AssignTimestamps',
'Combine',
'Filter',
'MapToFields',
'Split'):
language = spec.get('config', {}).get('language', 'generic')
new_type = spec['type'] + '-' + language
if known_transforms and new_type not in known_transforms:
Expand Down
Loading
Loading