Skip to content

Commit

Permalink
include models in context to build the response model to a sync call
Browse files Browse the repository at this point in the history
  • Loading branch information
Dacksus committed Dec 13, 2024
1 parent 006f530 commit f617980
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def _build_context(self) -> InternalContext:
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
models=self._models,
)

def _initialize_wallet_and_identity(self, seed, name, wallet_key_derivation_index):
Expand Down
18 changes: 15 additions & 3 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def __init__(
interval_messages: Optional[Set[str]] = None,
wallet_messaging_client: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
models: Optional[Dict[str, Type[Model]]] = None,
):
self._agent = agent
self._storage = storage
Expand All @@ -272,6 +273,7 @@ def __init__(
self._interval_messages = interval_messages
self._wallet_messaging_client = wallet_messaging_client
self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {}
self._models = models

@property
def agent(self) -> AgentRepresentation:
Expand Down Expand Up @@ -404,7 +406,7 @@ async def send(
message: Model,
sync: bool = False,
timeout: int = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
) -> Union[MsgStatus, Envelope]:
) -> Union[MsgStatus, Model]:
"""
This is the pro-active send method which is used in on_event and
on_interval methods. In these methods, interval messages are set but
Expand Down Expand Up @@ -441,7 +443,7 @@ async def send_raw(
timeout: int = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
protocol_digest: Optional[str] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
) -> Union[MsgStatus, Envelope]:
) -> Union[MsgStatus, Model]:
# Extract address from destination agent identifier if present
_, parsed_name, parsed_address = parse_identifier(destination)

Expand Down Expand Up @@ -521,6 +523,16 @@ async def send_raw(
session=self._session,
)

if isinstance(result, Envelope):
model_class: Optional[Type[Model]] = self._models.get(result.schema_digest)
if model_class is None:
log(self.logger, logging.DEBUG, "unexpected sync reply")
else:
try:
result = model_class.parse_raw(result.decode_payload())
except Exception as ex:
log(self.logger, logging.ERROR, f"Unable to parse message: {ex}")

return result

def _queue_envelope(
Expand Down Expand Up @@ -623,7 +635,7 @@ async def send(
message: Model,
sync: bool = False,
timeout: int = DEFAULT_ENVELOPE_TIMEOUT_SECONDS,
) -> Union[MsgStatus, Envelope]:
) -> Union[MsgStatus, Model]:
"""
Send a message to the specified destination.
Expand Down

0 comments on commit f617980

Please sign in to comment.