diff --git a/agents-api/agents_api/activities/sync_items_remote.py b/agents-api/agents_api/activities/sync_items_remote.py new file mode 100644 index 000000000..404fac496 --- /dev/null +++ b/agents-api/agents_api/activities/sync_items_remote.py @@ -0,0 +1,24 @@ +from typing import Any + +from beartype import beartype +from temporalio import activity + +from ..common.protocol.remote import RemoteObject + + +@beartype +async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]: + from ..common.storage_handler import store_in_blob_store_if_large + + return [store_in_blob_store_if_large(input) for input in inputs] + + +@beartype +async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]: + from ..common.storage_handler import load_from_blob_store_if_remote + + return [load_from_blob_store_if_remote(input) for input in inputs] + + +save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn) +load_inputs_remote = activity.defn(name="load_inputs_remote")(load_inputs_remote_fn) diff --git a/agents-api/agents_api/clients/s3.py b/agents-api/agents_api/clients/s3.py index 749f53245..e3f8157c3 100644 --- a/agents-api/agents_api/clients/s3.py +++ b/agents-api/agents_api/clients/s3.py @@ -71,7 +71,7 @@ def add_object(key: str, body: bytes, replace: bool = False) -> None: client.put_object(Bucket=blob_store_bucket, Key=key, Body=body) -@lru_cache(maxsize=256 * 1024 // blob_store_cutoff_kb) # 256mb in cache +@lru_cache(maxsize=256 * 1024 // max(1, blob_store_cutoff_kb)) # 256mb in cache @beartype def get_object(key: str) -> bytes: client = get_s3_client() diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py index 19bb5b5ae..86fbaa65d 100644 --- a/agents-api/agents_api/common/exceptions/tasks.py +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -121,7 +121,12 @@ def is_non_retryable_error(error: BaseException) -> bool: # Check for specific HTTP errors (status code == 429) if isinstance(error, httpx.HTTPStatusError): - if error.response.status_code in (408, 429, 503, 504): + if error.response.status_code in ( + 408, + 429, + 503, + 504, + ): # pytype: disable=attribute-error return False # If we don't know about the error, we should not retry diff --git a/agents-api/agents_api/common/protocol/remote.py b/agents-api/agents_api/common/protocol/remote.py new file mode 100644 index 000000000..6ec3e3a5a --- /dev/null +++ b/agents-api/agents_api/common/protocol/remote.py @@ -0,0 +1,236 @@ +from dataclasses import dataclass +from typing import Any, Iterator + +from temporalio import activity, workflow + +with workflow.unsafe.imports_passed_through(): + from pydantic import BaseModel + + from ...env import blob_store_bucket + + +@dataclass +class RemoteObject: + key: str + bucket: str = blob_store_bucket + + +class BaseRemoteModel(BaseModel): + _remote_cache: dict[str, Any] + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data: Any): + super().__init__(**data) + self._remote_cache = {} + + def __load_item(self, item: Any | RemoteObject) -> Any: + if not activity.in_activity(): + return item + + from ..storage_handler import load_from_blob_store_if_remote + + return load_from_blob_store_if_remote(item) + + def __save_item(self, item: Any) -> Any: + if not activity.in_activity(): + return item + + from ..storage_handler import store_in_blob_store_if_large + + return store_in_blob_store_if_large(item) + + def __getattribute__(self, name: str) -> Any: + if name.startswith("_"): + return super().__getattribute__(name) + + try: + value = super().__getattribute__(name) + except AttributeError: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + if isinstance(value, RemoteObject): + cache = super().__getattribute__("_remote_cache") + if name in cache: + return cache[name] + + loaded_data = self.__load_item(value) + cache[name] = loaded_data + return loaded_data + + return value + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + super().__setattr__(name, value) + return + + stored_value = self.__save_item(value) + super().__setattr__(name, stored_value) + + if isinstance(stored_value, RemoteObject): + cache = self.__dict__.get("_remote_cache", {}) + cache.pop(name, None) + + def unload_attribute(self, name: str) -> None: + if name in self._remote_cache: + data = self._remote_cache.pop(name) + remote_obj = self.__save_item(data) + super().__setattr__(name, remote_obj) + + def unload_all(self) -> None: + for name in list(self._remote_cache.keys()): + self.unload_attribute(name) + + +class RemoteList(list): + _remote_cache: dict[int, Any] + + def __init__(self, iterable: list[Any] | None = None): + super().__init__() + self._remote_cache: dict[int, Any] = {} + if iterable: + for item in iterable: + self.append(item) + + def __load_item(self, item: Any | RemoteObject) -> Any: + if not activity.in_activity(): + return item + + from ..storage_handler import load_from_blob_store_if_remote + + return load_from_blob_store_if_remote(item) + + def __save_item(self, item: Any) -> Any: + if not activity.in_activity(): + return item + + from ..storage_handler import store_in_blob_store_if_large + + return store_in_blob_store_if_large(item) + + def __getitem__(self, index: int | slice) -> Any: + if isinstance(index, slice): + # Obtain the slice without triggering __getitem__ recursively + sliced_items = super().__getitem__( + index + ) # This returns a list of items as is + return RemoteList._from_existing_items(sliced_items) + else: + value = super().__getitem__(index) + + if isinstance(value, RemoteObject): + if index in self._remote_cache: + return self._remote_cache[index] + loaded_data = self.__load_item(value) + self._remote_cache[index] = loaded_data + return loaded_data + return value + + @classmethod + def _from_existing_items(cls, items: list[Any]) -> "RemoteList": + """ + Create a RemoteList from existing items without processing them again. + This method ensures that slicing does not trigger loading of items. + """ + new_remote_list = cls.__new__( + cls + ) # Create a new instance without calling __init__ + list.__init__(new_remote_list) # Initialize as an empty list + new_remote_list._remote_cache = {} + new_remote_list._extend_without_processing(items) + return new_remote_list + + def _extend_without_processing(self, items: list[Any]) -> None: + """ + Extend the list without processing the items (i.e., without storing them again). + """ + super().extend(items) + + def __setitem__(self, index: int | slice, value: Any) -> None: + if isinstance(index, slice): + # Handle slice assignment without processing existing RemoteObjects + processed_values = [self.__save_item(v) for v in value] + super().__setitem__(index, processed_values) + # Clear cache for affected indices + for i in range(*index.indices(len(self))): + self._remote_cache.pop(i, None) + else: + stored_value = self.__save_item(value) + super().__setitem__(index, stored_value) + self._remote_cache.pop(index, None) + + def append(self, value: Any) -> None: + stored_value = self.__save_item(value) + super().append(stored_value) + # No need to cache immediately + + def insert(self, index: int, value: Any) -> None: + stored_value = self.__save_item(value) + super().insert(index, stored_value) + # Adjust cache indices + self._shift_cache_on_insert(index) + + def _shift_cache_on_insert(self, index: int) -> None: + new_cache = {} + for i, v in self._remote_cache.items(): + if i >= index: + new_cache[i + 1] = v + else: + new_cache[i] = v + self._remote_cache = new_cache + + def remove(self, value: Any) -> None: + # Find the index of the value to remove + index = self.index(value) + super().remove(value) + self._remote_cache.pop(index, None) + # Adjust cache indices + self._shift_cache_on_remove(index) + + def _shift_cache_on_remove(self, index: int) -> None: + new_cache = {} + for i, v in self._remote_cache.items(): + if i > index: + new_cache[i - 1] = v + elif i < index: + new_cache[i] = v + # Else: i == index, already removed + self._remote_cache = new_cache + + def pop(self, index: int = -1) -> Any: + value = super().pop(index) + # Adjust negative indices + if index < 0: + index = len(self) + index + self._remote_cache.pop(index, None) + # Adjust cache indices + self._shift_cache_on_remove(index) + return value + + def clear(self) -> None: + super().clear() + self._remote_cache.clear() + + def extend(self, iterable: list[Any]) -> None: + for item in iterable: + self.append(item) + + def __iter__(self) -> Iterator[Any]: + for index in range(len(self)): + yield self.__getitem__(index) + + def unload_item(self, index: int) -> None: + """Unload a specific item and replace it with a RemoteObject.""" + if index in self._remote_cache: + data = self._remote_cache.pop(index) + remote_obj = self.__save_item(data) + super().__setitem__(index, remote_obj) + + def unload_all(self) -> None: + """Unload all cached items.""" + for index in list(self._remote_cache.keys()): + self.unload_item(index) diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 8d46c57b7..ad7a08ada 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,31 +1,34 @@ -from dataclasses import dataclass from typing import Annotated, Any from uuid import UUID -from pydantic import BaseModel, Field, computed_field -from pydantic_partial import create_partial_model - -from ...autogen.openapi_model import ( - Agent, - CreateTaskRequest, - CreateTransitionRequest, - Execution, - ExecutionStatus, - PartialTaskSpecDef, - PatchTaskRequest, - Session, - Task, - TaskSpec, - TaskSpecDef, - TaskToolDef, - Tool, - TransitionTarget, - TransitionType, - UpdateTaskRequest, - User, - Workflow, - WorkflowStep, -) +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from pydantic import BaseModel, Field, computed_field + from pydantic_partial import create_partial_model + + from ...autogen.openapi_model import ( + Agent, + CreateTaskRequest, + CreateTransitionRequest, + Execution, + ExecutionStatus, + PartialTaskSpecDef, + PatchTaskRequest, + Session, + Task, + TaskSpec, + TaskSpecDef, + TaskToolDef, + Tool, + TransitionTarget, + TransitionType, + UpdateTaskRequest, + User, + Workflow, + WorkflowStep, + ) + from .remote import BaseRemoteModel, RemoteObject # TODO: Maybe we should use a library for this @@ -136,9 +139,9 @@ class ExecutionInput(BaseModel): session: Session | None = None -class StepContext(BaseModel): - execution_input: ExecutionInput - inputs: list[Any] +class StepContext(BaseRemoteModel): + execution_input: ExecutionInput | RemoteObject + inputs: list[Any] | RemoteObject cursor: TransitionTarget @computed_field @@ -216,11 +219,6 @@ class StepOutcome(BaseModel): transition_to: tuple[TransitionType, TransitionTarget] | None = None -@dataclass -class RemoteObject: - key: str - - def task_to_spec( task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts ) -> TaskSpecDef | PartialTaskSpecDef: diff --git a/agents-api/agents_api/common/storage_handler.py b/agents-api/agents_api/common/storage_handler.py index ca669620c..50b9e58f3 100644 --- a/agents-api/agents_api/common/storage_handler.py +++ b/agents-api/agents_api/common/storage_handler.py @@ -1,18 +1,22 @@ import inspect import sys +from datetime import timedelta from functools import wraps from typing import Any, Callable +from temporalio import workflow + +from ..activities.sync_items_remote import load_inputs_remote from ..clients import s3 -from ..common.protocol.tasks import RemoteObject -from ..env import blob_store_cutoff_kb, use_blob_store_for_temporal +from ..common.protocol.remote import BaseRemoteModel, RemoteList, RemoteObject +from ..common.retry_policies import DEFAULT_RETRY_POLICY +from ..env import blob_store_cutoff_kb, debug, testing, use_blob_store_for_temporal from ..worker.codec import deserialize, serialize -if use_blob_store_for_temporal: - s3.setup() - def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: + s3.setup() + serialized = serialize(x) data_size = sys.getsizeof(serialized) @@ -23,7 +27,9 @@ def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: return x -def load_from_blob_store_if_remote(x: Any) -> Any: +def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any: + s3.setup() + if isinstance(x, RemoteObject): fetched = s3.get_object(x.key) return deserialize(fetched) @@ -45,6 +51,12 @@ def load_args( return new_args, new_kwargs + def unload_return_value(x: Any | BaseRemoteModel | RemoteList) -> Any: + if isinstance(x, (BaseRemoteModel, RemoteList)): + x.unload_all() + + return store_in_blob_store_if_large(x) + if inspect.iscoroutinefunction(f): @wraps(f) @@ -52,7 +64,7 @@ async def async_wrapper(*args, **kwargs) -> Any: new_args, new_kwargs = load_args(args, kwargs) output = await f(*new_args, **new_kwargs) - return store_in_blob_store_if_large(output) + return unload_return_value(output) return async_wrapper if use_blob_store_for_temporal else f @@ -63,6 +75,29 @@ def wrapper(*args, **kwargs) -> Any: new_args, new_kwargs = load_args(args, kwargs) output = f(*new_args, **new_kwargs) - return store_in_blob_store_if_large(output) + return unload_return_value(output) return wrapper if use_blob_store_for_temporal else f + + +def auto_blob_store_workflow(f: Callable) -> Callable: + @wraps(f) + async def wrapper(*args, **kwargs) -> Any: + keys = kwargs.keys() + values = [kwargs[k] for k in keys] + + loaded = await workflow.execute_local_activity( + load_inputs_remote, + args=[[*args, *values]], + schedule_to_close_timeout=timedelta(seconds=10 if debug or testing else 60), + retry_policy=DEFAULT_RETRY_POLICY, + ) + + loaded_args = loaded[: len(args)] + loaded_kwargs = dict(zip(keys, loaded[len(args) :])) + + result = await f(*loaded_args, **loaded_kwargs) + + return result + + return wrapper diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index c69eae8f2..a0470acba 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -33,7 +33,7 @@ ) blob_store_bucket: str = env.str("BLOB_STORE_BUCKET", default="agents-api") -blob_store_cutoff_kb: int = env.int("BLOB_STORE_CUTOFF_KB", default=1024) +blob_store_cutoff_kb: int = env.int("BLOB_STORE_CUTOFF_KB", default=64) s3_endpoint: str = env.str("S3_ENDPOINT", default="http://seaweedfs:8333") s3_access_key: str | None = env.str("S3_ACCESS_KEY", default=None) s3_secret_key: str | None = env.str("S3_SECRET_KEY", default=None) diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index f63646e1c..e182de077 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -186,6 +186,7 @@ def make_cozo_json_query(fields): def cozo_query( func: Callable[P, tuple[str | list[str | None], dict]] | None = None, debug: bool | None = None, + only_on_error: bool = False, ): def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): """ @@ -209,8 +210,8 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: query = "}\n\n{\n".join(queries) query = f"{{ {query} }}" - debug and print(query) - debug and pprint( + not only_on_error and debug and print(query) + not only_on_error and debug and pprint( dict( variables=variables, ) @@ -224,13 +225,17 @@ def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: result = client.run(query, variables) except Exception as e: + if only_on_error and debug: + print(query) + pprint(variables) + debug and print(repr(getattr(e, "__cause__", None) or e)) raise # Need to fix the UUIDs in the result result = result.map(fix_uuid_if_present) - debug and pprint( + not only_on_error and debug and pprint( dict( result=result.to_dict(orient="records"), ) diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 54f2bcdd5..968a61ab9 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -21,6 +21,7 @@ def create_worker(client: Client) -> Any: from ..activities.mem_mgmt import mem_mgmt from ..activities.mem_rating import mem_rating from ..activities.summarization import summarization + from ..activities.sync_items_remote import load_inputs_remote, save_inputs_remote from ..activities.truncation import truncation from ..common.interceptors import CustomInterceptor from ..env import ( @@ -61,6 +62,8 @@ def create_worker(client: Client) -> Any: mem_rating, summarization, truncation, + save_inputs_remote, + load_inputs_remote, ], interceptors=[CustomInterceptor()], ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 155b49397..252f48537 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -4,16 +4,18 @@ from datetime import timedelta from typing import Any -from pydantic import RootModel from temporalio import workflow from temporalio.exceptions import ApplicationError # Import necessary modules and types with workflow.unsafe.imports_passed_through(): + from pydantic import RootModel + from ...activities import task_steps from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration from ...activities.execute_system import execute_system + from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote from ...autogen.openapi_model import ( ApiCallDef, ErrorWorkflowStep, @@ -39,6 +41,7 @@ YieldStep, ) from ...autogen.Tools import SystemDef + from ...common.protocol.remote import RemoteList from ...common.protocol.tasks import ( ExecutionInput, PartialTransition, @@ -124,7 +127,7 @@ async def run( self, execution_input: ExecutionInput, start: TransitionTarget = TransitionTarget(workflow="main", step=0), - previous_inputs: list[Any] = [], + previous_inputs: RemoteList | None = None, ) -> Any: workflow.logger.info( f"TaskExecutionWorkflow for task {execution_input.task.id}" @@ -132,7 +135,7 @@ async def run( ) # 0. Prepare context - previous_inputs = previous_inputs or [execution_input.arguments] + previous_inputs = previous_inputs or RemoteList([execution_input.arguments]) context = StepContext( execution_input=execution_input, @@ -144,8 +147,10 @@ async def run( # --- + continued_as_new = workflow.info().continued_run_id is not None + # 1. Transition to starting if not done yet - if context.is_first_step: + if context.is_first_step and not continued_as_new: await transition( context, type="init" if context.is_main else "init_branch", @@ -190,6 +195,13 @@ async def run( # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") + [outcome] = await workflow.execute_local_activity( + load_inputs_remote, + args=[[outcome]], + schedule_to_close_timeout=timedelta(seconds=10 if debug or testing else 60), + retry_policy=DEFAULT_RETRY_POLICY, + ) + match context.current_step, outcome: # Handle errors (activity returns None) case step, StepOutcome(error=error) if error is not None: @@ -558,10 +570,20 @@ def model_dump(obj): f"Continuing to next step: {final_state.next.workflow}.{final_state.next.step}" ) + # Save the final output to the blob store + [final_output] = await workflow.execute_local_activity( + save_inputs_remote, + args=[[final_state.output]], + schedule_to_close_timeout=timedelta(seconds=10 if debug or testing else 60), + retry_policy=DEFAULT_RETRY_POLICY, + ) + + previous_inputs.append(final_output) + # Continue as a child workflow return await continue_as_child( context.execution_input, start=final_state.next, - previous_inputs=previous_inputs + [final_state.output], + previous_inputs=previous_inputs, user_state=state.user_state, ) diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 271f33dbf..7fbc2c008 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -14,17 +14,20 @@ Workflow, WorkflowStep, ) + from ...common.protocol.remote import RemoteList from ...common.protocol.tasks import ( ExecutionInput, StepContext, ) + from ...common.storage_handler import auto_blob_store_workflow from ...env import task_max_parallelism +@auto_blob_store_workflow async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, ) -> Any: info = workflow.info() @@ -47,13 +50,14 @@ async def continue_as_child( ) +@auto_blob_store_workflow async def execute_switch_branch( *, context: StepContext, execution_input: ExecutionInput, switch: list, index: int, - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"Switch step: Chose branch {index}") @@ -77,6 +81,7 @@ async def execute_switch_branch( ) +@auto_blob_store_workflow async def execute_if_else_branch( *, context: StepContext, @@ -84,7 +89,7 @@ async def execute_if_else_branch( then_branch: WorkflowStep, else_branch: WorkflowStep, condition: bool, - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"If-Else step: Condition evaluated to {condition}") @@ -109,13 +114,14 @@ async def execute_if_else_branch( ) +@auto_blob_store_workflow async def execute_foreach_step( *, context: StepContext, execution_input: ExecutionInput, do_step: WorkflowStep, items: list[Any], - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, ) -> Any: workflow.logger.info(f"Foreach step: Iterating over {len(items)} items") @@ -143,13 +149,14 @@ async def execute_foreach_step( return results +@auto_blob_store_workflow async def execute_map_reduce_step( *, context: StepContext, execution_input: ExecutionInput, map_defn: WorkflowStep, items: list[Any], - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, reduce: str | None = None, initial: Any = [], @@ -186,13 +193,14 @@ async def execute_map_reduce_step( return result +@auto_blob_store_workflow async def execute_map_reduce_step_parallel( *, context: StepContext, execution_input: ExecutionInput, map_defn: WorkflowStep, items: list[Any], - previous_inputs: list[Any], + previous_inputs: RemoteList | list[Any], user_state: dict[str, Any] = {}, initial: Any = [], reduce: str | None = None, diff --git a/agents-api/docker-compose.yml b/agents-api/docker-compose.yml index d1a43d64a..40775abde 100644 --- a/agents-api/docker-compose.yml +++ b/agents-api/docker-compose.yml @@ -23,7 +23,7 @@ x--shared-environment: &shared-environment TRUNCATE_EMBED_TEXT: ${TRUNCATE_EMBED_TEXT:-True} WORKER_URL: ${WORKER_URL:-temporal:7233} USE_BLOB_STORE_FOR_TEMPORAL: ${USE_BLOB_STORE_FOR_TEMPORAL:-false} - BLOB_STORE_CUTOFF_KB: ${BLOB_STORE_CUTOFF_KB:-1024} + BLOB_STORE_CUTOFF_KB: ${BLOB_STORE_CUTOFF_KB:-128} BLOB_STORE_BUCKET: ${BLOB_STORE_BUCKET:-agents-api} S3_ENDPOINT: ${S3_ENDPOINT:-http://seaweedfs:8333} S3_ACCESS_KEY: ${S3_ACCESS_KEY} @@ -40,6 +40,8 @@ x--base-agents-api: &base-agents-api context: . dockerfile: Dockerfile + restart: on-failure + ports: - "8080:8080"