Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing walk order to resolve priority in multi-sink pipelines #120

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion dplutils/pipeline/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from enum import Enum

from networkx import DiGraph, all_simple_paths, bfs_edges, is_directed_acyclic_graph, path_graph
Expand Down Expand Up @@ -69,6 +70,7 @@
return 0 if isinstance(x, TRM) else sort_key(x)

sorter = (lambda x: sorted(x, key=_sort_key)) if sort_key else None

for _, node in bfs_edges(graph, source, reverse=back, sort_neighbors=sorter):
if not isinstance(node, TRM):
yield node
Expand Down Expand Up @@ -101,4 +103,21 @@
tasks in order of callable `sort_key`, which should return a
sortable object given :class:`PipelineTask` as input.
"""
return self._walk(source or TRM.sink, back=True, sort_key=sort_key)
paths = all_simple_paths(self.with_terminals().reverse(), source or TRM.sink, TRM.source)
depths = defaultdict(int)
layers = defaultdict(list)

Check warning on line 108 in dplutils/pipeline/graph.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/graph.py#L106-L108

Added lines #L106 - L108 were not covered by tests
# unlike bfs_edges/bfs_layers, we order by maximum depth from source, to
# try and ensure we prioritize outputs while also preferring tasks
# further along.
for path in paths:
for i, node in enumerate(path):
if isinstance(node, TRM) or node == source:
continue
depths[node] = max(depths[node], i)
for node, depth in depths.items():
layers[depth].append(node)

Check warning on line 118 in dplutils/pipeline/graph.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/graph.py#L112-L118

Added lines #L112 - L118 were not covered by tests
# layers will be keyed by maximum distance from source, containing a
# list of nodes at that distance. Yield based on sort key secondarily.
for i in sorted(layers.keys()):
for node in sorted(layers[i], key=sort_key or (lambda x: 0)):
yield node

Check warning on line 123 in dplutils/pipeline/graph.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/graph.py#L121-L123

Added lines #L121 - L123 were not covered by tests
66 changes: 38 additions & 28 deletions dplutils/pipeline/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@
total_length += len(next_df)

def enqueue_tasks(self):
# Work through the graph in reverse order, submitting any tasks as
# needed. Reverse order ensures we prefer to send tasks that are closer
# to the end of the pipeline and only feed as necessary.
# helper to make submission decision of a single task based on the batch
# size, exhaustion conditions, and whether the implementation deems it
# submittable. Returns flags (eligible, submitted) to indicate whether
# it was eligible to be submitted based on input queue and batch size,
# and whether it was actually submitted.
def _handle_one_task(task, rank):
eligible = submitted = False
if len(task.data_in) == 0:
Expand All @@ -179,34 +181,42 @@
self.logger.debug(f"Enqueueing split for <{task.name}>[bs={batch_size}]")
task.split_pending.appendleft(self.split_batch_submit(batch, batch_size))

while len(task.data_in) > 0:
num_to_merge = deque_num_merge(task.data_in, batch_size)
if num_to_merge == 0:
# If the feed is terminated and there are no more tasks that
# will feed to this one, submit everything
if self.source_exhausted and self.task_exhausted(task):
num_to_merge = len(task.data_in)
else:
break
eligible = True
if not self.task_submittable(task.task, rank):
break
merged = [task.data_in.pop().data for i in range(num_to_merge)]
self.logger.debug(f"Enqueueing merged batches <{task.name}>[n={len(merged)};bs={batch_size}]")
task.pending.appendleft(self.task_submit(task.task, merged))
task.counter += 1
submitted = True
num_to_merge = deque_num_merge(task.data_in, batch_size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: the change in this block simply removes the for loop, submitting (if eligible) only a single invocation of task at a time (to enable the fairness re-evaluation). The break -> return makes the diff look like more than whitespace.

if num_to_merge == 0:

Check warning on line 185 in dplutils/pipeline/stream.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/stream.py#L184-L185

Added lines #L184 - L185 were not covered by tests
# If the feed is terminated and there are no more tasks that
# will feed to this one, submit everything
if self.source_exhausted and self.task_exhausted(task):
num_to_merge = len(task.data_in)

Check warning on line 189 in dplutils/pipeline/stream.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/stream.py#L188-L189

Added lines #L188 - L189 were not covered by tests
else:
return (eligible, submitted)
eligible = True
if not self.task_submittable(task.task, rank):
return (eligible, submitted)

Check warning on line 194 in dplutils/pipeline/stream.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/stream.py#L191-L194

Added lines #L191 - L194 were not covered by tests

merged = [task.data_in.pop().data for _ in range(num_to_merge)]
self.logger.debug(f"Enqueueing merged batches <{task.name}>[n={len(merged)};bs={batch_size}]")
task.pending.appendleft(self.task_submit(task.task, merged))
task.counter += 1
submitted = True

Check warning on line 200 in dplutils/pipeline/stream.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/stream.py#L196-L200

Added lines #L196 - L200 were not covered by tests
return (eligible, submitted)

# proceed through all non-source tasks, which will be handled separately
# below due to the need to feed from generator.
rank = 0
for task in self.stream_graph.walk_back(sort_key=lambda x: x.counter):
if task in self.stream_graph.source_tasks:
continue
eligible, _ = _handle_one_task(task, rank)
if eligible: # update rank of this task if it _could_ be done, whether or not it was
rank += 1
# below due to the need to feed from generator. We walk backwards,
# re-evaluating the sort order of tasks of same depth after each single
# submission, implementing a kind of "fair" submission, while still
# prioritizing tasks closer to the sink.
submitted = True
while submitted:
rank = 0
submitted = False
Comment on lines +208 to +211
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below this is unchanged. This is here to implement fairness, so long as anything has been submitted, we keep walking the graph from top down but with new sort order. Once nothing is submitted, either no room or need sources.

for task in self.stream_graph.walk_back(sort_key=lambda x: x.counter):
if task in self.stream_graph.source_tasks:
continue
eligible, submitted = _handle_one_task(task, rank)
if eligible: # update rank of this task if it _could_ be done, whether or not it was
rank += 1
if submitted:
break

Check warning on line 219 in dplutils/pipeline/stream.py

View check run for this annotation

Codecov / codecov/patch

dplutils/pipeline/stream.py#L208-L219

Added lines #L208 - L219 were not covered by tests

# Source as many inputs as can fit on source tasks. We prioritize flushing the
# input queue and secondarily on number of invocations in case batch sizes differ.
Expand Down
19 changes: 19 additions & 0 deletions tests/pipeline/test_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def graph_suite():
"multisource": make_graph_struct([(a, c), (b, c), (c, d)], [a, b], [d]),
"multisink": make_graph_struct([(a, b), (b, c), (b, d)], [a], [c, d]),
"branchmulti": make_graph_struct([(a, c), (b, c), (c, d), (c, e), (d, f), (e, g)], [a, b], [f, g]),
"branchmultiout": make_graph_struct([(a, b), (b, c), (b, d), (d, e), (d, f), (f, g)], [a], [c, e, g]),
}


Expand Down Expand Up @@ -79,6 +80,14 @@ def test_graph_walk_returns_node_list(self, graph_info):
assert walked[-1] in graph_info.sources
assert len(walked) == len(p)

def test_graph_walk_excludes_starting_node(self, graph_info):
p = PipelineGraph(graph_info.edges)
source = graph_info.sinks[0]
walked = list(p.walk_back(source))
assert source not in walked
walked = list(p.walk_fwd(source))
assert source not in walked


def test_graph_walk_with_priority():
test = graph_suite()["branched"]
Expand All @@ -92,6 +101,16 @@ def test_graph_walk_with_priority():
assert walked == [p.task_map[i] for i in ["e", "d", "c", "b", "a"]]
walked = list(p.walk_fwd(sort_key=lambda x: -x.func))
assert walked == [p.task_map[i] for i in ["a", "b", "d", "c", "e"]]
# make sure to test with multi output, which can make priority in BFS more
# challenging, specifically in the back direction. Critically below, nodes
# "b" and "d" are both 2 away from the sink at minimum, but "f" is farther
# along so it should be priority, while all sinks should still be
# prioritized
p = PipelineGraph(graph_suite()["branchmultiout"].edges)
walked = list(p.walk_back(sort_key=lambda x: x.func))
assert walked == [p.task_map[i] for i in ["c", "e", "g", "f", "d", "b", "a"]]
walked = list(p.walk_back(sort_key=lambda x: -x.func))
assert walked == [p.task_map[i] for i in ["g", "e", "c", "f", "d", "b", "a"]]


def test_single_node_graph_to_list():
Expand Down
46 changes: 46 additions & 0 deletions tests/pipeline/test_stream_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

import pandas as pd
import pytest
from test_suite import PipelineExecutorTestSuite
Expand Down Expand Up @@ -77,3 +79,47 @@ def generator():
pl = LocalSerialExecutor(dummy_steps, generator=generator).set_config("task1.batch_size", 1)
res = [i.data for i in pl.run()]
assert len(res) == 8


def test_stream_submission_ordering_evaluation_priority():
# tracking class adds counts and a parallel submission which allows us to
# locally test that the re-prioritization during submission is working. If
# so, we should expect terminal tasks having even numbers of calls and being
# preferred (as opposed to submitting n parallel all to one task, or
# submitting some to upstream tasks).
class MyExec(LocalSerialExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counts = defaultdict(int)
self.parallel_submissions = 4
self.n_parallel = 0

def task_submit(self, task, data):
self.counts[task.name] += 1
self.n_parallel += 1
return super().task_submit(task, data)

def task_submittable(self, t, rank):
if t in self.graph.source_tasks:
return True
return self.n_parallel < self.parallel_submissions

def poll_tasks(self, pending):
self.n_parallel = 0

a = PipelineTask("a", lambda x: x, batch_size=16)
b = a("b", batch_size=1)
(c, d, e) = (b("c"), b("d"), b("e"))
# graph with multiple terminals. The large input batch size ensures we have
# work to submit in parallel to exercise the re sorting logic.
p = MyExec([(a, b), (b, c), (c, d), (c, e)], max_batches=16)
p_run = p.run()
_ = [next(p_run) for _ in range(4)] # pop number based on parallel submissions
assert p.counts["d"] == p.counts["e"] == 2 # terminals should have even counts
assert p.counts["b"] == p.counts["c"] == 4 # only just enough to submit 4
_ = [next(p_run) for _ in range(4)]
assert p.counts["d"] == p.counts["e"] == 4 # we finish the 4 batch size
assert p.counts["b"] == p.counts["c"] == 4 # but do no more upstream work
_ = [next(p_run) for _ in range(4)]
assert p.counts["d"] == p.counts["e"] == 6 # need to get more work, as above
assert p.counts["b"] == p.counts["c"] == 8 # more upstream for that work
Loading