Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use proper coders in interactive cache. #32330

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions sdks/python/apache_beam/runners/interactive/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pytype: skip-file

import base64
import collections
import os
import tempfile
Expand All @@ -28,6 +29,7 @@
from apache_beam.io import filesystems
from apache_beam.io import textio
from apache_beam.io import tfrecordio
from apache_beam.testing import test_stream
from apache_beam.transforms import combiners


Expand Down Expand Up @@ -154,7 +156,15 @@ class FileBasedCacheManager(CacheManager):
"""Maps PCollections to local temp files for materialization."""

_available_formats = {
'text': (textio.ReadFromText, textio.WriteToText),
'text': (
lambda path: textio.ReadFromText(
path,
coder=Base64Coder(),
compression_type=filesystems.CompressionTypes.BZIP2),
lambda path: textio.WriteToText(
path,
coder=Base64Coder(),
compression_type=filesystems.CompressionTypes.BZIP2)),
'tfrecord': (tfrecordio.ReadFromTFRecord, tfrecordio.WriteToTFRecord)
}

Expand Down Expand Up @@ -226,12 +236,13 @@ def read(self, *labels, **args):
return iter([]), -1

# Otherwise, return a generator to the cached PCollection.
source = self.source(*labels)._source
coder = self.load_pcoder('reify', *labels[1:])
source = self.raw_source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
reader = source.read(range_tracker)
version = self._latest_version(*labels)

return reader, version
return (coder.decode(b) for b in reader), version

def write(self, values, *labels):
"""Imitates how a WriteCache transform works without running a pipeline.
Expand All @@ -242,16 +253,17 @@ def write(self, values, *labels):
pcoder = coders.registry.get_coder(type(values[0]))
# Save the pcoder for the actual labels.
self.save_pcoder(pcoder, *labels)
self.save_pcoder(pcoder, 'reify', *labels[-1:])
single_shard_labels = [*labels[:-1], '-00000-of-00001']
# Save the pcoder for the labels that imitates the sharded cache file name
# suffix.
self.save_pcoder(pcoder, *single_shard_labels)
# Put a '-%05d-of-%05d' suffix to the cache file.
sink = self.sink(single_shard_labels)._sink
sink = self.raw_sink(single_shard_labels)._sink
path = self._path(*labels[:-1])
writer = sink.open_writer(path, labels[-1])
for v in values:
writer.write(v)
writer.write(pcoder.encode(v))
writer.close()

def clear(self, *labels):
Expand All @@ -261,12 +273,20 @@ def clear(self, *labels):
return False

def source(self, *labels):
return self._reader_class(
self._glob_path(*labels), coder=self.load_pcoder(*labels))
coder = self.load_pcoder('reify', *labels[1:])
return self.raw_source(*labels) | beam.Map(
lambda b: test_stream.WindowedValueHolder(coder.decode(b)))

def sink(self, labels, is_capture=False):
return self._writer_class(
self._path(*labels), coder=self.load_pcoder(*labels))
coder = self.load_pcoder('reify', *labels[1:])
return beam.Map(lambda wvh: coder.encode(wvh.windowed_value)
) | self.raw_sink(labels, is_capture)

def raw_sink(self, labels, is_capture=False):
return self._writer_class(self._path(*labels))

def raw_source(self, *labels):
return self._reader_class(self._glob_path(*labels))

def cleanup(self):
if self._cache_dir.startswith('gs://'):
Expand Down Expand Up @@ -342,6 +362,12 @@ def expand(self, pcoll):
# We save pcoder that is necessary for proper reading of
# cached PCollection. _cache_manager.sink(...) call below
# should be using this saved pcoder.
self._cache_manager.save_pcoder(
beam.coders.WindowedValueCoder(
beam.coders.registry.get_coder(pcoll.element_type),
pcoll.windowing.windowfn.get_window_coder()),
'reify',
self._label)
self._cache_manager.save_pcoder(
coders.registry.get_coder(pcoll.element_type), prefix, self._label)

Expand All @@ -365,3 +391,9 @@ def encode(self, value):

def decode(self, value):
return coders.coders.FastPrimitivesCoder().decode(unquote_to_bytes(value))


class Base64Coder(coders.Coder):
"""Used to safely encode arbitrary bytes to textio."""
encode = staticmethod(base64.b64encode)
decode = staticmethod(base64.b64decode)
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ def test_size(self):
coder = self.cache_manager.load_pcoder(prefix, cache_label)
encoded = coder.encode(value)

# Add one to the size on disk because of the extra new-line character when
# writing to file.
self.assertEqual(
self.cache_manager.size(prefix, cache_label), len(encoded) + 1)
# We encode in a format that escapes newlines.
self.assertGreater(
self.cache_manager.size(prefix, cache_label), len(encoded))

def test_clear(self):
"""Test that CacheManager can correctly tell if the cache exists or not."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import sys
import unittest
import unittest.mock
from typing import NamedTuple

import pandas as pd
Expand Down Expand Up @@ -138,7 +139,7 @@ def process(self, element):
0: [e[0] for e in actual],
1: [e[1] for e in actual],
'event_time': [end_of_window for _ in actual],
'windows': [[GlobalWindow()] for _ in actual],
'windows': [(GlobalWindow(), ) for _ in actual],
'pane_info': [
PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0) for _ in actual
]
Expand All @@ -153,7 +154,7 @@ def process(self, element):
expected_reified = [
WindowedValue(
e,
Timestamp(micros=end_of_window), [GlobalWindow()],
Timestamp(micros=end_of_window), (GlobalWindow(), ),
PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)) for e in actual
]
self.assertEqual(actual_reified, expected_reified)
Expand Down Expand Up @@ -490,6 +491,47 @@ def enter_composite_transform(
self.assertEqual(producer, prev_producer, trace_string)
prev_producer = consumer

@staticmethod
def only_none_shall_pass(value):
if value is None:
return b'\0'
else:
raise RuntimeError("Should be using a more efficient coder.")

@unittest.mock.patch.object(
beam.coders.coders.FastPrimitivesCoder, 'encode', only_none_shall_pass)
def test_defaults_to_efficient_cache(self):
p = beam.Pipeline(
runner=interactive_runner.InteractiveRunner(
direct_runner.DirectRunner()))

inputs = [1, 10, 100, 1000, 10000]
big = (
p
| beam.Create(inputs)
| 'Explode' >> beam.FlatMap(lambda n: ("v_%s" % ix for ix in range(n))))

# Watch the local scope for Interactive Beam so that counts will be cached.
ib.watch(locals())

# This is normally done in the interactive_utils when a transform is
# applied but needs an IPython environment. So we manually run this here.
ie.current_env().track_user_pipelines()

result = p.run()
result.wait_until_finish()

self.assertEqual(len(result.get(big)), sum(inputs))
self.assertEqual(
len(result.get(big, include_window_info=True)), sum(inputs))

cache_manager = ie.current_env().get_cache_manager(
result._pipeline_instrument.user_pipeline)
key = result._pipeline_instrument.cache_key(big)
size = cache_manager.size('full', key)
# Despite (highly redundant) windowing information, the cache is small.
self.assertLess(size, sum(inputs))


@unittest.skipIf(
not ie.current_env().is_interactive_ready,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_basic_execution(self):
# that all elements were written to cache.
elems = list(numbers_stream.read())
expected_elems = [
WindowedValue(i, MIN_TIMESTAMP, [GlobalWindow()]) for i in range(3)
WindowedValue(i, MIN_TIMESTAMP, (GlobalWindow(), )) for i in range(3)
]
self.assertListEqual(elems, expected_elems)

Expand Down
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/runners/interactive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
from apache_beam.testing.test_stream import WindowedValueHolder
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils.windowed_value import WindowedValue

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +84,8 @@ def elements():
elif isinstance(e, WindowedValueHolder):
yield (
e.windowed_value if include_window_info else e.windowed_value.value)
elif isinstance(e, WindowedValue):
yield (e if include_window_info else e.value)
else:
yield e

Expand Down
Loading