From 207bcaeb10a130bb1881348ee2352b52f1412022 Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 24 Nov 2023 20:20:41 +0800 Subject: [PATCH] tmp save --- malib/learner/manager.py | 3 +++ malib/models/config.py | 1 - malib/rl/config.py | 1 - malib/rollout/config.py | 1 - malib/rollout/envs/mdp/env.py | 1 - malib/scenarios/sarl_scenario.py | 2 +- tests/rollout/test_pb_rollout_worker.py | 1 - 7 files changed, 4 insertions(+), 6 deletions(-) diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 86e784f..79ccbbd 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -158,6 +158,9 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) self._stopping_conditions = stopping_conditions + # init strategy spec + self.add_policies() + Logger.info( f"training manager launched, {len(self._learners)} learner(s) created" ) diff --git a/malib/models/config.py b/malib/models/config.py index bb1a210..0683b93 100644 --- a/malib/models/config.py +++ b/malib/models/config.py @@ -5,7 +5,6 @@ @dataclass class ModelConfig: - model_cls: Type model_args: Dict[str, Any] diff --git a/malib/rl/config.py b/malib/rl/config.py index 5935b99..552543d 100644 --- a/malib/rl/config.py +++ b/malib/rl/config.py @@ -8,7 +8,6 @@ @dataclass class Algorithm: - policy: Type[Policy] trainer: Type[Trainer] diff --git a/malib/rollout/config.py b/malib/rollout/config.py index 4e462d6..f576c52 100644 --- a/malib/rollout/config.py +++ b/malib/rollout/config.py @@ -5,7 +5,6 @@ @dataclass class RolloutConfig: - num_workers: int = 1 """Defines how many workers will be used for executing one rollout task, default is 1""" diff --git a/malib/rollout/envs/mdp/env.py b/malib/rollout/envs/mdp/env.py index e97c99e..ce96515 100644 --- a/malib/rollout/envs/mdp/env.py +++ b/malib/rollout/envs/mdp/env.py @@ -9,7 +9,6 @@ class MDPEnvironment(Environment): def __init__(self, **configs): - try: from blackhc import mdp from blackhc.mdp import example as mdp_examples diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 0cc3659..07bc281 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -67,7 +67,7 @@ def create_global_stopper(self) -> StoppingCondition: def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): - # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input + # TODO(ming): simplize the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( stopping_conditions=scenario.stopping_conditions, algorithm=scenario.algorithm, diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 9d55103..baf1c33 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -93,7 +93,6 @@ def multiagent_post_process( class FakeFeatureHandler(BaseFeature): - pass