Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement persistent chats and enable continuation of past conversations #36

Merged
merged 27 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6f1ed09
Install aiosqlite
StreetLamb May 25, 2024
dd84e53
Create migration to add 'interrupt' column in members table. Update m…
StreetLamb May 25, 2024
6709127
Install psycopg2 and asyncpg
StreetLamb May 25, 2024
4f23993
Create AsyncPostgresSaver and PostgresSaver classes for graph checkpo…
StreetLamb May 25, 2024
aad2f3f
Add persistence to graph. Update stream route to require thread_id in…
StreetLamb May 25, 2024
47f9ea4
Create threads table
StreetLamb May 25, 2024
4cbbcb5
Create thread routes
StreetLamb May 25, 2024
732c605
Fix ThreadsOut and ThreadOut pydantic class
StreetLamb May 25, 2024
dadaf6a
Generated frontend ThreadsService code
StreetLamb May 25, 2024
1c989ad
Fix thread migration, model and route
StreetLamb May 26, 2024
b9f12d5
Update stream route to check if user has permission to run thread
StreetLamb May 26, 2024
c191bc3
Fix bug by removing useless await
StreetLamb May 26, 2024
2c63251
Update TheadsService request and response
StreetLamb May 26, 2024
0641e72
Fix metadata type in PostgresSaver setup method
StreetLamb May 26, 2024
527152a
Add checkpoint model and migration. Update read_thread to return last…
StreetLamb May 26, 2024
d784e27
Update ThreadsService types
StreetLamb May 26, 2024
e94c8bf
Add util functions to convert checkpoint data to messages
StreetLamb May 26, 2024
5d2c108
Create component to show team's threads
StreetLamb May 26, 2024
3d883d4
Update chat tab to create/update thread and store threadId in query p…
StreetLamb May 26, 2024
c8816e0
Fix read_thread and delete_thread queries to only return associated c…
StreetLamb May 27, 2024
79c328d
Delete search params and remove messages if readThread fails
StreetLamb May 27, 2024
071174e
Fix bug in read_thread and delete_thread
StreetLamb May 27, 2024
20d0870
Move toggling isStreaming to onMutate and onSettled
StreetLamb May 27, 2024
e0ae149
deleteThread should invalidate read_thread
StreetLamb May 27, 2024
1ebaef6
Remove unused aiosqlite dependency
StreetLamb May 27, 2024
515c623
Add stubs for asyncpg and psycopg2
StreetLamb May 27, 2024
9c3351c
Fix mypy errors
StreetLamb May 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions backend/app/alembic/versions/3a8a5f819c5f_add_thread_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Add thread table

Revision ID: 3a8a5f819c5f
Revises: 3b4636df4c6d
Create Date: 2024-05-25 13:57:57.773217

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '3a8a5f819c5f'
down_revision = '3b4636df4c6d'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('thread',
sa.Column('query', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('team_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['team_id'], ['team.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_thread_id'), 'thread', ['id'], unique=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_thread_id'), table_name='thread')
op.drop_table('thread')
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Add interrupt column in members table

Revision ID: 3b4636df4c6d
Revises: b922e3e826e9
Create Date: 2024-05-25 01:04:31.805491

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '3b4636df4c6d'
down_revision = 'b922e3e826e9'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
# Add the new column without the NOT NULL constraint
op.add_column('member', sa.Column('interrupt', sa.Boolean(), nullable=True))
# Set the default value for existing rows
op.execute('UPDATE member SET interrupt = FALSE WHERE interrupt IS NULL')
# Alter the column to add the NOT NULL constraint
op.alter_column('member', 'interrupt', nullable=False, server_default=sa.false())
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('member', 'interrupt')
# ### end Alembic commands ###
38 changes: 38 additions & 0 deletions backend/app/alembic/versions/6fa42be09dd2_add_checkpoints_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Add checkpoints table

Revision ID: 6fa42be09dd2
Revises: 3a8a5f819c5f
Create Date: 2024-05-26 02:42:53.431421

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '6fa42be09dd2'
down_revision = '3a8a5f819c5f'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('checkpoints',
sa.Column('thread_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('thread_ts', sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column('parent_ts', sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column('checkpoint', sa.LargeBinary(), nullable=False),
sa.Column('metadata', sa.LargeBinary(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['thread_id'], ['thread.id'], ),
sa.PrimaryKeyConstraint('thread_id', 'thread_ts')
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('checkpoints')
# ### end Alembic commands ###
5 changes: 4 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from app.api.routes import login, members, skills, teams, users, utils
from app.api.routes import login, members, skills, teams, threads, users, utils

api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
Expand All @@ -11,3 +11,6 @@
members.router, prefix="/teams/{team_id}/members", tags=["members"]
)
api_router.include_router(skills.router, prefix="/skills", tags=["skills"])
api_router.include_router(
threads.router, prefix="/teams/{team_id}/threads", tags=["threads"]
)
20 changes: 17 additions & 3 deletions backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TeamOut,
TeamsOut,
TeamUpdate,
Thread,
)

# TODO: To remove
Expand Down Expand Up @@ -211,9 +212,13 @@ def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any:
return Message(message="Team deleted successfully")


@router.post("/{id}/stream")
@router.post("/{id}/stream/{thread_id}")
async def stream(
session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat
session: SessionDep,
current_user: CurrentUser,
id: int,
thread_id: str,
team_chat: TeamChat,
) -> StreamingResponse:
"""
Stream a response to a user's input.
Expand All @@ -225,12 +230,21 @@ async def stream(
if not current_user.is_superuser and (team.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")

# Check if thread belongs to the team
thread = session.get(Thread, thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
if thread.team_id != id:
raise HTTPException(
status_code=400, detail="Thread does not belong to the team"
)

# Populate the skills for each member
members = team.members
for member in members:
member.skills = member.skills

return StreamingResponse(
generator(team, members, team_chat.messages),
generator(team, members, team_chat.messages, thread_id),
media_type="text/event-stream",
)
209 changes: 209 additions & 0 deletions backend/app/api/routes/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from typing import Any
from uuid import UUID

from fastapi import APIRouter, HTTPException
from sqlmodel import col, func, select

from app.api.deps import CurrentUser, SessionDep
from app.models import (
Checkpoint,
CreateThreadOut,
Message,
Team,
Thread,
ThreadCreate,
ThreadOut,
ThreadsOut,
ThreadUpdate,
)

router = APIRouter()


@router.get("/", response_model=ThreadsOut)
def read_threads(
session: SessionDep,
current_user: CurrentUser,
team_id: int,
skip: int = 0,
limit: int = 100,
) -> Any:
"""
Retrieve threads
"""

if current_user.is_superuser:
count_statement = select(func.count()).select_from(Thread)
count = session.exec(count_statement).one()
statement = (
select(Thread).where(Thread.team_id == team_id).offset(skip).limit(limit)
)
threads = session.exec(statement).all()
else:
count_statement = (
select(func.count())
.select_from(Thread)
.join(Team)
.where(Team.owner_id == current_user.id, Thread.team_id == team_id)
)
count = session.exec(count_statement).one()
statement = (
select(Thread)
.join(Team)
.where(Team.owner_id == current_user.id, Thread.team_id == team_id)
.offset(skip)
.limit(limit)
)
threads = session.exec(statement).all()
return ThreadsOut(data=threads, count=count)


@router.get("/{id}", response_model=CreateThreadOut)
def read_thread(
session: SessionDep, current_user: CurrentUser, team_id: int, id: UUID
) -> Any:
"""
Get thread and its last checkpoint by ID
"""
if current_user.is_superuser:
statement = (
select(Thread)
.join(Team)
.where(
Thread.id == id,
Thread.team_id == team_id,
)
)
thread = session.exec(statement).first()
else:
statement = (
select(Thread)
.join(Team)
.where(
Thread.id == id,
Thread.team_id == team_id,
Team.owner_id == current_user.id,
)
)
thread = session.exec(statement).first()

if not thread:
raise HTTPException(status_code=404, detail="Thread not found")

checkpoint_statement = (
select(Checkpoint)
.where(Checkpoint.thread_id == thread.id)
.order_by(col(Checkpoint.created_at).desc())
)
checkpoint = session.exec(checkpoint_statement).first()

return CreateThreadOut(
id=thread.id,
query=thread.query,
last_checkpoint=checkpoint,
updated_at=thread.updated_at,
)


@router.post("/", response_model=ThreadOut)
def create_thread(
*,
session: SessionDep,
current_user: CurrentUser,
team_id: int,
thread_in: ThreadCreate,
) -> Any:
"""
Create new thread
"""
if not current_user.is_superuser:
team = session.get(Team, team_id)
if not team:
raise HTTPException(status_code=404, detail="Team not found.")
if team.owner_id != current_user.id:
raise HTTPException(status_code=400, detail="Not enough permissions")
thread = Thread.model_validate(thread_in, update={"team_id": team_id})
session.add(thread)
session.commit()
session.refresh(thread)
return thread


@router.put("/{id}", response_model=ThreadOut)
def update_thread(
*,
session: SessionDep,
current_user: CurrentUser,
team_id: int,
id: UUID,
thread_in: ThreadUpdate,
) -> Any:
"""
Update a thread.
"""
if current_user.is_superuser:
statement = (
select(Thread).join(Team).where(Thread.id == id, Thread.team_id == team_id)
)
thread = session.exec(statement).first()
else:
statement = (
select(Thread)
.join(Team)
.where(
Thread.id == id,
Thread.team_id == team_id,
Team.owner_id == current_user.id,
)
)
thread = session.exec(statement).first()

if not thread:
raise HTTPException(status_code=404, detail="Member not found")

update_dict = thread_in.model_dump(exclude_unset=True)
thread.sqlmodel_update(update_dict)
session.add(thread)
session.commit()
session.refresh(thread)
return thread


@router.delete("/{id}")
def delete_thread(
session: SessionDep, current_user: CurrentUser, team_id: int, id: UUID
) -> Any:
"""
Delete a thread.
"""
if current_user.is_superuser:
statement = (
select(Thread)
.join(Team)
.where(
Thread.id == id,
Thread.team_id == team_id,
)
)
thread = session.exec(statement).first()
else:
statement = (
select(Thread)
.join(Team)
.where(
Thread.id == id,
Thread.team_id == team_id,
Team.owner_id == current_user.id,
)
)
thread = session.exec(statement).first()

if not thread:
raise HTTPException(status_code=404, detail="Thread not found")

for checkpoint in thread.checkpoints:
session.delete(checkpoint)

session.delete(thread)
session.commit()
return Message(message="Thread deleted successfully")
5 changes: 5 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
path=self.POSTGRES_DB,
)

@computed_field # type: ignore[misc]
@property
def PG_DATABASE_URI(self) -> str:
return f"postgres://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}@{self.POSTGRES_SERVER}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"

SMTP_TLS: bool = True
SMTP_SSL: bool = False
SMTP_PORT: int = 587
Expand Down
Loading