Skip to content

Commit

Permalink
Keyword-based agent wrapper and name
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 21, 2024
1 parent e7a7d9f commit 8b81169
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/bbrl/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .agent import Agent, TimeAgent, SerializableAgent
from .agent import Agent, TimeAgent, SerializableAgent, KWAgentWrapper
from .dataloader import DataLoaderAgent, ShuffledDatasetAgent
from .remote import NRemoteAgent, RemoteAgent
from .utils import Agents, CopyTAgent, PrintAgent, TemporalAgent, EpisodesDone
Expand Down
28 changes: 28 additions & 0 deletions src/bbrl/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def set_name(self, n):
n (str): The name
"""
self._name = n
return self

def get_name(self):
"""Returns the name of the agent
Expand All @@ -55,6 +56,14 @@ def get_name(self):
"""
return self._name

@property
def name(self) -> str:
return self._name

@name.setter
def name(self, name: str):
self._name = name

def with_prefix(self, prefix: str):
"""Returns the prefix in environments"""
self.prefix = prefix
Expand Down Expand Up @@ -174,6 +183,25 @@ def load_model(self, filename) -> nn.Module:
return torch.load(filename)


class KWAgentWrapper(Agent, ABC):
"""A wrapper that calls the agent with some specific keyword parameters"""

def __init__(self, agent: Agent, **kwargs):
"""Creates a new keyword-based wrapper
:param agent: The agent to be wrapped
"""
super().__init__()
self.wrapped = agent
self.kwargs = kwargs

def forward(self, t: int, *args, **kwargs) -> Any:
self.wrapped.workspace = self.workspace
r = self.wrapped.forward(t, *args, **kwargs, **self.kwargs)
self.wrapped.workspace = None
return r


class TimeAgent(Agent, ABC):
"""
`TimeAgent` is used as a convention to represent agents that
Expand Down

0 comments on commit 8b81169

Please sign in to comment.