Skip to content

Commit

Permalink
Accept runner and options in ib.collect. (#32434)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Sep 12, 2024
1 parent ed84acb commit 4ee2606
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setUp(self):
ie.current_env().track_user_pipelines()

recording_manager = RecordingManager(self._p)
recording = recording_manager.record([self._pcoll], 5, 5)
recording = recording_manager.record([self._pcoll], max_n=5, max_duration=5)
self._stream = recording.stream(self._pcoll)

def test_pcoll_visualization_generate_unique_display_id(self):
Expand Down
12 changes: 11 additions & 1 deletion sdks/python/apache_beam/runners/interactive/interactive_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,8 @@ def collect(
n='inf',
duration='inf',
include_window_info=False,
runner=None,
options=None,
force_compute=False,
force_tuple=False):
"""Materializes the elements from a PCollection into a Dataframe.
Expand All @@ -896,6 +898,9 @@ def collect(
a string duration. Default 'inf'.
include_window_info: (optional) if True, appends the windowing information
to each row. Default False.
runner: (optional) the runner with which to compute the results
options: (optional) any additional pipeline options to use to compute the
results
force_compute: (optional) if True, forces recomputation rather than using
cached PCollections
force_tuple: (optional) if True, return a 1-tuple or results rather than
Expand Down Expand Up @@ -969,7 +974,12 @@ def as_pcollection(pcoll_or_df):
uncomputed = set(pcolls) - set(computed.keys())
if uncomputed:
recording = recording_manager.record(
uncomputed, max_n=n, max_duration=duration, force_compute=force_compute)
uncomputed,
max_n=n,
max_duration=duration,
runner=runner,
options=options,
force_compute=force_compute)

try:
for pcoll in uncomputed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,36 @@ def test_dataframes_same_cell_twice(self):
df_expected['cube'],
ib.collect(df['cube'], n=10).reset_index(drop=True))

@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
def test_new_runner_and_options(self):
class MyRunner(beam.runners.PipelineRunner):
run_count = 0

@classmethod
def run_pipeline(cls, pipeline, options):
assert options._all_options['my_option'] == 123
cls.run_count += 1
return direct_runner.DirectRunner().run_pipeline(pipeline, options)

clear_side_effect()
p = beam.Pipeline(direct_runner.DirectRunner())

# Initial collection runs the pipeline.
pcoll1 = p | beam.Create(['a', 'b', 'c']) | beam.Map(cause_side_effect)
collected1 = ib.collect(pcoll1)
self.assertEqual(set(collected1[0]), set(['a', 'b', 'c']))
self.assertEqual(count_side_effects('a'), 1)

# Using the PCollection uses the cache with a different runner and options.
pcoll2 = pcoll1 | beam.Map(str.upper)
collected2 = ib.collect(
pcoll2,
runner=MyRunner(),
options=beam.options.pipeline_options.PipelineOptions(my_option=123))
self.assertEqual(set(collected2[0]), set(['A', 'B', 'C']))
self.assertEqual(count_side_effects('a'), 1)
self.assertEqual(MyRunner.run_count, 1)


if __name__ == '__main__':
unittest.main()
12 changes: 7 additions & 5 deletions sdks/python/apache_beam/runners/interactive/pipeline_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ class PipelineFragment(object):
A pipeline fragment is built from the original pipeline definition to include
only PTransforms that are necessary to produce the given PCollections.
"""
def __init__(self, pcolls, options=None):
def __init__(self, pcolls, options=None, runner=None):
"""Constructor of PipelineFragment.
Args:
pcolls: (List[PCollection]) a list of PCollections to build pipeline
fragment for.
options: (PipelineOptions) the pipeline options for the implicit
pipeline run.
runner: (Runner) the pipeline runner for the implicit
pipeline run.
"""
assert len(pcolls) > 0, (
'Need at least 1 PCollection as the target data to build a pipeline '
Expand All @@ -61,6 +63,7 @@ def __init__(self, pcolls, options=None):
'given and cannot be used to build a pipeline fragment that produces '
'the given PCollections.'.format(pcoll))
self._options = options
self._runner = runner

# A copied pipeline instance for modification without changing the user
# pipeline instance held by the end user. This instance can be processed
Expand Down Expand Up @@ -98,7 +101,7 @@ def deduce_fragment(self):
"""Deduce the pipeline fragment as an apache_beam.Pipeline instance."""
fragment = beam.pipeline.Pipeline.from_runner_api(
self._runner_pipeline.to_runner_api(),
self._runner_pipeline.runner,
self._runner or self._runner_pipeline.runner,
self._options)
ie.current_env().add_derived_pipeline(self._runner_pipeline, fragment)
return fragment
Expand Down Expand Up @@ -129,9 +132,8 @@ def run(self, display_pipeline_graph=False, use_cache=True, blocking=False):
'the Beam pipeline to use this function '
'on unbouded PCollections.')
result = beam.pipeline.Pipeline.from_runner_api(
pipeline_instrument_proto,
self._runner_pipeline.runner,
self._runner_pipeline._options).run()
pipeline_instrument_proto, fragment.runner,
fragment._options).run()
result.wait_until_finish()
ie.current_env().mark_pcollection_computed(
pipeline_instrument.cached_pcolls)
Expand Down
17 changes: 15 additions & 2 deletions sdks/python/apache_beam/runners/interactive/recording_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import apache_beam as beam
from apache_beam.dataframe.frame_base import DeferredBase
from apache_beam.options import pipeline_options
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import runner
from apache_beam.runners.interactive import background_caching_job as bcj
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import interactive_runner as ir
Expand Down Expand Up @@ -384,8 +386,11 @@ def record_pipeline(self) -> bool:
def record(
self,
pcolls: List[beam.pvalue.PCollection],
*,
max_n: int,
max_duration: Union[int, str],
runner: runner.PipelineRunner = None,
options: pipeline_options.PipelineOptions = None,
force_compute: bool = False) -> Recording:
# noqa: F821

Expand Down Expand Up @@ -427,12 +432,20 @@ def record(
# incomplete.
self._clear()

merged_options = pipeline_options.PipelineOptions(
**{
**self.user_pipeline.options.get_all_options(
drop_default=True, retain_unknown_options=True),
**options.get_all_options(
drop_default=True, retain_unknown_options=True)
}) if options else self.user_pipeline.options

cache_path = ie.current_env().options.cache_root
is_remote_run = cache_path and ie.current_env(
).options.cache_root.startswith('gs://')
pf.PipelineFragment(
list(uncomputed_pcolls),
self.user_pipeline.options).run(blocking=is_remote_run)
list(uncomputed_pcolls), merged_options,
runner=runner).run(blocking=is_remote_run)
result = ie.current_env().pipeline_result(self.user_pipeline)
else:
result = None
Expand Down

0 comments on commit 4ee2606

Please sign in to comment.