Skip to content

Commit

Permalink
Merge pull request #32330 Use proper coders in interactive cache.
Browse files Browse the repository at this point in the history
Formerly the coder used was always a url-excaped pickling[1] of windowed values. This is quite inefficient in time and space.

The default (text) sink is modified to use base64 encoding to avoid embedded newlines, and also has compression by defuault (which helps enormously in the case of common windowing and timestamp metadata).

Also add compression to account for base64 expansion and (often) highly repetative windowing metadata.

[1] The FastPrimitivesCoder is targeted at efficiently coding elements, not windows or windowedvalues.
  • Loading branch information
robertwb authored Sep 3, 2024
2 parents 6baba92 + f8bda18 commit f06df5d
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 16 deletions.
50 changes: 41 additions & 9 deletions sdks/python/apache_beam/runners/interactive/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pytype: skip-file

import base64
import collections
import os
import tempfile
Expand All @@ -28,6 +29,7 @@
from apache_beam.io import filesystems
from apache_beam.io import textio
from apache_beam.io import tfrecordio
from apache_beam.testing import test_stream
from apache_beam.transforms import combiners


Expand Down Expand Up @@ -154,7 +156,15 @@ class FileBasedCacheManager(CacheManager):
"""Maps PCollections to local temp files for materialization."""

_available_formats = {
'text': (textio.ReadFromText, textio.WriteToText),
'text': (
lambda path: textio.ReadFromText(
path,
coder=Base64Coder(),
compression_type=filesystems.CompressionTypes.BZIP2),
lambda path: textio.WriteToText(
path,
coder=Base64Coder(),
compression_type=filesystems.CompressionTypes.BZIP2)),
'tfrecord': (tfrecordio.ReadFromTFRecord, tfrecordio.WriteToTFRecord)
}

Expand Down Expand Up @@ -226,12 +236,13 @@ def read(self, *labels, **args):
return iter([]), -1

# Otherwise, return a generator to the cached PCollection.
source = self.source(*labels)._source
coder = self.load_pcoder('reify', *labels[1:])
source = self.raw_source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
reader = source.read(range_tracker)
version = self._latest_version(*labels)

return reader, version
return (coder.decode(b) for b in reader), version

def write(self, values, *labels):
"""Imitates how a WriteCache transform works without running a pipeline.
Expand All @@ -242,16 +253,17 @@ def write(self, values, *labels):
pcoder = coders.registry.get_coder(type(values[0]))
# Save the pcoder for the actual labels.
self.save_pcoder(pcoder, *labels)
self.save_pcoder(pcoder, 'reify', *labels[-1:])
single_shard_labels = [*labels[:-1], '-00000-of-00001']
# Save the pcoder for the labels that imitates the sharded cache file name
# suffix.
self.save_pcoder(pcoder, *single_shard_labels)
# Put a '-%05d-of-%05d' suffix to the cache file.
sink = self.sink(single_shard_labels)._sink
sink = self.raw_sink(single_shard_labels)._sink
path = self._path(*labels[:-1])
writer = sink.open_writer(path, labels[-1])
for v in values:
writer.write(v)
writer.write(pcoder.encode(v))
writer.close()

def clear(self, *labels):
Expand All @@ -261,12 +273,20 @@ def clear(self, *labels):
return False

def source(self, *labels):
return self._reader_class(
self._glob_path(*labels), coder=self.load_pcoder(*labels))
coder = self.load_pcoder('reify', *labels[1:])
return self.raw_source(*labels) | beam.Map(
lambda b: test_stream.WindowedValueHolder(coder.decode(b)))

def sink(self, labels, is_capture=False):
return self._writer_class(
self._path(*labels), coder=self.load_pcoder(*labels))
coder = self.load_pcoder('reify', *labels[1:])
return beam.Map(lambda wvh: coder.encode(wvh.windowed_value)
) | self.raw_sink(labels, is_capture)

def raw_sink(self, labels, is_capture=False):
return self._writer_class(self._path(*labels))

def raw_source(self, *labels):
return self._reader_class(self._glob_path(*labels))

def cleanup(self):
if self._cache_dir.startswith('gs://'):
Expand Down Expand Up @@ -342,6 +362,12 @@ def expand(self, pcoll):
# We save pcoder that is necessary for proper reading of
# cached PCollection. _cache_manager.sink(...) call below
# should be using this saved pcoder.
self._cache_manager.save_pcoder(
beam.coders.WindowedValueCoder(
beam.coders.registry.get_coder(pcoll.element_type),
pcoll.windowing.windowfn.get_window_coder()),
'reify',
self._label)
self._cache_manager.save_pcoder(
coders.registry.get_coder(pcoll.element_type), prefix, self._label)

Expand All @@ -365,3 +391,9 @@ def encode(self, value):

def decode(self, value):
return coders.coders.FastPrimitivesCoder().decode(unquote_to_bytes(value))


class Base64Coder(coders.Coder):
"""Used to safely encode arbitrary bytes to textio."""
encode = staticmethod(base64.b64encode)
decode = staticmethod(base64.b64decode)
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ def test_size(self):
coder = self.cache_manager.load_pcoder(prefix, cache_label)
encoded = coder.encode(value)

# Add one to the size on disk because of the extra new-line character when
# writing to file.
self.assertEqual(
self.cache_manager.size(prefix, cache_label), len(encoded) + 1)
# We encode in a format that escapes newlines.
self.assertGreater(
self.cache_manager.size(prefix, cache_label), len(encoded))

def test_clear(self):
"""Test that CacheManager can correctly tell if the cache exists or not."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import sys
import unittest
import unittest.mock
from typing import NamedTuple

import pandas as pd
Expand Down Expand Up @@ -138,7 +139,7 @@ def process(self, element):
0: [e[0] for e in actual],
1: [e[1] for e in actual],
'event_time': [end_of_window for _ in actual],
'windows': [[GlobalWindow()] for _ in actual],
'windows': [(GlobalWindow(), ) for _ in actual],
'pane_info': [
PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0) for _ in actual
]
Expand All @@ -153,7 +154,7 @@ def process(self, element):
expected_reified = [
WindowedValue(
e,
Timestamp(micros=end_of_window), [GlobalWindow()],
Timestamp(micros=end_of_window), (GlobalWindow(), ),
PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)) for e in actual
]
self.assertEqual(actual_reified, expected_reified)
Expand Down Expand Up @@ -490,6 +491,47 @@ def enter_composite_transform(
self.assertEqual(producer, prev_producer, trace_string)
prev_producer = consumer

@staticmethod
def only_none_shall_pass(value):
if value is None:
return b'\0'
else:
raise RuntimeError("Should be using a more efficient coder.")

@unittest.mock.patch.object(
beam.coders.coders.FastPrimitivesCoder, 'encode', only_none_shall_pass)
def test_defaults_to_efficient_cache(self):
p = beam.Pipeline(
runner=interactive_runner.InteractiveRunner(
direct_runner.DirectRunner()))

inputs = [1, 10, 100, 1000, 10000]
big = (
p
| beam.Create(inputs)
| 'Explode' >> beam.FlatMap(lambda n: ("v_%s" % ix for ix in range(n))))

# Watch the local scope for Interactive Beam so that counts will be cached.
ib.watch(locals())

# This is normally done in the interactive_utils when a transform is
# applied but needs an IPython environment. So we manually run this here.
ie.current_env().track_user_pipelines()

result = p.run()
result.wait_until_finish()

self.assertEqual(len(result.get(big)), sum(inputs))
self.assertEqual(
len(result.get(big, include_window_info=True)), sum(inputs))

cache_manager = ie.current_env().get_cache_manager(
result._pipeline_instrument.user_pipeline)
key = result._pipeline_instrument.cache_key(big)
size = cache_manager.size('full', key)
# Despite (highly redundant) windowing information, the cache is small.
self.assertLess(size, sum(inputs))


@unittest.skipIf(
not ie.current_env().is_interactive_ready,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_basic_execution(self):
# that all elements were written to cache.
elems = list(numbers_stream.read())
expected_elems = [
WindowedValue(i, MIN_TIMESTAMP, [GlobalWindow()]) for i in range(3)
WindowedValue(i, MIN_TIMESTAMP, (GlobalWindow(), )) for i in range(3)
]
self.assertListEqual(elems, expected_elems)

Expand Down
3 changes: 3 additions & 0 deletions sdks/python/apache_beam/runners/interactive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
from apache_beam.testing.test_stream import WindowedValueHolder
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.utils.windowed_value import WindowedValue

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +84,8 @@ def elements():
elif isinstance(e, WindowedValueHolder):
yield (
e.windowed_value if include_window_info else e.windowed_value.value)
elif isinstance(e, WindowedValue):
yield (e if include_window_info else e.value)
else:
yield e

Expand Down

0 comments on commit f06df5d

Please sign in to comment.