Skip to content

Commit

Permalink
[YAML] Allow constants and simple comparisons in generic expressions. (
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Jun 11, 2024
1 parent 5dd2d3f commit c2207d8
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}
for (Map.Entry<String, JavaRowUdf.Configuration> entry :
configuration.getFields().entrySet()) {
if (!"java".equals(configuration.getLanguage())) {
String expr = entry.getValue().getExpression();
if (expr == null || !inputSchema.hasField(expr)) {
throw new IllegalArgumentException(
"Unknown field or missing language specification for '" + entry.getKey() + "'");
}
}
try {
JavaRowUdf udf = new JavaRowUdf(entry.getValue(), inputSchema);
udfs.add(udf);
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/yaml/standard_providers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'MapToFields-java': 'MapToFields-java'
'MapToFields-generic': 'MapToFields-java'
'Filter-java': 'Filter-java'
'Filter-generic': 'Filter-java'
'Explode': 'Explode'
config:
mappings:
Expand All @@ -80,6 +81,10 @@
drop: 'drop'
fields: 'fields'
error_handling: 'error_handling'
'Filter-generic':
language: 'language'
keep: 'keep'
error_handling: 'error_handling'
'Filter-java':
language: 'language'
keep: 'keep'
Expand Down
46 changes: 46 additions & 0 deletions sdks/python/apache_beam/yaml/tests/map.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

pipelines:
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- 100
- 200
- 300

- type: MapToFields
config:
append: true
fields:
named_field: element
literal_int: 10
literal_float: 1.5
literal_str: '"abc"'

- type: Filter
config:
keep: "named_field < 250"

- type: AssertEqual
config:
elements:
- {element: 100, named_field: 100, literal_int: 10, literal_float: 1.5, literal_str: "abc"}
- {element: 200, named_field: 200, literal_int: 10, literal_float: 1.5, literal_str: "abc"}
117 changes: 94 additions & 23 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import inspect
import itertools
import re
from collections import abc
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -56,6 +57,12 @@
js2py = None
JsObjectWrapper = object

_str_expression_fields = {
'AssignTimestamps': 'timestamp',
'Filter': 'keep',
'Partition': 'by',
}


def normalize_mapping(spec):
"""
Expand All @@ -65,9 +72,87 @@ def normalize_mapping(spec):
config = spec.get('config')
if isinstance(config.get('drop'), str):
config['drop'] = [config['drop']]
for field, value in list(config.get('fields', {}).items()):
if isinstance(value, (str, int, float)):
config['fields'][field] = {'expression': str(value)}

elif spec['type'] in _str_expression_fields:
param = _str_expression_fields[spec['type']]
config = spec.get('config', {})
if isinstance(config.get(param), (str, int, float)):
config[param] = {'expression': str(config.get(param))}

return spec


def is_literal(expr: str) -> bool:
# Some languages have limited integer literal ranges.
if re.fullmatch(r'-?\d+?', expr) and -1 << 31 < int(expr) < 1 << 31:
return True
elif re.fullmatch(r'-?\d+\.\d*', expr):
return True
elif re.fullmatch(r'"[^\\"]*"', expr):
return True
else:
return False


def validate_generic_expression(
expr_dict: dict,
input_fields: Collection[str],
allow_cmp: bool,
error_field: str) -> None:
if not isinstance(expr_dict, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr_dict}")
if len(expr_dict) != 1 or 'expression' not in expr_dict:
raise ValueError(
"Missing language specification. "
"Must specify a language when using a map with custom logic for %s" %
error_field)
expr = str(expr_dict['expression'])

def is_atomic(expr: str):
return is_literal(expr) or expr in input_fields

if is_atomic(expr):
return

if allow_cmp:
maybe_cmp = re.fullmatch('(.*)([<>=!]+)(.*)', expr)
if maybe_cmp:
left, cmp, right = maybe_cmp.groups()
if (is_atomic(left.strip()) and is_atomic(right.strip()) and
cmp in {'==', '<=', '>=', '<', '>', '!='}):
return

raise ValueError(
"Missing language specification, unknown input fields, "
f"or invalid generic expression: {expr}. "
"See https://beam.apache.org/documentation/sdks/yaml-udf/#generic")


def validate_generic_expressions(base_type, config, input_pcolls) -> None:
if not input_pcolls:
return
try:
input_fields = [
name for (name, _) in named_fields_from_element_type(
next(iter(input_pcolls)).element_type)
]
except (TypeError, ValueError):
input_fields = []

if base_type == 'MapToFields':
for field, value in list(config.get('fields', {}).items()):
validate_generic_expression(value, input_fields, True, field)

elif base_type in _str_expression_fields:
param = _str_expression_fields[base_type]
validate_generic_expression(
config.get(param), input_fields, base_type == 'Filter', param)


def _check_mapping_arguments(
transform_name, expression=None, callable=None, name=None, path=None):
# Argument checking
Expand Down Expand Up @@ -282,16 +367,16 @@ def _as_callable_for_pcoll(


def _as_callable(original_fields, expr, transform_name, language, input_schema):
if isinstance(expr, str):
expr = {'expression': expr}

# Extract original type from upstream pcoll when doing simple mappings
original_type = input_schema.get(str(expr), None)
original_type = input_schema.get(expr.get('expression'), None)
if expr in original_fields:
language = "python"

# TODO(yaml): support an imports parameter
# TODO(yaml): support a requirements parameter (possibly at a higher level)
if isinstance(expr, str):
expr = {'expression': expr}
if not isinstance(expr, dict):
raise ValueError(
f"Ambiguous expression type (perhaps missing quoting?): {expr}")
Expand All @@ -300,7 +385,7 @@ def _as_callable(original_fields, expr, transform_name, language, input_schema):

if language == "javascript":
func = _expand_javascript_mapping_func(original_fields, **expr)
elif language == "python":
elif language in ("python", "generic", None):
func = _expand_python_mapping_func(original_fields, **expr)
else:
raise ValueError(
Expand All @@ -323,13 +408,9 @@ def checking_func(row):
return checking_func

elif original_type:

@beam.typehints.with_output_types(convert_to_beam_type(original_type))
def checking_func(row):
result = func(row)
return result

return checking_func
return beam.typehints.with_output_types(
convert_to_beam_type(original_type))(
func)

else:
return func
Expand Down Expand Up @@ -498,7 +579,7 @@ def _PyJsFilter(
See more complete documentation on
[YAML Filtering](https://beam.apache.org/documentation/sdks/yaml-udf/#filtering).
""" # pylint: disable=line-too-long
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language)
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic')
return pcoll | beam.Filter(keep_fn)


Expand Down Expand Up @@ -530,17 +611,6 @@ def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
f'Redefinition of field "{name}". '
'Cannot append a field that already exists in original input.')

if language == 'generic':
for expr in fields.values():
if not isinstance(expr, str):
raise ValueError(
"Missing language specification. "
"Must specify a language when using a map with custom logic.")
missing = set(fields.values()) - set(input_schema.keys())
if missing:
raise ValueError(
f"Missing language specification or unknown input fields: {missing}")

if append:
return input_schema, {
**{name: f'`{name}`' if language in ['sql', 'calcite'] else name
Expand Down Expand Up @@ -720,6 +790,7 @@ def create_mapping_providers():
'Explode': _Explode,
'Filter-python': _PyJsFilter,
'Filter-javascript': _PyJsFilter,
'Filter-generic': _PyJsFilter,
'MapToFields-python': _PyJsMapToFields,
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
Expand Down
7 changes: 7 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_combine import normalize_combine
from apache_beam.yaml.yaml_mapping import normalize_mapping
from apache_beam.yaml.yaml_mapping import validate_generic_expressions

__all__ = ["YamlTransform"]

Expand Down Expand Up @@ -384,6 +385,12 @@ def create_ptransform(self, spec, input_pcolls):
f'Missing inputs for transform at {identify_object(spec)}')

try:
if spec['type'].endswith('-generic'):
# Centralize the validation rather than require every implementation
# to do it.
validate_generic_expressions(
spec['type'].rsplit('-', 1)[0], config, input_pcolls)

# pylint: disable=undefined-loop-variable
ptransform = provider.create_transform(
spec['type'], config, self.create_ptransform)
Expand Down
27 changes: 26 additions & 1 deletion website/www/site/content/en/documentation/sdks/yaml-udf.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,22 @@ If one wanted to select a field that collides with a [reserved SQL keyword](http
**Note**: the field mapping tags and fields defined in `drop` do not need to be escaped. Only the UDF itself
needs to be a valid SQL statement.


### Generic

If a language is not specified the set of expressions is limited to pre-existing
fields and integer, floating point, or string literals. For example

```
- type: MapToFields
config:
fields:
new_col: col1
int_literal: 389
float_litera: 1.90216
str_literal: '"example"' # note the double quoting
```

## FlatMap

Sometimes it may be desirable to emit more (or less) than one record for each
Expand Down Expand Up @@ -269,10 +285,19 @@ criteria. This can be accomplished with a `Filter` transform, e.g.
```
- type: Filter
config:
language: python
keep: "col2 > 0"
```

For anything more complicated than a simple comparison between existing
fields and numeric literals a `language` parameter must be provided, e.g.

```
- type: Filter
config:
language: python
keep: "col2 + col3 > 0"
```

For more complicated filtering functions, one can provide a full Python callable that takes the row as an
argument to do more complex mappings
(see [PythonCallableSource](https://beam.apache.org/releases/pydoc/current/apache_beam.utils.python_callable.html#apache_beam.utils.python_callable.PythonCallableWithSource)
Expand Down

0 comments on commit c2207d8

Please sign in to comment.