diff --git a/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py b/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py new file mode 100644 index 00000000..9794c5aa --- /dev/null +++ b/backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py @@ -0,0 +1,38 @@ +"""Add thread table + +Revision ID: 3a8a5f819c5f +Revises: 3b4636df4c6d +Create Date: 2024-05-25 13:57:57.773217 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '3a8a5f819c5f' +down_revision = '3b4636df4c6d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('thread', + sa.Column('query', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('team_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['team_id'], ['team.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_thread_id'), 'thread', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_thread_id'), table_name='thread') + op.drop_table('thread') + # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/3b4636df4c6d_add_interrupt_column_in_members_table.py b/backend/app/alembic/versions/3b4636df4c6d_add_interrupt_column_in_members_table.py new file mode 100644 index 00000000..a4144eae --- /dev/null +++ b/backend/app/alembic/versions/3b4636df4c6d_add_interrupt_column_in_members_table.py @@ -0,0 +1,34 @@ +"""Add interrupt column in members table + +Revision ID: 3b4636df4c6d +Revises: b922e3e826e9 +Create Date: 2024-05-25 01:04:31.805491 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '3b4636df4c6d' +down_revision = 'b922e3e826e9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Add the new column without the NOT NULL constraint + op.add_column('member', sa.Column('interrupt', sa.Boolean(), nullable=True)) + # Set the default value for existing rows + op.execute('UPDATE member SET interrupt = FALSE WHERE interrupt IS NULL') + # Alter the column to add the NOT NULL constraint + op.alter_column('member', 'interrupt', nullable=False, server_default=sa.false()) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('member', 'interrupt') + # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py b/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py new file mode 100644 index 00000000..9c5d8a45 --- /dev/null +++ b/backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py @@ -0,0 +1,38 @@ +"""Add checkpoints table + +Revision ID: 6fa42be09dd2 +Revises: 3a8a5f819c5f +Create Date: 2024-05-26 02:42:53.431421 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '6fa42be09dd2' +down_revision = '3a8a5f819c5f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('checkpoints', + sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column('parent_ts', sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column('checkpoint', sa.LargeBinary(), nullable=False), + sa.Column('metadata', sa.LargeBinary(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['thread_id'], ['thread.id'], ), + sa.PrimaryKeyConstraint('thread_id', 'thread_ts') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('checkpoints') + # ### end Alembic commands ### diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 0951ed47..7673faf9 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from app.api.routes import login, members, skills, teams, users, utils +from app.api.routes import login, members, skills, teams, threads, users, utils api_router = APIRouter() api_router.include_router(login.router, tags=["login"]) @@ -11,3 +11,6 @@ members.router, prefix="/teams/{team_id}/members", tags=["members"] ) api_router.include_router(skills.router, prefix="/skills", tags=["skills"]) +api_router.include_router( + threads.router, prefix="/teams/{team_id}/threads", tags=["threads"] +) diff --git a/backend/app/api/routes/teams.py b/backend/app/api/routes/teams.py index b31ab7a4..624b78cf 100644 --- a/backend/app/api/routes/teams.py +++ b/backend/app/api/routes/teams.py @@ -15,6 +15,7 @@ TeamOut, TeamsOut, TeamUpdate, + Thread, ) # TODO: To remove @@ -211,9 +212,13 @@ def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any: return Message(message="Team deleted successfully") -@router.post("/{id}/stream") +@router.post("/{id}/stream/{thread_id}") async def stream( - session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat + session: SessionDep, + current_user: CurrentUser, + id: int, + thread_id: str, + team_chat: TeamChat, ) -> StreamingResponse: """ Stream a response to a user's input. @@ -225,12 +230,21 @@ async def stream( if not current_user.is_superuser and (team.owner_id != current_user.id): raise HTTPException(status_code=400, detail="Not enough permissions") + # Check if thread belongs to the team + thread = session.get(Thread, thread_id) + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + if thread.team_id != id: + raise HTTPException( + status_code=400, detail="Thread does not belong to the team" + ) + # Populate the skills for each member members = team.members for member in members: member.skills = member.skills return StreamingResponse( - generator(team, members, team_chat.messages), + generator(team, members, team_chat.messages, thread_id), media_type="text/event-stream", ) diff --git a/backend/app/api/routes/threads.py b/backend/app/api/routes/threads.py new file mode 100644 index 00000000..27455f35 --- /dev/null +++ b/backend/app/api/routes/threads.py @@ -0,0 +1,209 @@ +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, HTTPException +from sqlmodel import col, func, select + +from app.api.deps import CurrentUser, SessionDep +from app.models import ( + Checkpoint, + CreateThreadOut, + Message, + Team, + Thread, + ThreadCreate, + ThreadOut, + ThreadsOut, + ThreadUpdate, +) + +router = APIRouter() + + +@router.get("/", response_model=ThreadsOut) +def read_threads( + session: SessionDep, + current_user: CurrentUser, + team_id: int, + skip: int = 0, + limit: int = 100, +) -> Any: + """ + Retrieve threads + """ + + if current_user.is_superuser: + count_statement = select(func.count()).select_from(Thread) + count = session.exec(count_statement).one() + statement = ( + select(Thread).where(Thread.team_id == team_id).offset(skip).limit(limit) + ) + threads = session.exec(statement).all() + else: + count_statement = ( + select(func.count()) + .select_from(Thread) + .join(Team) + .where(Team.owner_id == current_user.id, Thread.team_id == team_id) + ) + count = session.exec(count_statement).one() + statement = ( + select(Thread) + .join(Team) + .where(Team.owner_id == current_user.id, Thread.team_id == team_id) + .offset(skip) + .limit(limit) + ) + threads = session.exec(statement).all() + return ThreadsOut(data=threads, count=count) + + +@router.get("/{id}", response_model=CreateThreadOut) +def read_thread( + session: SessionDep, current_user: CurrentUser, team_id: int, id: UUID +) -> Any: + """ + Get thread and its last checkpoint by ID + """ + if current_user.is_superuser: + statement = ( + select(Thread) + .join(Team) + .where( + Thread.id == id, + Thread.team_id == team_id, + ) + ) + thread = session.exec(statement).first() + else: + statement = ( + select(Thread) + .join(Team) + .where( + Thread.id == id, + Thread.team_id == team_id, + Team.owner_id == current_user.id, + ) + ) + thread = session.exec(statement).first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + checkpoint_statement = ( + select(Checkpoint) + .where(Checkpoint.thread_id == thread.id) + .order_by(col(Checkpoint.created_at).desc()) + ) + checkpoint = session.exec(checkpoint_statement).first() + + return CreateThreadOut( + id=thread.id, + query=thread.query, + last_checkpoint=checkpoint, + updated_at=thread.updated_at, + ) + + +@router.post("/", response_model=ThreadOut) +def create_thread( + *, + session: SessionDep, + current_user: CurrentUser, + team_id: int, + thread_in: ThreadCreate, +) -> Any: + """ + Create new thread + """ + if not current_user.is_superuser: + team = session.get(Team, team_id) + if not team: + raise HTTPException(status_code=404, detail="Team not found.") + if team.owner_id != current_user.id: + raise HTTPException(status_code=400, detail="Not enough permissions") + thread = Thread.model_validate(thread_in, update={"team_id": team_id}) + session.add(thread) + session.commit() + session.refresh(thread) + return thread + + +@router.put("/{id}", response_model=ThreadOut) +def update_thread( + *, + session: SessionDep, + current_user: CurrentUser, + team_id: int, + id: UUID, + thread_in: ThreadUpdate, +) -> Any: + """ + Update a thread. + """ + if current_user.is_superuser: + statement = ( + select(Thread).join(Team).where(Thread.id == id, Thread.team_id == team_id) + ) + thread = session.exec(statement).first() + else: + statement = ( + select(Thread) + .join(Team) + .where( + Thread.id == id, + Thread.team_id == team_id, + Team.owner_id == current_user.id, + ) + ) + thread = session.exec(statement).first() + + if not thread: + raise HTTPException(status_code=404, detail="Member not found") + + update_dict = thread_in.model_dump(exclude_unset=True) + thread.sqlmodel_update(update_dict) + session.add(thread) + session.commit() + session.refresh(thread) + return thread + + +@router.delete("/{id}") +def delete_thread( + session: SessionDep, current_user: CurrentUser, team_id: int, id: UUID +) -> Any: + """ + Delete a thread. + """ + if current_user.is_superuser: + statement = ( + select(Thread) + .join(Team) + .where( + Thread.id == id, + Thread.team_id == team_id, + ) + ) + thread = session.exec(statement).first() + else: + statement = ( + select(Thread) + .join(Team) + .where( + Thread.id == id, + Thread.team_id == team_id, + Team.owner_id == current_user.id, + ) + ) + thread = session.exec(statement).first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + for checkpoint in thread.checkpoints: + session.delete(checkpoint) + + session.delete(thread) + session.commit() + return Message(message="Thread deleted successfully") diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 41de80f7..a1df6722 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -66,6 +66,11 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: path=self.POSTGRES_DB, ) + @computed_field # type: ignore[misc] + @property + def PG_DATABASE_URI(self) -> str: + return f"postgres://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}@{self.POSTGRES_SERVER}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}" + SMTP_TLS: bool = True SMTP_SSL: bool = False SMTP_PORT: int = 587 diff --git a/backend/app/core/graph/build.py b/backend/app/core/graph/build.py index 6923abd0..310dff6d 100644 --- a/backend/app/core/graph/build.py +++ b/backend/app/core/graph/build.py @@ -6,12 +6,15 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.runnables import RunnableLambda +from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph import END, StateGraph from langgraph.graph.graph import CompiledGraph from langgraph.prebuilt import ( ToolNode, ) +from app.core.config import settings +from app.core.graph.checkpoint.aiopostgres import AsyncPostgresSaver from app.core.graph.members import ( GraphLeader, GraphMember, @@ -143,6 +146,7 @@ def convert_sequential_team_to_dict(team: Team) -> dict[str, GraphMember]: provider=memberModel.provider, model=memberModel.model, temperature=memberModel.temperature, + interrupt=memberModel.interrupt, ) team_dict[graph_member.name] = graph_member for nei_id in out_counts[member_id]: @@ -318,7 +322,9 @@ def create_hierarchical_graph( return graph -def create_sequential_graph(team: dict[str, GraphMember]) -> CompiledGraph: +def create_sequential_graph( + team: dict[str, GraphMember], memory: BaseCheckpointSaver +) -> CompiledGraph: """ Creates a sequential graph from a list of team members. @@ -333,6 +339,8 @@ def create_sequential_graph(team: dict[str, GraphMember]) -> CompiledGraph: """ members: list[GraphMember] = [] graph = StateGraph(TeamState) + # Create a list to store member names that require human intervention before tool calling + interrupt_member_names = [] for i, member in enumerate(team.values()): graph.add_node( member.name, @@ -350,6 +358,9 @@ def create_sequential_graph(team: dict[str, GraphMember]) -> CompiledGraph: ) # After tools node is called, agent node is called next. graph.add_edge(f"{member.name}_tools", member.name) + # Check if member requires human intervention before tool calling + if member.interrupt: + interrupt_member_names.append(f"{member.name}_tools") if i > 0: # if previous member has tools, then the edge should conditionally call tool node if len(members[i - 1].tools) >= 1: @@ -371,7 +382,7 @@ def create_sequential_graph(team: dict[str, GraphMember]) -> CompiledGraph: else: graph.add_edge(members[-1].name, END) graph.set_entry_point(members[0].name) - return graph.compile() + return graph.compile(checkpointer=memory, interrupt_before=interrupt_member_names) def convert_messages_and_tasks_to_dict(data: Any) -> Any: @@ -393,7 +404,7 @@ def convert_messages_and_tasks_to_dict(data: Any) -> Any: async def generator( - team: Team, members: list[Member], messages: list[ChatMessage] + team: Team, members: list[Member], messages: list[ChatMessage], thread_id: str ) -> AsyncGenerator[Any, Any]: """Create the graph and stream responses as JSON.""" formatted_messages = [ @@ -403,6 +414,8 @@ async def generator( for message in messages ] + memory = await AsyncPostgresSaver.from_conn_string(settings.PG_DATABASE_URI) + if team.workflow == "hierarchical": teams = convert_hierarchical_team_to_dict(team, members) team_leader = list(teams.keys())[0] @@ -414,7 +427,7 @@ async def generator( } else: member_dict = convert_sequential_team_to_dict(team) - root = create_sequential_graph(member_dict) + root = create_sequential_graph(member_dict, memory) state = { "messages": formatted_messages, "team_name": team.name, @@ -422,7 +435,10 @@ async def generator( "next": list(member_dict.values())[0].name, } async for output in root.astream_events( - state, version="v1", include_names=["work", "delegate", "summarise"] + state, + version="v1", + include_names=["work", "delegate", "summarise"], + config={"configurable": {"thread_id": thread_id}}, ): if output["event"] == "on_chain_end": output_data = output["data"]["output"] diff --git a/backend/app/core/graph/checkpoint/__init__.py b/backend/app/core/graph/checkpoint/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/core/graph/checkpoint/aiopostgres.py b/backend/app/core/graph/checkpoint/aiopostgres.py new file mode 100644 index 00000000..99c8eea8 --- /dev/null +++ b/backend/app/core/graph/checkpoint/aiopostgres.py @@ -0,0 +1,416 @@ +import asyncio +import functools +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import TypeVar + +import asyncpg +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.serde.base import SerializerProtocol +from typing_extensions import Self + +from app.core.graph.checkpoint.postgres import JsonPlusSerializerCompat, search_where + +T = TypeVar("T", bound=callable) # type: ignore[valid-type] + + +def not_implemented_sync_method(func: T) -> T: + @functools.wraps(func) + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] + raise NotImplementedError( + "The AsyncPostgresSaver does not support synchronous methods. " + "Consider using the PostgresSaver instead.\n" + "from langgraph.checkpoint.postgres import PostgresSaver\n" + "See https://langchain-ai.github.io/langgraph/reference/checkpoints/#postgressaver " + "for more information." + ) + + return wrapper # type: ignore[return-value] + + +class AsyncPostgresSaver(BaseCheckpointSaver, AbstractAsyncContextManager): # type: ignore[type-arg] + """An asynchronous checkpoint saver that stores checkpoints in a PostgreSQL database. + + Tip: + Requires the [asyncpg](https://pypi.org/project/asyncpg/) package. + Install it with `pip install asyncpg`. + + Args: + conn (asyncpg.Connection): The asynchronous PostgreSQL database connection. + serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat. + + Examples: + Usage within a StateGraph: + ```pycon + >>> import asyncio + >>> import asyncpg + >>> + >>> from langgraph.checkpoint.postgres import AsyncPostgresSaver + >>> from langgraph.graph import StateGraph + >>> + >>> builder = StateGraph(int) + >>> builder.add_node("add_one", lambda x: x + 1) + >>> builder.set_entry_point("add_one") + >>> builder.set_finish_point("add_one") + >>> memory = AsyncPostgresSaver.from_conn_string("postgresql://user:password@localhost/dbname") + >>> graph = builder.compile(checkpointer=memory) + >>> coro = graph.ainvoke(1, {"configurable": {"thread_id": "thread-1"}}) + >>> asyncio.run(coro) + Output: 2 + ``` + + Raw usage: + ```pycon + >>> import asyncio + >>> import asyncpg + >>> from langgraph.checkpoint.postgres import AsyncPostgresSaver + >>> + >>> async def main(): + >>> conn = await asyncpg.connect("postgresql://user:password@localhost/dbname") + ... saver = AsyncPostgresSaver(conn) + ... config = {"configurable": {"thread_id": "1"}} + ... checkpoint = {"ts": "2023-05-03T10:00:00Z", "data": {"key": "value"}} + ... saved_config = await saver.aput(config, checkpoint) + ... print(saved_config) + >>> asyncio.run(main()) + {"configurable": {"thread_id": "1", "thread_ts": "2023-05-03T10:00:00Z"}} + ``` + """ + + serde = JsonPlusSerializerCompat() + + conn: asyncpg.Connection # type: ignore[type-arg] + lock: asyncio.Lock + is_setup: bool + + def __init__( + self, + conn: asyncpg.Connection, # type: ignore[type-arg] + *, + serde: SerializerProtocol | None = None, + ): + super().__init__(serde=serde) + self.conn = conn + self.lock = asyncio.Lock() + self.is_setup = False + + @classmethod + async def from_conn_string(cls, conn_string: str) -> "AsyncPostgresSaver": + """Create a new AsyncPostgresSaver instance from a connection string. + + Args: + conn_string (str): The PostgreSQL connection string. + + Returns: + AsyncPostgresSaver: A new AsyncPostgresSaver instance. + """ + conn = await asyncpg.connect(conn_string) + return AsyncPostgresSaver(conn=conn) + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + if self.is_setup: + await self.conn.close() + return None + + @not_implemented_sync_method + def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + """Get a checkpoint tuple from the database. + + Note: + This method is not implemented for the AsyncPostgresSaver. Use `aget` instead. + Or consider using the [PostgresSaver](#postgressaver) checkpointer. + """ + + @not_implemented_sync_method + def list( # type: ignore[empty-body] + self, + config: RunnableConfig, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + Note: + This method is not implemented for the AsyncPostgresSaver. Use `alist` instead. + Or consider using the [PostgresSaver](#postgressaver) checkpointer. + """ + + @not_implemented_sync_method + def search( # type: ignore[empty-body] + self, + metadata_filter: CheckpointMetadata, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + """Search for checkpoints by metadata. + + Note: + This method is not implemented for the AsyncPostgresSaver. Use `asearch` instead. + Or consider using the [PostgresSaver](#postgressaver) checkpointer. + """ + + @not_implemented_sync_method + def put( # type: ignore[empty-body] + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + ) -> RunnableConfig: + """Save a checkpoint to the database.""" + + async def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the PostgreSQL database if they don't + already exist. It is called automatically when needed and should not be called + directly by the user. + """ + async with self.lock: + if self.is_setup: + return + + await self.conn.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + thread_ts TEXT NOT NULL, + parent_ts TEXT, + checkpoint BYTEA, + metadata BYTEA, + PRIMARY KEY (thread_id, thread_ts) + ); + """ + ) + + self.is_setup = True + + async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + """Get a checkpoint tuple from the database asynchronously. + + This method retrieves a checkpoint tuple from the PostgreSQL database based on the + provided config. If the config contains a "thread_ts" key, the checkpoint with + the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint + for the given thread ID is retrieved. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + await self.setup() + if config["configurable"].get("thread_ts"): + result = await self.conn.fetchrow( + "SELECT checkpoint, parent_ts, metadata FROM checkpoints WHERE thread_id = $1 AND thread_ts = $2", + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + ) + if result: + return CheckpointTuple( + config, + self.serde.loads(result["checkpoint"]), + self.serde.loads(result["metadata"]) if result["metadata"] else {}, + ( + { + "configurable": { + "thread_id": config["configurable"]["thread_id"], + "thread_ts": result["parent_ts"], + } + } + if result["parent_ts"] + else None + ), + ) + else: + result = await self.conn.fetchrow( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC LIMIT 1", + str(config["configurable"]["thread_id"]), + ) + if result: + return CheckpointTuple( + { + "configurable": { + "thread_id": result["thread_id"], + "thread_ts": result["thread_ts"], + } + }, + self.serde.loads(result["checkpoint"]), + self.serde.loads(result["metadata"]) if result["metadata"] else {}, + ( + { + "configurable": { + "thread_id": result["thread_id"], + "thread_ts": result["parent_ts"], + } + } + if result["parent_ts"] + else None + ), + ) + return None + + async def alist( + self, + config: RunnableConfig, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> AsyncIterator[CheckpointTuple]: + """List checkpoints from the database asynchronously. + + This method retrieves a list of checkpoint tuples from the PostgreSQL database based + on the provided config. The checkpoints are ordered by timestamp in descending order. + + Args: + config (RunnableConfig): The config to use for listing the checkpoints. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + + Yields: + AsyncIterator[CheckpointTuple]: An asynchronous iterator of checkpoint tuples. + """ + await self.setup() + query = ( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = $1 ORDER BY thread_ts DESC" + if before is None + else "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = $1 AND thread_ts < $2 ORDER BY thread_ts DESC" + ) + if limit: + query += f" LIMIT {limit}" + params = ( + (str(config["configurable"]["thread_id"]),) + if before is None + else ( + str(config["configurable"]["thread_id"]), + str(before["configurable"]["thread_ts"]), + ) + ) + async with self.conn.transaction(): + async for record in self.conn.cursor(query, *params): + yield CheckpointTuple( + { + "configurable": { + "thread_id": record["thread_id"], + "thread_ts": record["thread_ts"], + } + }, + self.serde.loads(record["checkpoint"]), + self.serde.loads(record["metadata"]) if record["metadata"] else {}, + ( + { + "configurable": { + "thread_id": record["thread_id"], + "thread_ts": record["parent_ts"], + } + } + if record["parent_ts"] + else None + ), + ) + + async def asearch( + self, + metadata_filter: CheckpointMetadata, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> AsyncIterator[CheckpointTuple]: + """Search for checkpoints by metadata asynchronously. + + This method retrieves a list of checkpoint tuples from the PostgreSQL + database based on the provided metadata filter. The metadata filter does + not need to contain all keys defined in the CheckpointMetadata class. + The checkpoints are ordered by timestamp in descending order. + + Args: + metadata_filter (CheckpointMetadata): The metadata filter to use for searching the checkpoints. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + + Yields: + Iterator[CheckpointTuple]: An iterator of checkpoint tuples. + """ + await self.setup() + + # construct query + SELECT = "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints " + WHERE, params = search_where(metadata_filter, before) + ORDER_BY = "ORDER BY thread_ts DESC " + LIMIT = f"LIMIT {limit}" if limit else "" + + query = f"{SELECT}{WHERE}{ORDER_BY}{LIMIT}" + + # execute query + async with self.conn.transaction(): + async for record in self.conn.cursor(query, *params): + yield CheckpointTuple( + { + "configurable": { + "thread_id": record["thread_id"], + "thread_ts": record["thread_ts"], + } + }, + self.serde.loads(record["checkpoint"]), + self.serde.loads(record["metadata"]) if record["metadata"] else {}, + ( + { + "configurable": { + "thread_id": record["thread_id"], + "thread_ts": record["parent_ts"], + } + } + if record["parent_ts"] + else None + ), + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + ) -> RunnableConfig: + """Save a checkpoint to the database asynchronously. + + This method saves a checkpoint to the PostgreSQL database. The checkpoint is associated + with the provided config and its parent config (if any). + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + + Returns: + RunnableConfig: The updated config containing the saved checkpoint's timestamp. + """ + await self.setup() + await self.conn.execute( + "INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata", + str(config["configurable"]["thread_id"]), + checkpoint["id"], + config["configurable"].get("thread_ts"), + self.serde.dumps(checkpoint), + self.serde.dumps(metadata), + ) + return { + "configurable": { + "thread_id": config["configurable"]["thread_id"], + "thread_ts": checkpoint["id"], + } + } diff --git a/backend/app/core/graph/checkpoint/postgres.py b/backend/app/core/graph/checkpoint/postgres.py new file mode 100644 index 00000000..06f79dd0 --- /dev/null +++ b/backend/app/core/graph/checkpoint/postgres.py @@ -0,0 +1,585 @@ +import json +import pickle +from collections.abc import AsyncIterator, Iterator +from contextlib import AbstractContextManager, contextmanager +from threading import Lock +from types import TracebackType +from typing import Any + +import psycopg2 +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.serde.base import SerializerProtocol +from langgraph.serde.jsonplus import JsonPlusSerializer +from typing_extensions import Self + + +class JsonPlusSerializerCompat(JsonPlusSerializer): + """A serializer that supports loading pickled checkpoints for backwards compatibility. + + This serializer extends the JsonPlusSerializer and adds support for loading pickled + checkpoints. If the input data starts with b"\x80" and ends with b".", it is treated + as a pickled checkpoint and loaded using pickle.loads(). Otherwise, the default + JsonPlusSerializer behavior is used. + + Examples: + >>> import pickle + >>> from langgraph.checkpoint.postgres import JsonPlusSerializerCompat + >>> + >>> serializer = JsonPlusSerializerCompat() + >>> pickled_data = pickle.dumps({"key": "value"}) + >>> loaded_data = serializer.loads(pickled_data) + >>> print(loaded_data) # Output: {"key": "value"} + >>> + >>> json_data = '{"key": "value"}'.encode("utf-8") + >>> loaded_data = serializer.loads(json_data) + >>> print(loaded_data) # Output: {"key": "value"} + """ + + def loads(self, data: bytes) -> Any: + if data.startswith(b"\x80") and data.endswith(b"."): + return pickle.loads(data) + return super().loads(data) + + +_AIO_ERROR_MSG = ( + "The PostgresSaver does not support async methods. " + "Consider using AsyncPostgresSaver instead.\n" + "Note: AsyncPostgresSaver requires an async PostgreSQL driver to use.\n" + "See https://langchain-ai.github.io/langgraph/reference/checkpoints/#asyncpostgressaver" + "for more information." +) + + +class PostgresSaver(BaseCheckpointSaver, AbstractContextManager): # type: ignore[type-arg] + """A checkpoint saver that stores checkpoints in a PostgreSQL database. + + Note: + This class is meant for lightweight, synchronous use cases + (demos and small projects) and does not + scale to multiple threads. + For a similar PostgreSQL saver with `async` support, + consider using AsyncPostgresSaver. + + Args: + conn (psycopg2.extensions.connection): The PostgreSQL database connection. + serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat. + + Examples: + + >>> import psycopg2 + >>> from langgraph.checkpoint.postgres import PostgresSaver + >>> from langgraph.graph import StateGraph + >>> + >>> builder = StateGraph(int) + >>> builder.add_node("add_one", lambda x: x + 1) + >>> builder.set_entry_point("add_one") + >>> builder.set_finish_point("add_one") + >>> conn = psycopg2.connect("dbname=test user=postgres password=secret") + >>> memory = PostgresSaver(conn) + >>> graph = builder.compile(checkpointer=memory) + >>> config = {"configurable": {"thread_id": "1"}} + >>> graph.get_state(config) + >>> result = graph.invoke(3, config) + >>> graph.get_state(config) + StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '2024-05-04T06:32:42.235444+00:00'}}, parent_config=None) + """ # noqa + + serde = JsonPlusSerializerCompat() + + conn: psycopg2.extensions.connection + is_setup: bool + + def __init__( + self, + conn: psycopg2.extensions.connection, + *, + serde: SerializerProtocol | None = None, + ) -> None: + super().__init__(serde=serde) + self.conn = conn + self.is_setup = False + self.lock = Lock() + + @classmethod + def from_conn_string(cls, conn_string: str) -> "PostgresSaver": + """Create a new PostgresSaver instance from a connection string. + + Args: + conn_string (str): The PostgreSQL connection string. + + Returns: + PostgresSaver: A new PostgresSaver instance. + + Examples: + + To disk: + + memory = PostgresSaver.from_conn_string("dbname=test user=postgres password=secret") + """ + return PostgresSaver(conn=psycopg2.connect(conn_string)) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + self.conn.close() + return None + + def setup(self) -> None: + """Set up the checkpoint database. + + This method creates the necessary tables in the PostgreSQL database if they don't + already exist. It is called automatically when needed and should not be called + directly by the user. + """ + if self.is_setup: + return + + with self.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id TEXT NOT NULL, + thread_ts TEXT NOT NULL, + parent_ts TEXT, + checkpoint BYTEA, + metadata BYTEA, + PRIMARY KEY (thread_id, thread_ts) + ); + """ + ) + + self.is_setup = True + + @contextmanager + def cursor(self, transaction: bool = True) -> Iterator[psycopg2.extensions.cursor]: + """Get a cursor for the PostgreSQL database. + + This method returns a cursor for the PostgreSQL database. It is used internally + by the PostgresSaver and should not be called directly by the user. + + Args: + transaction (bool): Whether to commit the transaction when the cursor is closed. Defaults to True. + + Yields: + psycopg2.extensions.cursor: A cursor for the PostgreSQL database. + """ + self.setup() + cur = self.conn.cursor() + try: + yield cur + finally: + if transaction: + self.conn.commit() + cur.close() + + def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the PostgreSQL database based on the + provided config. If the config contains a "thread_ts" key, the checkpoint with + the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint + for the given thread ID is retrieved. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + + Examples: + + Basic: + >>> config = {"configurable": {"thread_id": "1"}} + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + + With timestamp: + + >>> config = { + ... "configurable": { + ... "thread_id": "1", + ... "thread_ts": "2024-05-04T06:32:42.235444+00:00", + ... } + ... } + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + """ # noqa + with self.cursor(transaction=False) as cur: + if config["configurable"].get("thread_ts"): + cur.execute( + "SELECT checkpoint, parent_ts, metadata FROM checkpoints WHERE thread_id = %s AND thread_ts = %s", + ( + str(config["configurable"]["thread_id"]), + str(config["configurable"]["thread_ts"]), + ), + ) + if value := cur.fetchone(): + return CheckpointTuple( + config, + self.serde.loads(value[0]), + self.serde.loads(value[2]) if value[2] is not None else {}, + ( + { + "configurable": { + "thread_id": config["configurable"]["thread_id"], + "thread_ts": value[1], + } + } + if value[1] + else None + ), + ) + else: + cur.execute( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = %s ORDER BY thread_ts DESC LIMIT 1", + (str(config["configurable"]["thread_id"]),), + ) + if value := cur.fetchone(): + return CheckpointTuple( + { + "configurable": { + "thread_id": value[0], + "thread_ts": value[1], + } + }, + self.serde.loads(value[3]), + self.serde.loads(value[4]) if value[4] is not None else {}, + ( + { + "configurable": { + "thread_id": value[0], + "thread_ts": value[2], + } + } + if value[2] + else None + ), + ) + return None + + def list( + self, + config: RunnableConfig, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the PostgreSQL database based + on the provided config. The checkpoints are ordered by timestamp in descending order. + + Args: + config (RunnableConfig): The config to use for listing the checkpoints. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + + Yields: + Iterator[CheckpointTuple]: An iterator of checkpoint tuples. + + Examples: + >>> from langgraph.checkpoint.postgres import PostgresSaver + >>> memory = PostgresSaver.from_conn_string("dbname=test user=postgres password=secret") + ... # Run a graph, then list the checkpoints + >>> config = {"configurable": {"thread_id": "1"}} + >>> checkpoints = list(memory.list(config, limit=2)) + >>> print(checkpoints) + [CheckpointTuple(...), CheckpointTuple(...)] + + >>> config = {"configurable": {"thread_id": "1"}} + >>> before = {"configurable": {"thread_ts": "2024-05-04T06:32:42.235444+00:00"}} + >>> checkpoints = list(memory.list(config, before=before)) + >>> print(checkpoints) + [CheckpointTuple(...), ...] + """ + query = ( + "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = %s ORDER BY thread_ts DESC" + if before is None + else "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints WHERE thread_id = %s AND thread_ts < %s ORDER BY thread_ts DESC" + ) + if limit: + query += f" LIMIT {limit}" + with self.cursor(transaction=False) as cur: + cur.execute( + query, + ( + (str(config["configurable"]["thread_id"]),) + if before is None + else ( + str(config["configurable"]["thread_id"]), + before["configurable"]["thread_ts"], + ) + ), + ) + for thread_id, thread_ts, parent_ts, value, metadata in cur: + yield CheckpointTuple( + {"configurable": {"thread_id": thread_id, "thread_ts": thread_ts}}, + self.serde.loads(value), + self.serde.loads(metadata) if metadata is not None else {}, + ( + { + "configurable": { + "thread_id": thread_id, + "thread_ts": parent_ts, + } + } + if parent_ts + else None + ), + ) + + def search( + self, + metadata_filter: CheckpointMetadata, + *, + before: RunnableConfig | None = None, + limit: int | None = None, + ) -> Iterator[CheckpointTuple]: + """Search for checkpoints by metadata. + + This method retrieves a list of checkpoint tuples from the PostgreSQL + database based on the provided metadata filter. The metadata filter does + not need to contain all keys defined in the CheckpointMetadata class. + The checkpoints are ordered by timestamp in descending order. + + Args: + metadata_filter (CheckpointMetadata): The metadata filter to use for searching the checkpoints. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None. + limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None. + + Yields: + Iterator[CheckpointTuple]: An iterator of checkpoint tuples. + """ + # construct query + SELECT = "SELECT thread_id, thread_ts, parent_ts, checkpoint, metadata FROM checkpoints " + WHERE, params = search_where(metadata_filter, before) + ORDER_BY = "ORDER BY thread_ts DESC " + LIMIT = f"LIMIT {limit}" if limit else "" + + query = f"{SELECT}{WHERE}{ORDER_BY}{LIMIT}" + + # execute query + with self.cursor(transaction=False) as cur: + cur.execute(query, params) + + for thread_id, thread_ts, parent_ts, value, metadata in cur: + yield CheckpointTuple( + {"configurable": {"thread_id": thread_id, "thread_ts": thread_ts}}, + self.serde.loads(value), + self.serde.loads(metadata) if metadata is not None else {}, + ( + { + "configurable": { + "thread_id": thread_id, + "thread_ts": parent_ts, + } + } + if parent_ts + else None + ), + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the PostgreSQL database. The checkpoint is associated + with the provided config and its parent config (if any). + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None. + + Returns: + RunnableConfig: The updated config containing the saved checkpoint's timestamp. + + Examples: + + >>> from langgraph.checkpoint.postgres import PostgresSaver + >>> memory = PostgresSaver.from_conn_string("dbname=test user=postgres password=secret") + ... # Run a graph, then list the checkpoints + >>> config = {"configurable": {"thread_id": "1"}} + >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "data": {"key": "value"}} + >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}) + >>> print(saved_config) + {"configurable": {"thread_id": "1", "thread_ts": 2024-05-04T06:32:42.235444+00:00"}} + """ + with self.lock, self.cursor() as cur: + cur.execute( + "INSERT INTO checkpoints (thread_id, thread_ts, parent_ts, checkpoint, metadata) VALUES (%s, %s, %s, %s, %s) ON CONFLICT (thread_id, thread_ts) DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata", + ( + str(config["configurable"]["thread_id"]), + checkpoint["id"], + config["configurable"].get("thread_ts"), + self.serde.dumps(checkpoint), + json.dumps(metadata), + ), + ) + return { + "configurable": { + "thread_id": config["configurable"]["thread_id"], + "thread_ts": checkpoint["id"], + } + } + + +def search_where( + metadata_filter: CheckpointMetadata, + before: RunnableConfig | None = None, +) -> tuple[str, tuple[Any, ...]]: + """Return WHERE clause predicates for (a)search() given metadata filter + and `before` config. + + This method returns a tuple of a string and a tuple of values. The string + is the parametered WHERE clause predicate (including the WHERE keyword): + "WHERE column1 = $1 AND column2 IS $2". The tuple of values contains the + values for each of the corresponding parameters. + """ + where = "WHERE " + param_values = () + + # construct predicate for metadata filter + metadata_predicate, metadata_values = _metadata_predicate(metadata_filter) + if metadata_predicate != "": + where += metadata_predicate + param_values += metadata_values + + # construct predicate for `before` + if before is not None: + if metadata_predicate != "": + where += "AND thread_ts < %s " + else: + where += "thread_ts < %s " + + param_values += (before["configurable"]["thread_ts"],) # type: ignore[assignment] + + if where == "WHERE ": + # no predicates, return an empty WHERE clause string + return ("", ()) + else: + return (where, param_values) + + +def _metadata_predicate( + metadata_filter: CheckpointMetadata, +) -> tuple[str, tuple[Any, ...]]: + """Return WHERE clause predicates for (a)search() given metadata filter. + + This method returns a tuple of a string and a tuple of values. The string + is the parametered WHERE clause predicate (excluding the WHERE keyword): + "column1 = $1 AND column2 IS $2". The tuple of values contains the values + for each of the corresponding parameters. + """ + + def _where_value(query_value: Any) -> tuple[str, Any]: + """Return tuple of operator and value for WHERE clause predicate.""" + if query_value is None: + return ("IS %s", None) + elif ( + isinstance(query_value, str) + or isinstance(query_value, int) + or isinstance(query_value, float) + ): + return ("= %s", query_value) + elif isinstance(query_value, bool): + return ("= %s", 1 if query_value else 0) + elif isinstance(query_value, dict) or isinstance(query_value, list): + # query value for JSON object cannot have trailing space after separators (, :) + # PostgreSQL jsonb fields are stored without whitespace + return ("= %s", json.dumps(query_value, separators=(",", ":"))) + else: + return ("= %s", str(query_value)) + + predicate = "" + param_values = () + + # process metadata query + for query_key, query_value in metadata_filter.items(): + operator, param_value = _where_value(query_value) + predicate += f"metadata->>'{query_key}' {operator} AND " + param_values += (param_value,) # type: ignore[assignment] + + if predicate != "": + # remove trailing AND + predicate = predicate[:-4] + + # predicate contains an extra trailing space + return (predicate, param_values) + + +async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: # type: ignore[no-untyped-def] + """Get a checkpoint tuple from the database asynchronously. + + Note: + This async method is not supported by the PostgresSaver class. + Use get_tuple() instead, or consider using [AsyncPostgresSaver](#asyncpostgressaver). + """ + raise NotImplementedError(_AIO_ERROR_MSG) + + +def alist( # type: ignore[misc, no-untyped-def] + self, + config: RunnableConfig, + *, + before: RunnableConfig | None = None, + limit: int | None = None, +) -> AsyncIterator[CheckpointTuple]: + """List checkpoints from the database asynchronously. + + Note: + This async method is not supported by the PostgresSaver class. + Use list() instead, or consider using [AsyncPostgresSaver](#asyncpostgressaver). + """ + raise NotImplementedError(_AIO_ERROR_MSG) + yield + + +def asearch( # type: ignore[misc, no-untyped-def] + self, + metadata_filter: CheckpointMetadata, + *, + before: RunnableConfig | None = None, + limit: int | None = None, +) -> AsyncIterator[CheckpointTuple]: + """Search for checkpoints by metadata asynchronously. + + Note: + This async method is not supported by the PostgresSaver class. + Use search() instead, or consider using [AsyncPostgresSaver](#asyncpostgressaver). + """ + raise NotImplementedError(_AIO_ERROR_MSG) + yield + + +async def aput( # type: ignore[no-untyped-def] + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, +) -> RunnableConfig: + """Save a checkpoint to the database asynchronously. + + Note: + This async method is not supported by the PostgresSaver class. + Use put() instead, or consider using [AsyncPostgresSaver](#asyncpostgressaver). + """ + raise NotImplementedError(_AIO_ERROR_MSG) diff --git a/backend/app/core/graph/members.py b/backend/app/core/graph/members.py index 4b7db324..a03be8f4 100644 --- a/backend/app/core/graph/members.py +++ b/backend/app/core/graph/members.py @@ -27,6 +27,10 @@ class GraphMember(GraphPerson): description="Description of the person's experience, motives and concerns." ) tools: list[str] = Field(description="The list of tools that the person can use.") + interrupt: bool = Field( + default=False, + description="Whether to interrupt the person or not before skill use", + ) @property def persona(self) -> str: diff --git a/backend/app/models.py b/backend/app/models.py index 19f5ad0c..e9656385 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,8 +1,10 @@ +from datetime import datetime from enum import Enum +from uuid import UUID, uuid4 from pydantic import BaseModel from pydantic import Field as PydanticField -from sqlalchemy import UniqueConstraint +from sqlalchemy import Column, DateTime, PrimaryKeyConstraint, UniqueConstraint, func from sqlmodel import Field, Relationship, SQLModel @@ -122,6 +124,9 @@ class Team(TeamBase, table=True): back_populates="belongs", sa_relationship_kwargs={"cascade": "delete"} ) workflow: str # TODO: This should be an enum 'sequential' and 'hierarchical' + threads: list["Thread"] = Relationship( + back_populates="team", sa_relationship_kwargs={"cascade": "delete"} + ) # Properties to return via API, id is always required @@ -136,6 +141,60 @@ class TeamsOut(SQLModel): count: int +# =============Threads=================== + + +class ThreadBase(SQLModel): + query: str + + +class ThreadCreate(ThreadBase): + pass + + +class ThreadUpdate(ThreadBase): + query: str | None = None # type: ignore[assignment] + updated_at: datetime | None = None + + +class Thread(ThreadBase, table=True): + id: UUID | None = Field( + default_factory=uuid4, + primary_key=True, + index=True, + nullable=False, + ) + updated_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + onupdate=func.now(), + server_default=func.now(), + ) + ) + team_id: int | None = Field(default=None, foreign_key="team.id", nullable=False) + team: Team | None = Relationship(back_populates="threads") + checkpoints: list["Checkpoint"] = Relationship( + back_populates="thread", sa_relationship_kwargs={"cascade": "delete"} + ) + + +class ThreadOut(SQLModel): + id: UUID + query: str + updated_at: datetime + + +class CreateThreadOut(ThreadOut): + last_checkpoint: "CheckpointOut" + + +class ThreadsOut(SQLModel): + data: list[ThreadOut] + count: int + + # ==============MEMBER========================= @@ -158,6 +217,7 @@ class MemberBase(SQLModel): provider: str = "ChatOpenAI" model: str = "gpt-3.5-turbo" temperature: float = 0.7 + interrupt: bool = False class MemberCreate(MemberBase): @@ -176,6 +236,7 @@ class MemberUpdate(MemberBase): provider: str | None = None # type: ignore[assignment] model: str | None = None # type: ignore[assignment] temperature: float | None = None # type: ignore[assignment] + interrupt: bool | None = None # type: ignore[assignment] class Member(MemberBase, table=True): @@ -225,3 +286,32 @@ class SkillsOut(SQLModel): class SkillOut(SkillBase): id: int description: str | None + + +# ==============CHECKPOINT===================== + + +class Checkpoint(SQLModel, table=True): + __tablename__ = "checkpoints" + __table_args__ = (PrimaryKeyConstraint("thread_id", "thread_ts"),) + thread_id: UUID = Field(foreign_key="thread.id", primary_key=True) + thread_ts: UUID = Field(primary_key=True) + parent_ts: UUID | None + checkpoint: bytes + metadata_: bytes = Field(sa_column_kwargs={"name": "metadata"}) + thread: Thread = Relationship(back_populates="checkpoints") + created_at: datetime | None = Field( + sa_column=Column( + DateTime(timezone=True), + nullable=False, + default=func.now(), + server_default=func.now(), + ) + ) + + +class CheckpointOut(SQLModel): + thread_id: UUID + thread_ts: UUID + checkpoint: bytes + created_at: datetime diff --git a/backend/poetry.lock b/backend/poetry.lock index da788fac..b01fcac1 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -209,6 +209,78 @@ files = [ {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] +[[package]] +name = "asyncpg" +version = "0.29.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "asyncpg-0.29.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72fd0ef9f00aeed37179c62282a3d14262dbbafb74ec0ba16e1b1864d8a12169"}, + {file = "asyncpg-0.29.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:52e8f8f9ff6e21f9b39ca9f8e3e33a5fcdceaf5667a8c5c32bee158e313be385"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e6823a7012be8b68301342ba33b4740e5a166f6bbda0aee32bc01638491a22"}, + {file = "asyncpg-0.29.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:746e80d83ad5d5464cfbf94315eb6744222ab00aa4e522b704322fb182b83610"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ff8e8109cd6a46ff852a5e6bab8b0a047d7ea42fcb7ca5ae6eaae97d8eacf397"}, + {file = "asyncpg-0.29.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97eb024685b1d7e72b1972863de527c11ff87960837919dac6e34754768098eb"}, + {file = "asyncpg-0.29.0-cp310-cp310-win32.whl", hash = "sha256:5bbb7f2cafd8d1fa3e65431833de2642f4b2124be61a449fa064e1a08d27e449"}, + {file = "asyncpg-0.29.0-cp310-cp310-win_amd64.whl", hash = "sha256:76c3ac6530904838a4b650b2880f8e7af938ee049e769ec2fba7cd66469d7772"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4900ee08e85af01adb207519bb4e14b1cae8fd21e0ccf80fac6aa60b6da37b4"}, + {file = "asyncpg-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a65c1dcd820d5aea7c7d82a3fdcb70e096f8f70d1a8bf93eb458e49bfad036ac"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b52e46f165585fd6af4863f268566668407c76b2c72d366bb8b522fa66f1870"}, + {file = "asyncpg-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc600ee8ef3dd38b8d67421359779f8ccec30b463e7aec7ed481c8346decf99f"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:039a261af4f38f949095e1e780bae84a25ffe3e370175193174eb08d3cecab23"}, + {file = "asyncpg-0.29.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6feaf2d8f9138d190e5ec4390c1715c3e87b37715cd69b2c3dfca616134efd2b"}, + {file = "asyncpg-0.29.0-cp311-cp311-win32.whl", hash = "sha256:1e186427c88225ef730555f5fdda6c1812daa884064bfe6bc462fd3a71c4b675"}, + {file = "asyncpg-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:cfe73ffae35f518cfd6e4e5f5abb2618ceb5ef02a2365ce64f132601000587d3"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6011b0dc29886ab424dc042bf9eeb507670a3b40aece3439944006aafe023178"}, + {file = "asyncpg-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b544ffc66b039d5ec5a7454667f855f7fec08e0dfaf5a5490dfafbb7abbd2cfb"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d84156d5fb530b06c493f9e7635aa18f518fa1d1395ef240d211cb563c4e2364"}, + {file = "asyncpg-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54858bc25b49d1114178d65a88e48ad50cb2b6f3e475caa0f0c092d5f527c106"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bde17a1861cf10d5afce80a36fca736a86769ab3579532c03e45f83ba8a09c59"}, + {file = "asyncpg-0.29.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:37a2ec1b9ff88d8773d3eb6d3784dc7e3fee7756a5317b67f923172a4748a175"}, + {file = "asyncpg-0.29.0-cp312-cp312-win32.whl", hash = "sha256:bb1292d9fad43112a85e98ecdc2e051602bce97c199920586be83254d9dafc02"}, + {file = "asyncpg-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:2245be8ec5047a605e0b454c894e54bf2ec787ac04b1cb7e0d3c67aa1e32f0fe"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0009a300cae37b8c525e5b449233d59cd9868fd35431abc470a3e364d2b85cb9"}, + {file = "asyncpg-0.29.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cad1324dbb33f3ca0cd2074d5114354ed3be2b94d48ddfd88af75ebda7c43cc"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:012d01df61e009015944ac7543d6ee30c2dc1eb2f6b10b62a3f598beb6531548"}, + {file = "asyncpg-0.29.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000c996c53c04770798053e1730d34e30cb645ad95a63265aec82da9093d88e7"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e0bfe9c4d3429706cf70d3249089de14d6a01192d617e9093a8e941fea8ee775"}, + {file = "asyncpg-0.29.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:642a36eb41b6313ffa328e8a5c5c2b5bea6ee138546c9c3cf1bffaad8ee36dd9"}, + {file = "asyncpg-0.29.0-cp38-cp38-win32.whl", hash = "sha256:a921372bbd0aa3a5822dd0409da61b4cd50df89ae85150149f8c119f23e8c408"}, + {file = "asyncpg-0.29.0-cp38-cp38-win_amd64.whl", hash = "sha256:103aad2b92d1506700cbf51cd8bb5441e7e72e87a7b3a2ca4e32c840f051a6a3"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5340dd515d7e52f4c11ada32171d87c05570479dc01dc66d03ee3e150fb695da"}, + {file = "asyncpg-0.29.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e17b52c6cf83e170d3d865571ba574577ab8e533e7361a2b8ce6157d02c665d3"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f100d23f273555f4b19b74a96840aa27b85e99ba4b1f18d4ebff0734e78dc090"}, + {file = "asyncpg-0.29.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48e7c58b516057126b363cec8ca02b804644fd012ef8e6c7e23386b7d5e6ce83"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f9ea3f24eb4c49a615573724d88a48bd1b7821c890c2effe04f05382ed9e8810"}, + {file = "asyncpg-0.29.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8d36c7f14a22ec9e928f15f92a48207546ffe68bc412f3be718eedccdf10dc5c"}, + {file = "asyncpg-0.29.0-cp39-cp39-win32.whl", hash = "sha256:797ab8123ebaed304a1fad4d7576d5376c3a006a4100380fb9d517f0b59c1ab2"}, + {file = "asyncpg-0.29.0-cp39-cp39-win_amd64.whl", hash = "sha256:cce08a178858b426ae1aa8409b5cc171def45d4293626e7aa6510696d46decd8"}, + {file = "asyncpg-0.29.0.tar.gz", hash = "sha256:d1c49e1f44fffafd9a55e1a9b101590859d881d639ea2922516f5d9c512d354e"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.12.0\""} + +[package.extras] +docs = ["Sphinx (>=5.3.0,<5.4.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["flake8 (>=6.1,<7.0)", "uvloop (>=0.15.3)"] + +[[package]] +name = "asyncpg-stubs" +version = "0.29.1" +description = "asyncpg stubs" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "asyncpg_stubs-0.29.1-py3-none-any.whl", hash = "sha256:cce994d5a19394249e74ae8d252bde3c77cee0ddfc776cc708b724fdb4adebb6"}, + {file = "asyncpg_stubs-0.29.1.tar.gz", hash = "sha256:686afcc0af3a2f3c8e393cd850e0de430e5a139ce82b2f28ef8f693ecdf918bf"}, +] + +[package.dependencies] +asyncpg = ">=0.29,<0.30" +typing-extensions = ">=4.7.0,<5.0.0" + [[package]] name = "attrs" version = "23.2.0" @@ -2802,6 +2874,28 @@ files = [ {file = "psycopg_binary-3.1.18-cp39-cp39-win_amd64.whl", hash = "sha256:d4422af5232699f14b7266a754da49dc9bcd45eba244cf3812307934cd5d6679"}, ] +[[package]] +name = "psycopg2" +version = "2.9.9" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "psycopg2-2.9.9-cp310-cp310-win32.whl", hash = "sha256:38a8dcc6856f569068b47de286b472b7c473ac7977243593a288ebce0dc89516"}, + {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, + {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, + {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, + {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, + {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, + {file = "psycopg2-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:bac58c024c9922c23550af2a581998624d6e02350f4ae9c5f0bc642c633a2d5e"}, + {file = "psycopg2-2.9.9-cp39-cp39-win32.whl", hash = "sha256:c92811b2d4c9b6ea0285942b2e7cac98a59e166d59c588fe5cfe1eda58e72d59"}, + {file = "psycopg2-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:de80739447af31525feddeb8effd640782cf5998e1a4e9192ebdf829717e3913"}, + {file = "psycopg2-2.9.9.tar.gz", hash = "sha256:d1454bde93fb1e224166811694d600e746430c006fbb031ea06ecc2ea41bf156"}, +] + [[package]] name = "pyasn1" version = "0.6.0" @@ -3780,6 +3874,17 @@ files = [ {file = "types_passlib-1.7.7.20240327-py3-none-any.whl", hash = "sha256:3a3b7f4258b71034d2e2f4f307d6810f9904f906cdf375514c8bdbdb28a4ad23"}, ] +[[package]] +name = "types-psycopg2" +version = "2.9.21.20240417" +description = "Typing stubs for psycopg2" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-psycopg2-2.9.21.20240417.tar.gz", hash = "sha256:05db256f4a459fb21a426b8e7fca0656c3539105ff0208eaf6bdaf406a387087"}, + {file = "types_psycopg2-2.9.21.20240417-py3-none-any.whl", hash = "sha256:644d6644d64ebbe37203229b00771012fb3b3bddd507a129a2e136485990e4f8"}, +] + [[package]] name = "types-pyasn1" version = "0.6.0.20240402" @@ -4312,4 +4417,4 @@ repair = ["scipy (>=1.6.3)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "7adbd6b6ee6bcd4a029be1d0ca2af53a04e0acfb56ec4948847d781429cfd00e" +content-hash = "63a45fcb57d8e078a8b7aae6cf8718b6e127312bbf9c0e591964249eaa330f81" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c9d5636e..6c127c9a 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -40,6 +40,10 @@ google-search-results = "^2.4.2" yfinance = "^0.2.38" langchain-core = "0.2.1" pyjwt = "^2.8.0" +psycopg2 = "^2.9.9" +asyncpg = "^0.29.0" +types-psycopg2 = "^2.9.21.20240417" +asyncpg-stubs = "^0.29.1" [tool.poetry.group.dev.dependencies] pytest = "^7.4.3" diff --git a/frontend/src/client/index.ts b/frontend/src/client/index.ts index 0ca61db2..73b04c5d 100644 --- a/frontend/src/client/index.ts +++ b/frontend/src/client/index.ts @@ -10,6 +10,8 @@ export type { OpenAPIConfig } from './core/OpenAPI'; export type { Body_login_login_access_token } from './models/Body_login_login_access_token'; export type { ChatMessage } from './models/ChatMessage'; export type { ChatMessageType } from './models/ChatMessageType'; +export type { CheckpointOut } from './models/CheckpointOut'; +export type { CreateThreadOut } from './models/CreateThreadOut'; export type { HTTPValidationError } from './models/HTTPValidationError'; export type { MemberCreate } from './models/MemberCreate'; export type { MemberOut } from './models/MemberOut'; @@ -25,6 +27,10 @@ export type { TeamCreate } from './models/TeamCreate'; export type { TeamOut } from './models/TeamOut'; export type { TeamsOut } from './models/TeamsOut'; export type { TeamUpdate } from './models/TeamUpdate'; +export type { ThreadCreate } from './models/ThreadCreate'; +export type { ThreadOut } from './models/ThreadOut'; +export type { ThreadsOut } from './models/ThreadsOut'; +export type { ThreadUpdate } from './models/ThreadUpdate'; export type { Token } from './models/Token'; export type { UpdatePassword } from './models/UpdatePassword'; export type { UserCreate } from './models/UserCreate'; @@ -38,6 +44,8 @@ export type { ValidationError } from './models/ValidationError'; export { $Body_login_login_access_token } from './schemas/$Body_login_login_access_token'; export { $ChatMessage } from './schemas/$ChatMessage'; export { $ChatMessageType } from './schemas/$ChatMessageType'; +export { $CheckpointOut } from './schemas/$CheckpointOut'; +export { $CreateThreadOut } from './schemas/$CreateThreadOut'; export { $HTTPValidationError } from './schemas/$HTTPValidationError'; export { $MemberCreate } from './schemas/$MemberCreate'; export { $MemberOut } from './schemas/$MemberOut'; @@ -53,6 +61,10 @@ export { $TeamCreate } from './schemas/$TeamCreate'; export { $TeamOut } from './schemas/$TeamOut'; export { $TeamsOut } from './schemas/$TeamsOut'; export { $TeamUpdate } from './schemas/$TeamUpdate'; +export { $ThreadCreate } from './schemas/$ThreadCreate'; +export { $ThreadOut } from './schemas/$ThreadOut'; +export { $ThreadsOut } from './schemas/$ThreadsOut'; +export { $ThreadUpdate } from './schemas/$ThreadUpdate'; export { $Token } from './schemas/$Token'; export { $UpdatePassword } from './schemas/$UpdatePassword'; export { $UserCreate } from './schemas/$UserCreate'; @@ -67,5 +79,6 @@ export { LoginService } from './services/LoginService'; export { MembersService } from './services/MembersService'; export { SkillsService } from './services/SkillsService'; export { TeamsService } from './services/TeamsService'; +export { ThreadsService } from './services/ThreadsService'; export { UsersService } from './services/UsersService'; export { UtilsService } from './services/UtilsService'; diff --git a/frontend/src/client/models/CheckpointOut.ts b/frontend/src/client/models/CheckpointOut.ts new file mode 100644 index 00000000..122c8ce0 --- /dev/null +++ b/frontend/src/client/models/CheckpointOut.ts @@ -0,0 +1,12 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type CheckpointOut = { + thread_id: string; + thread_ts: string; + checkpoint: Blob; + created_at: string; +}; + diff --git a/frontend/src/client/models/CreateThreadOut.ts b/frontend/src/client/models/CreateThreadOut.ts new file mode 100644 index 00000000..620b49b1 --- /dev/null +++ b/frontend/src/client/models/CreateThreadOut.ts @@ -0,0 +1,14 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { CheckpointOut } from './CheckpointOut'; + +export type CreateThreadOut = { + id: string; + query: string; + updated_at: string; + last_checkpoint: CheckpointOut; +}; + diff --git a/frontend/src/client/models/MemberCreate.ts b/frontend/src/client/models/MemberCreate.ts index 1b243581..191a3a59 100644 --- a/frontend/src/client/models/MemberCreate.ts +++ b/frontend/src/client/models/MemberCreate.ts @@ -15,5 +15,6 @@ export type MemberCreate = { provider?: string; model?: string; temperature?: number; + interrupt?: boolean; }; diff --git a/frontend/src/client/models/MemberOut.ts b/frontend/src/client/models/MemberOut.ts index 0904085e..c4ce7a22 100644 --- a/frontend/src/client/models/MemberOut.ts +++ b/frontend/src/client/models/MemberOut.ts @@ -17,6 +17,7 @@ export type MemberOut = { provider?: string; model?: string; temperature?: number; + interrupt?: boolean; id: number; belongs_to: number; skills: Array; diff --git a/frontend/src/client/models/MemberUpdate.ts b/frontend/src/client/models/MemberUpdate.ts index bd89aafc..18207cbf 100644 --- a/frontend/src/client/models/MemberUpdate.ts +++ b/frontend/src/client/models/MemberUpdate.ts @@ -17,6 +17,7 @@ export type MemberUpdate = { provider?: (string | null); model?: (string | null); temperature?: (number | null); + interrupt?: (boolean | null); belongs_to?: (number | null); skills?: (Array | null); }; diff --git a/frontend/src/client/models/ThreadCreate.ts b/frontend/src/client/models/ThreadCreate.ts new file mode 100644 index 00000000..2a653da4 --- /dev/null +++ b/frontend/src/client/models/ThreadCreate.ts @@ -0,0 +1,9 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type ThreadCreate = { + query: string; +}; + diff --git a/frontend/src/client/models/ThreadOut.ts b/frontend/src/client/models/ThreadOut.ts new file mode 100644 index 00000000..b6896214 --- /dev/null +++ b/frontend/src/client/models/ThreadOut.ts @@ -0,0 +1,11 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type ThreadOut = { + id: string; + query: string; + updated_at: string; +}; + diff --git a/frontend/src/client/models/ThreadUpdate.ts b/frontend/src/client/models/ThreadUpdate.ts new file mode 100644 index 00000000..6b3a7e0e --- /dev/null +++ b/frontend/src/client/models/ThreadUpdate.ts @@ -0,0 +1,10 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type ThreadUpdate = { + query?: (string | null); + updated_at?: (string | null); +}; + diff --git a/frontend/src/client/models/ThreadsOut.ts b/frontend/src/client/models/ThreadsOut.ts new file mode 100644 index 00000000..bbefe487 --- /dev/null +++ b/frontend/src/client/models/ThreadsOut.ts @@ -0,0 +1,12 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { ThreadOut } from './ThreadOut'; + +export type ThreadsOut = { + data: Array; + count: number; +}; + diff --git a/frontend/src/client/schemas/$CheckpointOut.ts b/frontend/src/client/schemas/$CheckpointOut.ts new file mode 100644 index 00000000..a5f706d5 --- /dev/null +++ b/frontend/src/client/schemas/$CheckpointOut.ts @@ -0,0 +1,28 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $CheckpointOut = { + properties: { + thread_id: { + type: 'string', + isRequired: true, + format: 'uuid', + }, + thread_ts: { + type: 'string', + isRequired: true, + format: 'uuid', + }, + checkpoint: { + type: 'binary', + isRequired: true, + format: 'binary', + }, + created_at: { + type: 'string', + isRequired: true, + format: 'date-time', + }, + }, +} as const; diff --git a/frontend/src/client/schemas/$CreateThreadOut.ts b/frontend/src/client/schemas/$CreateThreadOut.ts new file mode 100644 index 00000000..78fd938e --- /dev/null +++ b/frontend/src/client/schemas/$CreateThreadOut.ts @@ -0,0 +1,26 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $CreateThreadOut = { + properties: { + id: { + type: 'string', + isRequired: true, + format: 'uuid', + }, + query: { + type: 'string', + isRequired: true, + }, + updated_at: { + type: 'string', + isRequired: true, + format: 'date-time', + }, + last_checkpoint: { + type: 'CheckpointOut', + isRequired: true, + }, + }, +} as const; diff --git a/frontend/src/client/schemas/$MemberCreate.ts b/frontend/src/client/schemas/$MemberCreate.ts index 58ed195b..934d0ed8 100644 --- a/frontend/src/client/schemas/$MemberCreate.ts +++ b/frontend/src/client/schemas/$MemberCreate.ts @@ -58,5 +58,8 @@ export const $MemberCreate = { temperature: { type: 'number', }, + interrupt: { + type: 'boolean', + }, }, } as const; diff --git a/frontend/src/client/schemas/$MemberOut.ts b/frontend/src/client/schemas/$MemberOut.ts index c7d8034d..c2d39470 100644 --- a/frontend/src/client/schemas/$MemberOut.ts +++ b/frontend/src/client/schemas/$MemberOut.ts @@ -59,6 +59,9 @@ export const $MemberOut = { temperature: { type: 'number', }, + interrupt: { + type: 'boolean', + }, id: { type: 'number', isRequired: true, diff --git a/frontend/src/client/schemas/$MemberUpdate.ts b/frontend/src/client/schemas/$MemberUpdate.ts index c506a98b..f75bf3e6 100644 --- a/frontend/src/client/schemas/$MemberUpdate.ts +++ b/frontend/src/client/schemas/$MemberUpdate.ts @@ -93,6 +93,14 @@ export const $MemberUpdate = { type: 'null', }], }, + interrupt: { + type: 'any-of', + contains: [{ + type: 'boolean', + }, { + type: 'null', + }], + }, belongs_to: { type: 'any-of', contains: [{ diff --git a/frontend/src/client/schemas/$ThreadCreate.ts b/frontend/src/client/schemas/$ThreadCreate.ts new file mode 100644 index 00000000..ef024e0d --- /dev/null +++ b/frontend/src/client/schemas/$ThreadCreate.ts @@ -0,0 +1,12 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ThreadCreate = { + properties: { + query: { + type: 'string', + isRequired: true, + }, + }, +} as const; diff --git a/frontend/src/client/schemas/$ThreadOut.ts b/frontend/src/client/schemas/$ThreadOut.ts new file mode 100644 index 00000000..545bf7e5 --- /dev/null +++ b/frontend/src/client/schemas/$ThreadOut.ts @@ -0,0 +1,22 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ThreadOut = { + properties: { + id: { + type: 'string', + isRequired: true, + format: 'uuid', + }, + query: { + type: 'string', + isRequired: true, + }, + updated_at: { + type: 'string', + isRequired: true, + format: 'date-time', + }, + }, +} as const; diff --git a/frontend/src/client/schemas/$ThreadUpdate.ts b/frontend/src/client/schemas/$ThreadUpdate.ts new file mode 100644 index 00000000..cd8b27fb --- /dev/null +++ b/frontend/src/client/schemas/$ThreadUpdate.ts @@ -0,0 +1,25 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ThreadUpdate = { + properties: { + query: { + type: 'any-of', + contains: [{ + type: 'string', + }, { + type: 'null', + }], + }, + updated_at: { + type: 'any-of', + contains: [{ + type: 'string', + format: 'date-time', + }, { + type: 'null', + }], + }, + }, +} as const; diff --git a/frontend/src/client/schemas/$ThreadsOut.ts b/frontend/src/client/schemas/$ThreadsOut.ts new file mode 100644 index 00000000..d69f69c8 --- /dev/null +++ b/frontend/src/client/schemas/$ThreadsOut.ts @@ -0,0 +1,19 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export const $ThreadsOut = { + properties: { + data: { + type: 'array', + contains: { + type: 'ThreadOut', + }, + isRequired: true, + }, + count: { + type: 'number', + isRequired: true, + }, + }, +} as const; diff --git a/frontend/src/client/services/TeamsService.ts b/frontend/src/client/services/TeamsService.ts index b3a6f349..c3a26543 100644 --- a/frontend/src/client/services/TeamsService.ts +++ b/frontend/src/client/services/TeamsService.ts @@ -143,16 +143,19 @@ export class TeamsService { */ public static stream({ id, + threadId, requestBody, }: { id: number, + threadId: string, requestBody: TeamChat, }): CancelablePromise { return __request(OpenAPI, { method: 'POST', - url: '/api/v1/teams/{id}/stream', + url: '/api/v1/teams/{id}/stream/{thread_id}', path: { 'id': id, + 'thread_id': threadId, }, body: requestBody, mediaType: 'application/json', diff --git a/frontend/src/client/services/ThreadsService.ts b/frontend/src/client/services/ThreadsService.ts new file mode 100644 index 00000000..7c4f6108 --- /dev/null +++ b/frontend/src/client/services/ThreadsService.ts @@ -0,0 +1,157 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { CreateThreadOut } from '../models/CreateThreadOut'; +import type { ThreadCreate } from '../models/ThreadCreate'; +import type { ThreadOut } from '../models/ThreadOut'; +import type { ThreadsOut } from '../models/ThreadsOut'; +import type { ThreadUpdate } from '../models/ThreadUpdate'; + +import type { CancelablePromise } from '../core/CancelablePromise'; +import { OpenAPI } from '../core/OpenAPI'; +import { request as __request } from '../core/request'; + +export class ThreadsService { + + /** + * Read Threads + * Retrieve threads + * @returns ThreadsOut Successful Response + * @throws ApiError + */ + public static readThreads({ + teamId, + skip, + limit = 100, + }: { + teamId: number, + skip?: number, + limit?: number, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/teams/{team_id}/threads/', + path: { + 'team_id': teamId, + }, + query: { + 'skip': skip, + 'limit': limit, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Create Thread + * Create new thread + * @returns ThreadOut Successful Response + * @throws ApiError + */ + public static createThread({ + teamId, + requestBody, + }: { + teamId: number, + requestBody: ThreadCreate, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/teams/{team_id}/threads/', + path: { + 'team_id': teamId, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Read Thread + * Get thread and its last checkpoint by ID + * @returns CreateThreadOut Successful Response + * @throws ApiError + */ + public static readThread({ + teamId, + id, + }: { + teamId: number, + id: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/teams/{team_id}/threads/{id}', + path: { + 'team_id': teamId, + 'id': id, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Update Thread + * Update a thread. + * @returns ThreadOut Successful Response + * @throws ApiError + */ + public static updateThread({ + teamId, + id, + requestBody, + }: { + teamId: number, + id: string, + requestBody: ThreadUpdate, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'PUT', + url: '/api/v1/teams/{team_id}/threads/{id}', + path: { + 'team_id': teamId, + 'id': id, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Delete Thread + * Delete a thread. + * @returns any Successful Response + * @throws ApiError + */ + public static deleteThread({ + teamId, + id, + }: { + teamId: number, + id: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'DELETE', + url: '/api/v1/teams/{team_id}/threads/{id}', + path: { + 'team_id': teamId, + 'id': id, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + +} diff --git a/frontend/src/components/Teams/ChatTeam.tsx b/frontend/src/components/Teams/ChatTeam.tsx index 9d3ffd0d..5514dc61 100644 --- a/frontend/src/components/Teams/ChatTeam.tsx +++ b/frontend/src/components/Teams/ChatTeam.tsx @@ -18,11 +18,14 @@ import { OpenAPI, type OpenAPIConfig, type ChatMessage, + type ThreadUpdate, + ThreadsService, + type ThreadCreate, } from "../../client" -import { useMutation } from "react-query" +import { useMutation, useQuery, useQueryClient } from "react-query" import useCustomToast from "../../hooks/useCustomToast" -import { useParams } from "@tanstack/react-router" -import { useState } from "react" +import { getRouteApi, useNavigate, useParams } from "@tanstack/react-router" +import { useEffect, useState } from "react" import { getQueryString, getRequestBody, @@ -31,13 +34,14 @@ import { import type { ApiRequestOptions } from "../../client/core/ApiRequestOptions" import Markdown from "react-markdown" import { GrFormNextLink } from "react-icons/gr" +import { convertCheckpointToMessages } from "../../utils" -interface ToolInput { +export interface ToolInput { name: string args: { [x: string]: any } } -interface Message extends ChatMessage { +export interface Message extends ChatMessage { toolCalls?: ToolInput[] member: string next?: string @@ -63,12 +67,13 @@ const getUrl = (config: OpenAPIConfig, options: ApiRequestOptions): string => { return url } -const stream = async (id: number, data: TeamChat) => { +const stream = async (id: number, threadId: string, data: TeamChat) => { const requestOptions = { method: "POST" as const, - url: "/api/v1/teams/{id}/stream", + url: "/api/v1/teams/{id}/stream/{threadId}", path: { id, + threadId, }, body: data, mediaType: "application/json", @@ -95,10 +100,18 @@ const stream = async (id: number, data: TeamChat) => { const MessageBox = ({ message }: { message: Message }) => { const { member, next, content, toolCalls } = message const hasTools = (toolCalls && toolCalls.length > 0) || false + const memberComp = + member === "You." ? ( + + You + + ) : ( + member + ) return ( - {member} + {memberComp} {next && } {next && next} {hasTools && Tool} @@ -119,15 +132,112 @@ const MessageBox = ({ message }: { message: Message }) => { } const ChatTeam = () => { + const queryClient = useQueryClient() + const navigate = useNavigate() + const { threadId } = getRouteApi("/_layout/teams/$teamId").useSearch() const { teamId } = useParams({ strict: false }) as { teamId: string } const showToast = useCustomToast() const [input, setInput] = useState("") const [messages, setMessages] = useState([]) const [isStreaming, setIsStreaming] = useState(false) + const threadData = useQuery( + ["thread", threadId], + () => + ThreadsService.readThread({ + teamId: Number.parseInt(teamId), + id: threadId!, + }), + { + enabled: !!threadId, // only runs the query if threadId is not null or undefined + onError: (err: ApiError) => { + const errDetail = err.body?.detail + showToast("Something went wrong.", `${errDetail}`, "error") + // if fail, then remove it from search params and delete existing messages + navigate({ search: {} }) + setMessages([]) + }, + }, + ) + + useEffect(() => { + if (threadData.data?.last_checkpoint?.checkpoint) { + const checkpoint = JSON.parse( + threadData.data.last_checkpoint.checkpoint as unknown as string, + ) + const messages = convertCheckpointToMessages(checkpoint) + setMessages(messages) + } + }, [threadData.data]) + + const createThread = async (data: ThreadCreate) => { + const thread = await ThreadsService.createThread({ + teamId: Number.parseInt(teamId), + requestBody: data, + }) + return thread.id + } + const createThreadMutation = useMutation(createThread, { + onSuccess: (threadId) => { + navigate({ search: { threadId } }) + }, + onError: (err: ApiError) => { + const errDetail = err.body?.detail + showToast("Unable to create thread", `${errDetail}`, "error") + }, + onSettled: () => { + queryClient.invalidateQueries(["threads", teamId]) + }, + }) + + const updateThread = async (data: ThreadUpdate) => { + if (!threadId) return + const thread = await ThreadsService.updateThread({ + teamId: Number.parseInt(teamId), + id: threadId, + requestBody: data, + }) + return thread.id + } + const updateThreadMutation = useMutation(updateThread, { + onError: (err: ApiError) => { + const errDetail = err.body?.detail + showToast("Unable to update thread.", `${errDetail}`, "error") + }, + onSettled: () => { + queryClient.invalidateQueries(["threads", teamId]) + }, + }) const chatTeam = async (data: TeamChat) => { - setMessages([]) - const res = await stream(Number.parseInt(teamId), data) + setMessages((prev) => [ + ...prev, + { + type: "human", + content: data.messages[0].content, + member: "You.", + }, + ]) + // Create a new thread or update current thread with most recent user query + const query = data.messages + let currentThreadId: string | undefined | null = threadId + if (!threadId) { + currentThreadId = await createThreadMutation.mutateAsync({ + query: query[0].content, + }) + } else { + currentThreadId = await updateThreadMutation.mutateAsync({ + query: query[0].content, + }) + } + + if (!currentThreadId) + return showToast( + "Something went wrong.", + "Unable to obtain thread id", + "error", + ) + + const res = await stream(Number.parseInt(teamId), currentThreadId, data) if (res.body) { const reader = res.body.getReader() @@ -174,7 +284,12 @@ const ChatTeam = () => { setMessages((prev) => [...prev, ...newMessages]) } catch (error) { - console.error("Failed to parse JSON:", error) + console.error("Failed to parse messages:", error) + return showToast( + "Something went wrong.", + "Unable to parse messages", + "error", + ) } } boundary = buffer.indexOf("\n\n") @@ -182,10 +297,12 @@ const ChatTeam = () => { } } } - setIsStreaming(false) } const mutation = useMutation(chatTeam, { + onMutate: () => { + setIsStreaming(true) + }, onError: (err: ApiError) => { const errDetail = err.body?.detail showToast("Something went wrong.", `${errDetail}`, "error") @@ -193,12 +310,15 @@ const ChatTeam = () => { onSuccess: () => { showToast("Success!", "Streaming completed.", "success") }, + onSettled: () => { + setIsStreaming(false) + }, }) const onSubmit = async (e: React.FormEvent) => { e.preventDefault() - setIsStreaming(true) mutation.mutate({ messages: [{ type: "human", content: input }] }) + setInput("") } return ( diff --git a/frontend/src/components/Teams/ViewThreads.tsx b/frontend/src/components/Teams/ViewThreads.tsx new file mode 100644 index 00000000..a1eefbb3 --- /dev/null +++ b/frontend/src/components/Teams/ViewThreads.tsx @@ -0,0 +1,129 @@ +import { + Flex, + Spinner, + Container, + TableContainer, + Table, + Thead, + Tr, + Th, + Td, + Tbody, + useColorModeValue, + IconButton, +} from "@chakra-ui/react" +import { useMutation, useQuery, useQueryClient } from "react-query" +import { ThreadsService, type ApiError } from "../../client" +import useCustomToast from "../../hooks/useCustomToast" +import { getRouteApi, useNavigate } from "@tanstack/react-router" +import { DeleteIcon } from "@chakra-ui/icons" + +interface ChatHistoryProps { + teamId: string + updateTabIndex: (index: number) => void +} + +const ChatHistory = ({ teamId, updateTabIndex }: ChatHistoryProps) => { + const queryClient = useQueryClient() + const { threadId } = getRouteApi("/_layout/teams/$teamId").useSearch() + const navigate = useNavigate() + const showToast = useCustomToast() + const rowTint = useColorModeValue("blackAlpha.50", "whiteAlpha.50") + const { + data: threads, + isLoading, + isError, + error, + } = useQuery(["threads", teamId], () => + ThreadsService.readThreads({ teamId: Number.parseInt(teamId) }), + ) + const deleteThread = async (threadId: string) => { + await ThreadsService.deleteThread({ + teamId: Number.parseInt(teamId), + id: threadId, + }) + } + const deleteThreadMutation = useMutation(deleteThread, { + onError: (err: ApiError) => { + const errDetail = err.body?.detail + showToast("Unable to delete thread.", `${errDetail}`, "error") + }, + onSettled: () => { + queryClient.invalidateQueries(["threads", teamId]) + queryClient.invalidateQueries(["thread", threadId]) + }, + }) + + /** + * Set the threadId in the search params and navigate to 'Chat' tab + */ + const onClickRowHandler = (threadId: string) => { + navigate({ search: { threadId } }) + updateTabIndex(1) + } + + const onDeleteHandler = ( + e: React.MouseEvent, + threadId: string, + ) => { + e.stopPropagation() + deleteThreadMutation.mutate(threadId) + } + + if (isError) { + const errDetail = (error as ApiError).body?.detail + showToast("Something went wrong.", `${errDetail}`, "error") + } + + return ( + <> + {isLoading ? ( + // TODO: Add skeleton + + + + ) : ( + threads && ( + + + + + + + + + + + + + {threads.data.map((thread) => ( + onClickRowHandler(thread.id)} + _hover={{ backgroundColor: rowTint }} + cursor={"pointer"} + > + + + + + + ))} + +
Start TimeRecent QueryThread IDActions
{new Date(thread.updated_at).toLocaleString()}{thread.query}{thread.id} + } + onClick={(e) => onDeleteHandler(e, thread.id)} + /> +
+
+
+ ) + )} + + ) +} + +export default ChatHistory diff --git a/frontend/src/routes/_layout/teams.$teamId.tsx b/frontend/src/routes/_layout/teams.$teamId.tsx index 8642eeb0..fcc2c7f5 100644 --- a/frontend/src/routes/_layout/teams.$teamId.tsx +++ b/frontend/src/routes/_layout/teams.$teamId.tsx @@ -19,14 +19,27 @@ import useCustomToast from "../../hooks/useCustomToast" import { ChevronRightIcon } from "@chakra-ui/icons" import Flow from "../../components/ReactFlow/Flow" import ChatTeam from "../../components/Teams/ChatTeam" +import ViewThreads from "../../components/Teams/ViewThreads" +import { useState } from "react" + +type SearchSchema = { + threadId?: string +} export const Route = createFileRoute("/_layout/teams/$teamId")({ component: Team, + validateSearch: (search: Record): SearchSchema => { + return { + threadId: + typeof search?.threadId === "string" ? search.threadId : undefined, + } + }, }) function Team() { const showToast = useCustomToast() const { teamId } = Route.useParams() + const [tabIndex, setTabIndex] = useState(0) const { data: team, isLoading, @@ -73,10 +86,16 @@ function Team() { > {team.name} - + Build Chat + Threads @@ -85,6 +104,9 @@ function Team() { + + +
diff --git a/frontend/src/utils.ts b/frontend/src/utils.ts index 1bf5167a..0acb65c9 100644 --- a/frontend/src/utils.ts +++ b/frontend/src/utils.ts @@ -1,4 +1,39 @@ +import type { Message, ToolInput } from "./components/Teams/ChatTeam" + export const emailPattern = { value: /^[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,4}$/i, message: "Invalid email address", } + +interface CheckpointMessage { + kwargs: { + tool_calls?: ToolInput[] + name: string + type: string + content: string + next?: string + } +} + +/** + * Convert langgraph's checkpoint data to messages + * @param checkpoint Checkpoint + */ +export const convertCheckpointToMessages = (checkpoint: any): Message[] => { + const messages: Message[] = [] + + for (const message of checkpoint.channel_values + .messages as CheckpointMessage[]) { + if (message.kwargs.type === "tool") continue + const { type, content, next, tool_calls, name } = message.kwargs + messages.push({ + toolCalls: tool_calls || [], + member: type === "human" ? "You." : name, + type, + content, + next, + } as Message) + } + + return messages +}