Skip to content

Commit

Permalink
[Python] use get_buffer to fetch buffer when the buffer is None (apac…
Browse files Browse the repository at this point in the history
…he#27373)

* Create empty ListBuffer when buffer is none

* Replace empty buffer with a List/GroupBuffer

* Apply suggestions from code review

* Fix mypy error

* Update sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py

* Add test

* Add note to CHANGES.md

* Fix link in CHANGES.md
  • Loading branch information
AnandInguva authored Jul 11, 2023
1 parent 2348737 commit 63d5171
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@

## Bugfixes

* Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Fixed DirectRunner bug in Python SDK where GroupByKey gets empty PCollection and fails when pipeline option `direct_num_workers!=1`. ([#27373](https://github.com/apache/beam/pull/27373))

## Known Issues

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,11 @@ def _execute_bundle(self,

buffers_to_clean = set()
known_consumers = set()
for _, buffer_id in bundle_context_manager.stage_data_outputs.items():
for (consuming_stage_name, consuming_transform) in \
runner_execution_context.buffer_id_to_consumer_pairs.get(buffer_id,
[]):
for transform_id, buffer_id in (
bundle_context_manager.stage_data_outputs.items()):
for (consuming_stage_name, consuming_transform
) in runner_execution_context.buffer_id_to_consumer_pairs.get(
buffer_id, []):
buffer = runner_execution_context.pcoll_buffers.get(buffer_id, None)

if (buffer_id in runner_execution_context.pcoll_buffers and
Expand All @@ -840,6 +841,11 @@ def _execute_bundle(self,
# so we create a copy of the buffer for every new stage.
runner_execution_context.pcoll_buffers[buffer_id] = buffer.copy()
buffer = runner_execution_context.pcoll_buffers[buffer_id]
# When the buffer is not in the pcoll_buffers, it means that the
# it could be an empty PCollection. In this case, get the buffer using
# the buffer id and transform id
if buffer is None:
buffer = bundle_context_manager.get_buffer(buffer_id, transform_id)

# If the buffer has already been added to be consumed by
# (stage, transform), then we don't need to add it again. This case
Expand All @@ -854,7 +860,7 @@ def _execute_bundle(self,
# MAX_TIMESTAMP for the downstream stage.
runner_execution_context.queues.watermark_pending_inputs.enque(
((consuming_stage_name, timestamp.MAX_TIMESTAMP),
DataInput({consuming_transform: buffer}, {}))) # type: ignore
DataInput({consuming_transform: buffer}, {})))

for bid in buffers_to_clean:
if bid in runner_execution_context.pcoll_buffers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,15 @@ def create_pipeline(self, is_drain=False):
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p

def test_group_by_key_with_empty_pcoll_elements(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([('test_key', 'test_value')])
| beam.Filter(lambda x: False)
| beam.GroupByKey())
assert_that(res, equal_to([]))

def test_metrics(self):
raise unittest.SkipTest("This test is for a single worker only.")

Expand Down

0 comments on commit 63d5171

Please sign in to comment.