From 1078fa359136a58eb6a58d1c61eac7171a20f5b6 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 15 Mar 2024 16:41:48 +0100 Subject: [PATCH] add learner block for >1 GPUs (only for max-in-flight=1 thus far!) avoids the error of results NOT having state/weights of worker #1 in them due to a possible slight delay. Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 4 ++- rllib/core/learner/learner_group.py | 42 ++++++++++++----------------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index ea81ec47891ed..897aa1b5fe093 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -3228,7 +3228,9 @@ def _compile_iteration_results( results.update(results["sampler_results"]) results["num_healthy_workers"] = self.workers.num_healthy_remote_workers() - results["num_in_flight_async_reqs"] = self.workers.num_in_flight_async_reqs() + results["num_in_flight_async_sample_reqs"] = ( + self.workers.num_in_flight_async_reqs() + ) results[ "num_remote_worker_restarts" ] = self.workers.num_remote_worker_restarts() diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index 796069feaa1ab..ab785e052e673 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -69,11 +69,8 @@ class LearnerGroup: def __init__( self, *, - config: AlgorithmConfig = None, # TODO (sven): Make this arg mandatory. + config: AlgorithmConfig, module_spec: Optional[RLModuleSpec] = None, - max_queue_len: int = 20, - # Deprecated args. - learner_spec=None, ): """Initializes a LearnerGroup instance. @@ -90,23 +87,7 @@ def __init__( the specifics for your RLModule to be used in each Learner. module_spec: If not already specified in `config`, a separate overriding RLModuleSpec may be provided via this argument. - max_queue_len: The maximum number of batches to queue up if doing - async_update. If the queue is full it will evict the oldest batch first. """ - if learner_spec is not None: - deprecation_warning( - old="LearnerGroup(learner_spec=...)", - new="config = AlgorithmConfig().[resources|training|rl_module](...); " - "LearnerGroup(config=config)", - error=True, - ) - if config is None: - raise ValueError( - "LearnerGroup constructor must be called with a `config` arg! " - "Pass in a `ray.rllib.algorithms.algorithm_config::AlgorithmConfig` " - "object with the proper settings configured." - ) - # scaling_config = learner_spec.learner_group_scaling_config self.config = config @@ -180,7 +161,9 @@ def __init__( # in-flight. Used for keeping trakc of and grouping together the results of # requests that were sent to the workers at the same time. self._update_request_tags = Counter() + self._update_request_tag = 0 self._additional_update_request_tags = Counter() + self._additional_update_request_tags = 0 def get_stats(self) -> Dict[str, Any]: """Returns the current stats for the input queue for this learner group.""" @@ -417,14 +400,23 @@ def _learner_update( # Retrieve all ready results (kicked off by prior calls to this method). results = None if self._update_request_tags: - results = self._worker_manager.fetch_ready_async_reqs( - tags=list(self._update_request_tags) - ) + assert len(self._update_request_tags) == 1 # only 1 in-flight right now possible + for tag in self._update_request_tags: + results = self._worker_manager.fetch_ready_async_reqs( + tags=[str(tag)], timeout_seconds=None + ) + #if tag+1 not in self._update_request_tags and len(results.result_or_errors) < len(self._workers): + #if len(results.result_or_errors) < len(self._workers): + # more_results = self._worker_manager.fetch_ready_async_reqs( + # tags=[str(tag)], timeout_seconds=None + # ) + # results.add_result() - update_tag = str(uuid.uuid4()) + update_tag = self._update_request_tag + self._update_request_tag += 1 num_sent_requests = self._worker_manager.foreach_actor_async( - partials, tag=update_tag + partials, tag=str(update_tag) ) if num_sent_requests: