diff --git a/src/bbrl/agents/__init__.py b/src/bbrl/agents/__init__.py index 20d03cf..682138c 100644 --- a/src/bbrl/agents/__init__.py +++ b/src/bbrl/agents/__init__.py @@ -1,4 +1,4 @@ -from .agent import Agent, TimeAgent, SerializableAgent +from .agent import Agent, TimeAgent, SerializableAgent, KWAgentWrapper from .dataloader import DataLoaderAgent, ShuffledDatasetAgent from .remote import NRemoteAgent, RemoteAgent from .utils import Agents, CopyTAgent, PrintAgent, TemporalAgent, EpisodesDone diff --git a/src/bbrl/agents/agent.py b/src/bbrl/agents/agent.py index e912bb7..21932c6 100644 --- a/src/bbrl/agents/agent.py +++ b/src/bbrl/agents/agent.py @@ -46,6 +46,7 @@ def set_name(self, n): n (str): The name """ self._name = n + return self def get_name(self): """Returns the name of the agent @@ -55,6 +56,14 @@ def get_name(self): """ return self._name + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, name: str): + self._name = name + def with_prefix(self, prefix: str): """Returns the prefix in environments""" self.prefix = prefix @@ -174,6 +183,25 @@ def load_model(self, filename) -> nn.Module: return torch.load(filename) +class KWAgentWrapper(Agent, ABC): + """A wrapper that calls the agent with some specific keyword parameters""" + + def __init__(self, agent: Agent, **kwargs): + """Creates a new keyword-based wrapper + + :param agent: The agent to be wrapped + """ + super().__init__() + self.wrapped = agent + self.kwargs = kwargs + + def forward(self, t: int, *args, **kwargs) -> Any: + self.wrapped.workspace = self.workspace + r = self.wrapped.forward(t, *args, **kwargs, **self.kwargs) + self.wrapped.workspace = None + return r + + class TimeAgent(Agent, ABC): """ `TimeAgent` is used as a convention to represent agents that