diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py index 71e9d091..4faaf409 100644 --- a/skrl/agents/jax/base.py +++ b/skrl/agents/jax/base.py @@ -452,7 +452,7 @@ def pre_interaction(self, timestep: int, timesteps: int) -> None: """ pass - def post_interaction(self, timestep: int, timesteps: int) -> None: + def post_interaction(self, timestep: int, timesteps: int, save_checkpoints: bool = True) -> None: """Callback called after the interaction with the environment :param timestep: Current timestep @@ -463,7 +463,7 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: timestep += 1 # update best models and write checkpoints - if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: + if save_checkpoints and timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) if reward > self.checkpoint_best_modules["reward"]: diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py index 237a0953..7b3a05ee 100644 --- a/skrl/agents/torch/base.py +++ b/skrl/agents/torch/base.py @@ -625,7 +625,7 @@ def pre_interaction(self, timestep: int, timesteps: int) -> None: """ pass - def post_interaction(self, timestep: int, timesteps: int) -> None: + def post_interaction(self, timestep: int, timesteps: int, save_checkpoints: bool = True) -> None: """Callback called after the interaction with the environment :param timestep: Current timestep @@ -636,7 +636,7 @@ def post_interaction(self, timestep: int, timesteps: int) -> None: timestep += 1 # update best models and write checkpoints - if timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: + if save_checkpoints and timestep > 1 and self.checkpoint_interval > 0 and not timestep % self.checkpoint_interval: # update best models reward = np.mean(self.tracking_data.get("Reward / Total reward (mean)", -2 ** 31)) if reward > self.checkpoint_best_modules["reward"]: diff --git a/skrl/trainers/jax/base.py b/skrl/trainers/jax/base.py index b542c2e7..4844c167 100644 --- a/skrl/trainers/jax/base.py +++ b/skrl/trainers/jax/base.py @@ -243,7 +243,7 @@ def single_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if self.env.num_envs > 1: @@ -362,7 +362,7 @@ def multi_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if not self.env.agents: diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py index 6fbb261c..58acc13c 100644 --- a/skrl/trainers/jax/sequential.py +++ b/skrl/trainers/jax/sequential.py @@ -185,7 +185,7 @@ def eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if terminated.any() or truncated.any(): diff --git a/skrl/trainers/jax/step.py b/skrl/trainers/jax/step.py index ae7e5986..4ca776e2 100644 --- a/skrl/trainers/jax/step.py +++ b/skrl/trainers/jax/step.py @@ -245,7 +245,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) infos=infos, timestep=timestep, timesteps=timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps, save_checkpoints=False) else: # write data to TensorBoard @@ -259,7 +259,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) infos=infos, timestep=timestep, timesteps=timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps, save_checkpoints=False) # reset environments if terminated.any() or truncated.any(): diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index a13d70a4..c35f990c 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -242,7 +242,7 @@ def single_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if self.env.num_envs > 1: @@ -359,7 +359,7 @@ def multi_agent_eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if not self.env.agents: diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index 68b9b9d8..1831edde 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -99,7 +99,7 @@ def fn_processor(process_index, *args): infos=queue.get(), timestep=msg['timestep'], timesteps=msg['timesteps']) - super(type(agent), agent).post_interaction(timestep=msg['timestep'], timesteps=msg['timesteps']) + super(type(agent), agent).post_interaction(timestep=msg['timestep'], timesteps=msg['timesteps'], save_checkpoints=False) barrier.wait() diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 49952351..ac40c9af 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -182,7 +182,7 @@ def eval(self) -> None: infos=infos, timestep=timestep, timesteps=self.timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps) + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps, save_checkpoints=False) # reset environments if terminated.any() or truncated.any(): diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index c60476f1..f2bed431 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -240,7 +240,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) infos=infos, timestep=timestep, timesteps=timesteps) - super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps) + super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=timesteps, save_checkpoints=False) else: # write data to TensorBoard @@ -254,7 +254,7 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) infos=infos, timestep=timestep, timesteps=timesteps) - super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps) + super(type(agent), agent).post_interaction(timestep=timestep, timesteps=timesteps, save_checkpoints=False) # reset environments if terminated.any() or truncated.any():