Skip to content

Commit

Permalink
add learner block for >1 GPUs (only for max-in-flight=1 thus far!)
Browse files Browse the repository at this point in the history
avoids the error of results NOT having state/weights of worker #1 in them due to a possible slight delay.

Signed-off-by: sven1977 <svenmika1977@gmail.com>
  • Loading branch information
sven1977 committed Mar 15, 2024
1 parent b02d34d commit 1078fa3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 26 deletions.
4 changes: 3 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3228,7 +3228,9 @@ def _compile_iteration_results(
results.update(results["sampler_results"])

results["num_healthy_workers"] = self.workers.num_healthy_remote_workers()
results["num_in_flight_async_reqs"] = self.workers.num_in_flight_async_reqs()
results["num_in_flight_async_sample_reqs"] = (
self.workers.num_in_flight_async_reqs()
)
results[
"num_remote_worker_restarts"
] = self.workers.num_remote_worker_restarts()
Expand Down
42 changes: 17 additions & 25 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ class LearnerGroup:
def __init__(
self,
*,
config: AlgorithmConfig = None, # TODO (sven): Make this arg mandatory.
config: AlgorithmConfig,
module_spec: Optional[RLModuleSpec] = None,
max_queue_len: int = 20,
# Deprecated args.
learner_spec=None,
):
"""Initializes a LearnerGroup instance.
Expand All @@ -90,23 +87,7 @@ def __init__(
the specifics for your RLModule to be used in each Learner.
module_spec: If not already specified in `config`, a separate overriding
RLModuleSpec may be provided via this argument.
max_queue_len: The maximum number of batches to queue up if doing
async_update. If the queue is full it will evict the oldest batch first.
"""
if learner_spec is not None:
deprecation_warning(
old="LearnerGroup(learner_spec=...)",
new="config = AlgorithmConfig().[resources|training|rl_module](...); "
"LearnerGroup(config=config)",
error=True,
)
if config is None:
raise ValueError(
"LearnerGroup constructor must be called with a `config` arg! "
"Pass in a `ray.rllib.algorithms.algorithm_config::AlgorithmConfig` "
"object with the proper settings configured."
)

# scaling_config = learner_spec.learner_group_scaling_config
self.config = config

Expand Down Expand Up @@ -180,7 +161,9 @@ def __init__(
# in-flight. Used for keeping trakc of and grouping together the results of
# requests that were sent to the workers at the same time.
self._update_request_tags = Counter()
self._update_request_tag = 0
self._additional_update_request_tags = Counter()
self._additional_update_request_tags = 0

def get_stats(self) -> Dict[str, Any]:
"""Returns the current stats for the input queue for this learner group."""
Expand Down Expand Up @@ -417,14 +400,23 @@ def _learner_update(
# Retrieve all ready results (kicked off by prior calls to this method).
results = None
if self._update_request_tags:
results = self._worker_manager.fetch_ready_async_reqs(
tags=list(self._update_request_tags)
)
assert len(self._update_request_tags) == 1 # only 1 in-flight right now possible
for tag in self._update_request_tags:
results = self._worker_manager.fetch_ready_async_reqs(
tags=[str(tag)], timeout_seconds=None
)
#if tag+1 not in self._update_request_tags and len(results.result_or_errors) < len(self._workers):
#if len(results.result_or_errors) < len(self._workers):
# more_results = self._worker_manager.fetch_ready_async_reqs(
# tags=[str(tag)], timeout_seconds=None
# )
# results.add_result()

update_tag = str(uuid.uuid4())
update_tag = self._update_request_tag
self._update_request_tag += 1

num_sent_requests = self._worker_manager.foreach_actor_async(
partials, tag=update_tag
partials, tag=str(update_tag)
)

if num_sent_requests:
Expand Down

0 comments on commit 1078fa3

Please sign in to comment.