Skip to content

Commit

Permalink
[YAML] Add MLTransform. (#30002)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Jun 11, 2024
1 parent 462c833 commit 5dd2d3f
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 9 deletions.
18 changes: 12 additions & 6 deletions sdks/python/apache_beam/ml/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,7 @@ def expand(
Returns:
A PCollection of dictionaries.
"""
if isinstance(pcoll.element_type, RowTypeConstraint):
# Row instance
return pcoll | beam.Map(lambda x: x.as_dict())
else:
# named tuple
return pcoll | beam.Map(lambda x: x._asdict())
return pcoll | beam.Map(lambda x: x._asdict())


class TFTProcessHandler(ProcessHandler[tft_process_handler_input_type,
Expand Down Expand Up @@ -404,6 +399,17 @@ def expand(
raw_data_metadata = metadata_io.read_metadata(
os.path.join(self.artifact_location, RAW_DATA_METADATA_DIR))

element_type = raw_data.element_type
if (isinstance(element_type, RowTypeConstraint) or
native_type_compatibility.match_is_named_tuple(element_type)):
# convert Row or NamedTuple to Dict
column_type_mapping = self._map_column_names_to_types(
row_type=element_type)
raw_data = (
raw_data
| _ConvertNamedTupleToDict().with_output_types(
Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore

feature_set = [feature.name for feature in raw_data_metadata.schema.feature]

# TFT ignores columns in the input data that aren't explicitly defined
Expand Down
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/ml/transforms/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def __init__(
This function applies a tf-idf transformation on the given columns
of incoming data.
TFIDF outputs two artifacts for each column: the vocabu index and
the tfidf weight. The vocabu index is a mapping from the original
TFIDF outputs two artifacts for each column: the vocabulary index and
the tfidf weight. The vocabulary index is a mapping from the original
vocabulary to the new vocabulary. The tfidf weight is a mapping
from the original vocabulary to the tfidf score.
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(
compute_word_count: A boolean that specifies whether to compute
the unique word count over the entire dataset. Defaults to False.
key_vocab_filename: The file name for the key vocabulary file when
compute_word_count is True. If empty, a file name
compute_word_count is True. If empty, a file name
will be chosen based on the current scope. If provided, the vocab
file will be suffixed with the column name.
name: A name for the operation (optional).
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/yaml/standard_providers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
Flatten: 'beam:schematransform:org.apache.beam:yaml:flatten:v1'
LogForTesting: 'beam:schematransform:org.apache.beam:yaml:log_for_testing:v1'

- type: 'python'
config: {}
transforms:
MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform'

- type: renaming
transforms:
'MapToFields-java': 'MapToFields-java'
Expand Down
66 changes: 66 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""This module defines yaml wrappings for some ML transforms."""

from typing import Any
from typing import List
from typing import Optional

import apache_beam as beam
from apache_beam.yaml import options

try:
from apache_beam.ml.transforms import tft
from apache_beam.ml.transforms.base import MLTransform
# TODO(robertwb): Is this all of them?
_transform_constructors = tft.__dict__
except ImportError:
tft = None # type: ignore


def _config_to_obj(spec):
if 'type' not in spec:
raise ValueError(r"Missing type in ML transform spec {spec}")
if 'config' not in spec:
raise ValueError(r"Missing config in ML transform spec {spec}")
constructor = _transform_constructors.get(spec['type'])
if constructor is None:
raise ValueError("Unknown ML transform type: %r" % spec['type'])
return constructor(**spec['config'])


@beam.ptransform.ptransform_fn
def ml_transform(
pcoll,
write_artifact_location: Optional[str] = None,
read_artifact_location: Optional[str] = None,
transforms: Optional[List[Any]] = None):
if tft is None:
raise ValueError(
'tensorflow-transform must be installed to use this MLTransform')
options.YamlOptions.check_enabled(pcoll.pipeline, 'ML')
# TODO(robertwb): Perhaps _config_to_obj could be pushed into MLTransform
# itself for better cross-language support?
return pcoll | MLTransform(
write_artifact_location=write_artifact_location,
read_artifact_location=read_artifact_location,
transforms=[_config_to_obj(t) for t in transforms] if transforms else [])


if tft is not None:
ml_transform.__doc__ = MLTransform.__doc__
92 changes: 92 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_ml_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import tempfile
import unittest

import apache_beam as beam
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.yaml.yaml_transform import YamlTransform

try:
# pylint: disable=wrong-import-order, wrong-import-position, unused-import
from apache_beam.ml.transforms import tft
except ImportError:
raise unittest.SkipTest('tensorflow_transform is not installed.')

TRAIN_DATA = [
beam.Row(num=0, text='And God said, Let there be light,'),
beam.Row(num=2, text='And there was light'),
beam.Row(num=8, text='And God saw the light, that it was good'),
]

TEST_DATA = [
beam.Row(num=6, text='And God divided the light from the darkness.'),
]


class MLTransformTest(unittest.TestCase):
def test_ml_transform(self):
ml_opts = beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['ML'])
with tempfile.TemporaryDirectory() as tempdir:
with beam.Pipeline(options=ml_opts) as p:
elements = p | beam.Create(TRAIN_DATA)
result = elements | YamlTransform(
f'''
type: MLTransform
config:
write_artifact_location: {tempdir}
transforms:
- type: ScaleTo01
config:
columns: [num]
- type: ComputeAndApplyVocabulary
config:
columns: [text]
split_string_by_delimiter: ' ,.'
''')
assert_that(
# Why is this an array, not a scalar?
result | beam.Map(lambda x: x.num[0]),
equal_to([0, .25, 1]))
assert_that(
result | beam.Map(lambda x: set(x.text))
| beam.CombineGlobally(lambda xs: set.union(*xs)),
equal_to([set(range(13))]),
label='CheckVocab')

with beam.Pipeline(options=ml_opts) as p:
elements = p | beam.Create(TEST_DATA)
result = elements | YamlTransform(
f'''
type: MLTransform
config:
read_artifact_location: {tempdir}
''')
assert_that(result | beam.Map(lambda x: x.num[0]), equal_to([.75]))
assert_that(
result | beam.Map(lambda x: len(set(x.text))),
equal_to([5]),
label='CheckVocab')


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit 5dd2d3f

Please sign in to comment.