Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): consistently generated session ids #531

Merged
merged 12 commits into from
Oct 8, 2024
11 changes: 11 additions & 0 deletions python/docs/api/uagents/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ An agent that interacts within a communication environment.
corresponding protocols.
- `_ctx` _Context_ - The context for agent interactions.
- `_test` _bool_ - True if the agent will register and transact on the testnet.
- `_enable_agent_inspector` _bool_ - Enable the agent inspector REST endpoints.

Properties:
- `name` _str_ - The name of the agent.
Expand Down Expand Up @@ -732,6 +733,16 @@ def start_message_receivers()

Start message receiving tasks for the agent.

<a id="src.uagents.agent.Agent.start_server"></a>

#### start`_`server

```python
async def start_server()
```

Start the agent's server.

<a id="src.uagents.agent.Agent.run_async"></a>

#### run`_`async
Expand Down
8 changes: 3 additions & 5 deletions python/docs/api/uagents/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ agent (AgentRepresentation): The agent representation associated with the contex
storage (KeyValueStore): The key-value store for storage operations.
ledger (LedgerClient): The client for interacting with the blockchain ledger.
logger (logging.Logger): The logger instance.
session (uuid.UUID): The session UUID associated with the context.

**Methods**:

Expand Down Expand Up @@ -101,7 +102,7 @@ Get the logger instance associated with the context.
```python
@property
@abstractmethod
def session() -> Union[uuid.UUID, None]
def session() -> uuid.UUID
```

Get the session UUID associated with the context.
Expand Down Expand Up @@ -267,7 +268,7 @@ Represents the agent internal context for proactive behaviour.

```python
@property
def session() -> Union[uuid.UUID, None]
def session() -> uuid.UUID
```

Get the session UUID associated with the context.
Expand Down Expand Up @@ -339,7 +340,6 @@ Represents the reactive context in which messages are handled and processed.

- `_queries` _Dict[str, asyncio.Future]_ - Dictionary mapping query senders to their
response Futures.
- `_session` _Optional[uuid.UUID]_ - The session UUID.
- `_replies` _Optional[Dict[str, Dict[str, Type[Model]]]]_ - Dictionary of allowed reply digests
for each type of incoming message.
- `_message_received` _Optional[MsgDigest]_ - The message digest received.
Expand All @@ -352,7 +352,6 @@ Represents the reactive context in which messages are handled and processed.

```python
def __init__(message_received: MsgDigest,
session: Optional[uuid.UUID] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
protocol: Optional[Tuple[str, Protocol]] = None,
Expand All @@ -366,7 +365,6 @@ Initialize the ExternalContext instance and attributes needed from the InternalC
- `message_received` _MsgDigest_ - The optional message digest received.
- `queries` _Dict[str, asyncio.Future]_ - Dictionary mapping query senders to their
response Futures.
- `session` _Optional[uuid.UUID]_ - The optional session UUID.
- `replies` _Optional[Dict[str, Dict[str, Type[Model]]]]_ - Dictionary of allowed replies
for each type of incoming message.
- `protocol` _Optional[Tuple[str, Protocol]]_ - The optional Tuple of protocols.
Expand Down
86 changes: 59 additions & 27 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
parse_agentverse_config,
parse_endpoint_config,
)
from uagents.context import Context, ExternalContext, InternalContext
from uagents.context import Context, ContextFactory, ExternalContext, InternalContext
from uagents.crypto import Identity, derive_key_from_seed, is_user_address
from uagents.dispatch import Sink, dispatcher
from uagents.envelope import EnvelopeHistory, EnvelopeHistoryEntry
Expand Down Expand Up @@ -71,24 +71,31 @@
from uagents.utils import get_logger


async def _run_interval(func: IntervalCallback, ctx: Context, period: float):
async def _run_interval(
func: IntervalCallback,
logger: logging.Logger,
context_factory: ContextFactory,
period: float,
):
"""
Run the provided interval callback function at a specified period.

Args:
func (IntervalCallback): The interval callback function to run.
ctx (Context): The context for the agent.
logger (logging.Logger): The logger instance for logging interval handler activities.
context_factory (ContextFactory): The factory function for creating the context.
period (float): The time period at which to run the callback function.
"""
while True:
try:
ctx = context_factory()
await func(ctx)
except OSError as ex:
ctx.logger.exception(f"OS Error in interval handler: {ex}")
logger.exception(f"OS Error in interval handler: {ex}")
except RuntimeError as ex:
ctx.logger.exception(f"Runtime Error in interval handler: {ex}")
logger.exception(f"Runtime Error in interval handler: {ex}")
except Exception as ex:
ctx.logger.exception(f"Exception in interval handler: {ex}")
logger.exception(f"Exception in interval handler: {ex}")

await asyncio.sleep(period)

Expand Down Expand Up @@ -377,21 +384,6 @@ def __init__(
# keep track of supported protocols
self.protocols: Dict[str, Protocol] = {}

self._ctx = InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

# register with the dispatcher
self._dispatcher.register(self.address, self)

Expand Down Expand Up @@ -426,6 +418,28 @@ async def _handle_get_messages(_ctx: Context):

self._init_done = True

def _build_context(self) -> InternalContext:
"""
An internal method to build the context for the agent.

Returns:
InternalContext: The internal context for the agent.
"""
return InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

def _initialize_wallet_and_identity(self, seed, name, wallet_key_derivation_index):
"""
Initialize the wallet and identity for the agent.
Expand Down Expand Up @@ -997,7 +1011,10 @@ async def handle_rest(
if not handler:
return None

args = (self._ctx, message) if message else (self._ctx,)
args = []
args.append(self._build_context())
if message:
args.append(message)

return await handler(*args) # type: ignore

Expand All @@ -1015,7 +1032,8 @@ async def _startup(self):
)
for handler in self._on_startup:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in startup handler: {ex}")
except RuntimeError as ex:
Expand All @@ -1030,7 +1048,8 @@ async def _shutdown(self):
"""
for handler in self._on_shutdown:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in shutdown handler: {ex}")
except RuntimeError as ex:
Expand Down Expand Up @@ -1061,7 +1080,9 @@ def start_interval_tasks(self):

"""
for func, period in self._interval_handlers:
self._loop.create_task(_run_interval(func, self._ctx, period))
self._loop.create_task(
_run_interval(func, self._logger, self._build_context, period)
)

def start_message_receivers(self):
"""
Expand All @@ -1075,7 +1096,9 @@ def start_message_receivers(self):
if self._wallet_messaging_client is not None:
for task in [
self._wallet_messaging_client.poll_server(),
self._wallet_messaging_client.process_message_queue(self._ctx),
self._wallet_messaging_client.process_message_queue(
self._build_context
),
]:
self._loop.create_task(task)

Expand Down Expand Up @@ -1163,7 +1186,11 @@ async def _process_message_queue(self):
)

context = ExternalContext(
agent=self._ctx.agent,
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
Expand All @@ -1179,6 +1206,11 @@ async def _process_message_queue(self):
protocol=protocol_info,
)

# sanity check
assert (
context.session == session
), "Context object should always have message session"

# parse the received message
try:
recovered = model_class.parse_raw(message)
Expand Down
19 changes: 9 additions & 10 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -59,6 +60,7 @@ class Context(ABC):
storage (KeyValueStore): The key-value store for storage operations.
ledger (LedgerClient): The client for interacting with the blockchain ledger.
logger (logging.Logger): The logger instance.
session (uuid.UUID): The session UUID associated with the context.

Methods:
get_agents_by_protocol(protocol_digest, limit, logger): Retrieve a list of agent addresses
Expand Down Expand Up @@ -116,7 +118,7 @@ def logger(self) -> logging.Logger:

@property
@abstractmethod
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.

Expand Down Expand Up @@ -256,6 +258,7 @@ def __init__(
ledger: LedgerClient,
resolver: Resolver,
dispenser: Dispenser,
session: Optional[uuid.UUID] = None,
interval_messages: Optional[Set[str]] = None,
wallet_messaging_client: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
Expand All @@ -266,7 +269,7 @@ def __init__(
self._resolver = resolver
self._dispenser = dispenser
self._logger = logger
self._session: Optional[uuid.UUID] = None
self._session = session or uuid.uuid4()
self._interval_messages = interval_messages
self._wallet_messaging_client = wallet_messaging_client
self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {}
Expand All @@ -288,7 +291,7 @@ def logger(self) -> Union[logging.Logger, None]:
return self._logger

@property
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.

Expand Down Expand Up @@ -408,7 +411,6 @@ async def send(
we don't have access properties that are only necessary in re-active
contexts, like 'replies', 'message_received', or 'protocol'.
"""
self._session = None
schema_digest = Model.build_schema_digest(message)
message_body = message.model_dump_json()

Expand Down Expand Up @@ -440,8 +442,6 @@ async def send_raw(
protocol_digest: Optional[str] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
) -> MsgStatus:
self._session = self._session or uuid.uuid4()

# Extract address from destination agent identifier if present
_, parsed_name, parsed_address = parse_identifier(destination)

Expand Down Expand Up @@ -564,7 +564,6 @@ class ExternalContext(InternalContext):
Attributes:
_queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their
response Futures.
_session (Optional[uuid.UUID]): The session UUID.
_replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed reply digests
for each type of incoming message.
_message_received (Optional[MsgDigest]): The message digest received.
Expand All @@ -575,7 +574,6 @@ class ExternalContext(InternalContext):
def __init__(
self,
message_received: MsgDigest,
session: Optional[uuid.UUID] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
protocol: Optional[Tuple[str, Protocol]] = None,
Expand All @@ -588,13 +586,11 @@ def __init__(
message_received (MsgDigest): The optional message digest received.
queries (Dict[str, asyncio.Future]): Dictionary mapping query senders to their
response Futures.
session (Optional[uuid.UUID]): The optional session UUID.
replies (Optional[Dict[str, Dict[str, Type[Model]]]]): Dictionary of allowed replies
for each type of incoming message.
protocol (Optional[Tuple[str, Protocol]]): The optional Tuple of protocols.
"""
super().__init__(**kwargs)
self._session = session or None
self._queries = queries or {}
self._replies = replies
self._message_received = message_received
Expand Down Expand Up @@ -674,3 +670,6 @@ async def send(
protocol_digest=self._protocol[0],
queries=self._queries,
)


ContextFactory = Callable[[], Context]
11 changes: 1 addition & 10 deletions python/src/uagents/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from uagents.crypto import Identity
from uagents.network import AlmanacContract, InsufficientFundsError, add_testnet_funds
from uagents.types import AgentEndpoint
from uagents.types import AgentEndpoint, AgentGeoLocation


class AgentRegistrationPolicy(ABC):
Expand All @@ -32,15 +32,6 @@ async def register(
pass


class AgentGeoLocation(BaseModel):
# Latitude and longitude of the agent
latitude: float
longitude: float

# Radius around the agent location, expressed in meters
radius: float


class AgentRegistrationAttestation(BaseModel):
agent_address: str
protocols: List[str]
Expand Down
9 changes: 9 additions & 0 deletions python/src/uagents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ class AgentEndpoint(BaseModel):
weight: int


class AgentGeoLocation(BaseModel):
# Latitude and longitude of the agent
latitude: float
longitude: float

# Radius around the agent location, expressed in meters
radius: float


class AgentInfo(BaseModel):
agent_address: str
endpoints: List[AgentEndpoint]
Expand Down
Loading
Loading