Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YAML] Allow constants and simple comparisons in generic expressions. #31455

Merged
merged 11 commits into from
Jun 11, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}
for (Map.Entry<String, JavaRowUdf.Configuration> entry :
configuration.getFields().entrySet()) {
if (!"java".equals(configuration.getLanguage())) {
String expr = entry.getValue().getExpression();
if (expr == null || !inputSchema.hasField(expr)) {
throw new IllegalArgumentException(
"Unknown field or missing language specification for '" + entry.getKey() + "'");
}
}
try {
JavaRowUdf udf = new JavaRowUdf(entry.getValue(), inputSchema);
udfs.add(udf);
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/yaml/standard_providers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
'MapToFields-java': 'MapToFields-java'
'MapToFields-generic': 'MapToFields-java'
'Filter-java': 'Filter-java'
'Filter-generic': 'Filter-java'
'Explode': 'Explode'
config:
mappings:
Expand All @@ -75,6 +76,10 @@
drop: 'drop'
fields: 'fields'
error_handling: 'errorHandling'
'Filter-generic':
language: 'language'
keep: 'keep'
error_handling: 'errorHandling'
robertwb marked this conversation as resolved.
Show resolved Hide resolved
'Filter-java':
language: 'language'
keep: 'keep'
Expand Down
46 changes: 46 additions & 0 deletions sdks/python/apache_beam/yaml/tests/map.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the# Row(word='License'); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an# Row(word='AS IS' BASIS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI all of the test files have this error, so not sure what is generating this license, but it appears to not like single quotes

Suggested change
# (the# Row(word='License'); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an# Row(word='AS IS' BASIS,
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here and #31480 for the rest.

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

pipelines:
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- 100
- 200
- 300

- type: MapToFields
config:
append: true
fields:
named_field: element
literal_int: 10
literal_float: 1.5
literal_str: '"abc"'

- type: Filter
config:
keep: "named_field < 250"

- type: AssertEqual
config:
elements:
- {element: 100, named_field: 100, literal_int: 10, literal_float: 1.5, literal_str: "abc"}
- {element: 200, named_field: 200, literal_int: 10, literal_float: 1.5, literal_str: "abc"}
115 changes: 92 additions & 23 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import inspect
import itertools
import re
from collections import abc
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -56,6 +57,12 @@
js2py = None
JsObjectWrapper = object

_str_expression_fields = {
'AssignTimestamps': 'timestamp',
'Filter': 'keep',
'Partition': 'by',
}


def normalize_mapping(spec):
"""
Expand All @@ -65,9 +72,85 @@ def normalize_mapping(spec):
config = spec.get('config')
if isinstance(config.get('drop'), str):
config['drop'] = [config['drop']]
for field, value in list(config.get('fields', {}).items()):
if isinstance(value, (str, int, float)):
config['fields'][field] = {'expression': str(value)}

elif spec['type'] in _str_expression_fields:
param = _str_expression_fields[spec['type']]
config = spec.get('config', {})
if isinstance(config.get(param), (str, int, float)):
config[param] = {'expression': str(config.get(param))}

return spec


def is_literal(expr: str) -> bool:
# Some languages have limited integer literal ranges.
if re.fullmatch(r'-?\d+?', expr) and -1 << 31 < int(expr) < 1 << 31:
return True
elif re.fullmatch(r'-?\d+\.\d*', expr):
return True
elif re.fullmatch(r'"[^\\"]*"', expr):
return True
else:
return False


def validate_generic_expression(
expr_dict: dict,
input_fields: Collection[str],
allow_cmp: bool,
error_field: str) -> None:
if not isinstance(expr_dict, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr_dict}")
if len(expr_dict) != 1 or 'expression' not in expr_dict:
raise ValueError(
"Missing language specification. "
"Must specify a language when using a map with custom logic for %s" %
error_field)
expr = str(expr_dict['expression'])

def is_atomic(expr: str):
return is_literal(expr) or expr in input_fields

if is_atomic(expr):
return

if allow_cmp:
maybe_cmp = re.fullmatch('(.*)([<>=!]+)(.*)', expr)
if maybe_cmp:
left, cmp, right = maybe_cmp.groups()
if (is_atomic(left.strip()) and is_atomic(right.strip()) and
cmp in {'==', '<=', '>=', '<', '>', '!='}):
return

raise ValueError(
f"Missing language specification or unknown input fields: {expr}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should this message refer to expression rather than input fields?

Suggested change
f"Missing language specification or unknown input fields: {expr}")
f"Missing language specification or invalid generic expression: {expr}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I kept the "unkown input fields" as a missing/typo input name is still a quite likely error here we should point to.



def validate_generic_expressions(base_type, config, input_pcolls) -> None:
if not input_pcolls:
return
try:
input_fields = [
name for (name, _) in named_fields_from_element_type(
next(iter(input_pcolls)).element_type)
]
except (TypeError, ValueError):
input_fields = []

if base_type == 'MapToFields':
for field, value in list(config.get('fields', {}).items()):
validate_generic_expression(value, input_fields, True, field)

elif base_type in _str_expression_fields:
param = _str_expression_fields[base_type]
validate_generic_expression(
config.get(param), input_fields, base_type == 'Filter', param)


def _check_mapping_arguments(
transform_name, expression=None, callable=None, name=None, path=None):
# Argument checking
Expand Down Expand Up @@ -282,16 +365,16 @@ def _as_callable_for_pcoll(


def _as_callable(original_fields, expr, transform_name, language, input_schema):
if isinstance(expr, str):
expr = {'expression': expr}

# Extract original type from upstream pcoll when doing simple mappings
original_type = input_schema.get(str(expr), None)
original_type = input_schema.get(expr.get('expression'), None)
if expr in original_fields:
language = "python"

# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
if isinstance(expr, str):
expr = {'expression': expr}
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
Expand All @@ -300,7 +383,7 @@ def _as_callable(original_fields, expr, transform_name, language, input_schema):

if language == "javascript":
func = _expand_javascript_mapping_func(original_fields, **expr)
elif language == "python":
elif language in ("python", "generic", None):
func = _expand_python_mapping_func(original_fields, **expr)
else:
raise ValueError(
Expand All @@ -323,13 +406,9 @@ def checking_func(row):
return checking_func

elif original_type:

@beam.typehints.with_output_types(convert_to_beam_type(original_type))
def checking_func(row):
result = func(row)
return result

return checking_func
return beam.typehints.with_output_types(
convert_to_beam_type(original_type))(
func)

else:
return func
Expand Down Expand Up @@ -498,7 +577,7 @@ def _PyJsFilter(
See more complete documentation on
[YAML Filtering](https://beam.apache.org/documentation/sdks/yaml-udf/#filtering).
""" # pylint: disable=line-too-long
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language)
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic')
return pcoll | beam.Filter(keep_fn)


Expand Down Expand Up @@ -530,17 +609,6 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
f'Redefinition of field "{name}". '
'Cannot append a field that already exists in original input.')

if language == 'generic':
for expr in fields.values():
if not isinstance(expr, str):
raise ValueError(
"Missing language specification. "
"Must specify a language when using a map with custom logic.")
missing = set(fields.values()) - set(input_schema.keys())
if missing:
raise ValueError(
f"Missing language specification or unknown input fields: {missing}")

if append:
return input_schema, {
**{name: f'`{name}`' if language in ['sql', 'calcite'] else name
Expand Down Expand Up @@ -720,6 +788,7 @@ def create_mapping_providers():
'Explode': _Explode,
'Filter-python': _PyJsFilter,
'Filter-javascript': _PyJsFilter,
'Filter-generic': _PyJsFilter,
'MapToFields-python': _PyJsMapToFields,
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_combine import normalize_combine
from apache_beam.yaml.yaml_mapping import normalize_mapping
from apache_beam.yaml.yaml_mapping import validate_generic_expressions

__all__ = ["YamlTransform"]

Expand Down Expand Up @@ -384,6 +385,12 @@ def create_ptransform(self, spec, input_pcolls):
f'Missing inputs for transform at {identify_object(spec)}')

try:
if spec['type'].endswith('-generic'):
# Centralize the validation rather than require every implementation
# to do it.
validate_generic_expressions(
spec['type'].rsplit('-', 1)[0], config, input_pcolls)

# pylint: disable=undefined-loop-variable
ptransform = provider.create_transform(
spec['type'], config, self.create_ptransform)
Expand Down
Loading