-
Notifications
You must be signed in to change notification settings - Fork 44.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(platform): Add basic library functionality (#9043)
Add functionality to allow users to add agents to their library from the store page.
- Loading branch information
Showing
17 changed files
with
702 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
165 changes: 165 additions & 0 deletions
165
autogpt_platform/backend/backend/server/v2/library/db.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import logging | ||
from typing import List | ||
|
||
import prisma.errors | ||
import prisma.models | ||
import prisma.types | ||
|
||
import backend.data.graph | ||
import backend.data.includes | ||
import backend.server.v2.library.model | ||
import backend.server.v2.store.exceptions | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def get_library_agents( | ||
user_id: str, | ||
) -> List[backend.server.v2.library.model.LibraryAgent]: | ||
""" | ||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table) | ||
""" | ||
logger.debug(f"Getting library agents for user {user_id}") | ||
|
||
try: | ||
# Get agents created by user with nodes and links | ||
user_created = await prisma.models.AgentGraph.prisma().find_many( | ||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True), | ||
include=backend.data.includes.AGENT_GRAPH_INCLUDE, | ||
) | ||
|
||
# Get agents in user's library with nodes and links | ||
library_agents = await prisma.models.UserAgent.prisma().find_many( | ||
where=prisma.types.UserAgentWhereInput( | ||
userId=user_id, isDeleted=False, isArchived=False | ||
), | ||
include={ | ||
"Agent": { | ||
"include": { | ||
"AgentNodes": { | ||
"include": { | ||
"Input": True, | ||
"Output": True, | ||
"Webhook": True, | ||
"AgentBlock": True, | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
) | ||
|
||
# Convert to Graph models first | ||
graphs = [] | ||
|
||
# Add user created agents | ||
for agent in user_created: | ||
try: | ||
graphs.append(backend.data.graph.GraphModel.from_db(agent)) | ||
except Exception as e: | ||
logger.error(f"Error processing user created agent {agent.id}: {e}") | ||
continue | ||
|
||
# Add library agents | ||
for agent in library_agents: | ||
if agent.Agent: | ||
try: | ||
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent)) | ||
except Exception as e: | ||
logger.error(f"Error processing library agent {agent.agentId}: {e}") | ||
continue | ||
|
||
# Convert Graph models to LibraryAgent models | ||
result = [] | ||
for graph in graphs: | ||
result.append( | ||
backend.server.v2.library.model.LibraryAgent( | ||
id=graph.id, | ||
version=graph.version, | ||
is_active=graph.is_active, | ||
name=graph.name, | ||
description=graph.description, | ||
isCreatedByUser=any(a.id == graph.id for a in user_created), | ||
input_schema=graph.input_schema, | ||
output_schema=graph.output_schema, | ||
) | ||
) | ||
|
||
logger.debug(f"Found {len(result)} library agents") | ||
return result | ||
|
||
except prisma.errors.PrismaError as e: | ||
logger.error(f"Database error getting library agents: {str(e)}") | ||
raise backend.server.v2.store.exceptions.DatabaseError( | ||
"Failed to fetch library agents" | ||
) from e | ||
|
||
|
||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None: | ||
""" | ||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table) | ||
if they don't already have it | ||
""" | ||
logger.debug( | ||
f"Adding agent from store listing version {store_listing_version_id} to library for user {user_id}" | ||
) | ||
|
||
try: | ||
# Get store listing version to find agent | ||
store_listing_version = ( | ||
await prisma.models.StoreListingVersion.prisma().find_unique( | ||
where={"id": store_listing_version_id}, include={"Agent": True} | ||
) | ||
) | ||
|
||
if not store_listing_version or not store_listing_version.Agent: | ||
logger.warning( | ||
f"Store listing version not found: {store_listing_version_id}" | ||
) | ||
raise backend.server.v2.store.exceptions.AgentNotFoundError( | ||
f"Store listing version {store_listing_version_id} not found" | ||
) | ||
|
||
agent = store_listing_version.Agent | ||
|
||
if agent.userId == user_id: | ||
logger.warning( | ||
f"User {user_id} cannot add their own agent to their library" | ||
) | ||
raise backend.server.v2.store.exceptions.DatabaseError( | ||
"Cannot add own agent to library" | ||
) | ||
|
||
# Check if user already has this agent | ||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first( | ||
where={ | ||
"userId": user_id, | ||
"agentId": agent.id, | ||
"agentVersion": agent.version, | ||
} | ||
) | ||
|
||
if existing_user_agent: | ||
logger.debug( | ||
f"User {user_id} already has agent {agent.id} in their library" | ||
) | ||
return | ||
|
||
# Create UserAgent entry | ||
await prisma.models.UserAgent.prisma().create( | ||
data=prisma.types.UserAgentCreateInput( | ||
userId=user_id, | ||
agentId=agent.id, | ||
agentVersion=agent.version, | ||
isCreatedByUser=False, | ||
) | ||
) | ||
logger.debug(f"Added agent {agent.id} to library for user {user_id}") | ||
|
||
except backend.server.v2.store.exceptions.AgentNotFoundError: | ||
raise | ||
except prisma.errors.PrismaError as e: | ||
logger.error(f"Database error adding agent to library: {str(e)}") | ||
raise backend.server.v2.store.exceptions.DatabaseError( | ||
"Failed to add agent to library" | ||
) from e |
189 changes: 189 additions & 0 deletions
189
autogpt_platform/backend/backend/server/v2/library/db_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from datetime import datetime | ||
|
||
import prisma.errors | ||
import prisma.models | ||
import pytest | ||
from prisma import Prisma | ||
|
||
import backend.data.includes | ||
import backend.server.v2.library.db as db | ||
import backend.server.v2.store.exceptions | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
async def setup_prisma(): | ||
# Don't register client if already registered | ||
try: | ||
Prisma() | ||
except prisma.errors.ClientAlreadyRegisteredError: | ||
pass | ||
yield | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_library_agents(mocker): | ||
# Mock data | ||
mock_user_created = [ | ||
prisma.models.AgentGraph( | ||
id="agent1", | ||
version=1, | ||
name="Test Agent 1", | ||
description="Test Description 1", | ||
userId="test-user", | ||
isActive=True, | ||
createdAt=datetime.now(), | ||
isTemplate=False, | ||
) | ||
] | ||
|
||
mock_library_agents = [ | ||
prisma.models.UserAgent( | ||
id="ua1", | ||
userId="test-user", | ||
agentId="agent2", | ||
agentVersion=1, | ||
isCreatedByUser=False, | ||
isDeleted=False, | ||
isArchived=False, | ||
createdAt=datetime.now(), | ||
updatedAt=datetime.now(), | ||
isFavorite=False, | ||
Agent=prisma.models.AgentGraph( | ||
id="agent2", | ||
version=1, | ||
name="Test Agent 2", | ||
description="Test Description 2", | ||
userId="other-user", | ||
isActive=True, | ||
createdAt=datetime.now(), | ||
isTemplate=False, | ||
), | ||
) | ||
] | ||
|
||
# Mock prisma calls | ||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma") | ||
mock_agent_graph.return_value.find_many = mocker.AsyncMock( | ||
return_value=mock_user_created | ||
) | ||
|
||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma") | ||
mock_user_agent.return_value.find_many = mocker.AsyncMock( | ||
return_value=mock_library_agents | ||
) | ||
|
||
# Call function | ||
result = await db.get_library_agents("test-user") | ||
|
||
# Verify results | ||
assert len(result) == 2 | ||
assert result[0].id == "agent1" | ||
assert result[0].name == "Test Agent 1" | ||
assert result[0].description == "Test Description 1" | ||
assert result[0].isCreatedByUser is True | ||
assert result[1].id == "agent2" | ||
assert result[1].name == "Test Agent 2" | ||
assert result[1].description == "Test Description 2" | ||
assert result[1].isCreatedByUser is False | ||
|
||
# Verify mocks called correctly | ||
mock_agent_graph.return_value.find_many.assert_called_once_with( | ||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True), | ||
include=backend.data.includes.AGENT_GRAPH_INCLUDE, | ||
) | ||
mock_user_agent.return_value.find_many.assert_called_once_with( | ||
where=prisma.types.UserAgentWhereInput( | ||
userId="test-user", isDeleted=False, isArchived=False | ||
), | ||
include={ | ||
"Agent": { | ||
"include": { | ||
"AgentNodes": { | ||
"include": { | ||
"Input": True, | ||
"Output": True, | ||
"Webhook": True, | ||
"AgentBlock": True, | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_add_agent_to_library(mocker): | ||
# Mock data | ||
mock_store_listing = prisma.models.StoreListingVersion( | ||
id="version123", | ||
version=1, | ||
createdAt=datetime.now(), | ||
updatedAt=datetime.now(), | ||
agentId="agent1", | ||
agentVersion=1, | ||
slug="test-agent", | ||
name="Test Agent", | ||
subHeading="Test Agent Subheading", | ||
imageUrls=["https://example.com/image.jpg"], | ||
description="Test Description", | ||
categories=["test"], | ||
isFeatured=False, | ||
isDeleted=False, | ||
isAvailable=True, | ||
isApproved=True, | ||
Agent=prisma.models.AgentGraph( | ||
id="agent1", | ||
version=1, | ||
name="Test Agent", | ||
description="Test Description", | ||
userId="creator", | ||
isActive=True, | ||
createdAt=datetime.now(), | ||
isTemplate=False, | ||
), | ||
) | ||
|
||
# Mock prisma calls | ||
mock_store_listing_version = mocker.patch( | ||
"prisma.models.StoreListingVersion.prisma" | ||
) | ||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock( | ||
return_value=mock_store_listing | ||
) | ||
|
||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma") | ||
mock_user_agent.return_value.create = mocker.AsyncMock() | ||
|
||
# Call function | ||
await db.add_agent_to_library("version123", "test-user") | ||
|
||
# Verify mocks called correctly | ||
mock_store_listing_version.return_value.find_unique.assert_called_once_with( | ||
where={"id": "version123"}, include={"Agent": True} | ||
) | ||
mock_user_agent.return_value.create.assert_called_once_with( | ||
data=prisma.types.UserAgentCreateInput( | ||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False | ||
) | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_add_agent_to_library_not_found(mocker): | ||
# Mock prisma calls | ||
mock_store_listing_version = mocker.patch( | ||
"prisma.models.StoreListingVersion.prisma" | ||
) | ||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock( | ||
return_value=None | ||
) | ||
|
||
# Call function and verify exception | ||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError): | ||
await db.add_agent_to_library("version123", "test-user") | ||
|
||
# Verify mock called correctly | ||
mock_store_listing_version.return_value.find_unique.assert_called_once_with( | ||
where={"id": "version123"}, include={"Agent": True} | ||
) |
16 changes: 16 additions & 0 deletions
16
autogpt_platform/backend/backend/server/v2/library/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import typing | ||
|
||
import pydantic | ||
|
||
|
||
class LibraryAgent(pydantic.BaseModel): | ||
id: str # Changed from agent_id to match GraphMeta | ||
version: int # Changed from agent_version to match GraphMeta | ||
is_active: bool # Added to match GraphMeta | ||
name: str | ||
description: str | ||
|
||
isCreatedByUser: bool | ||
# Made input_schema and output_schema match GraphMeta's type | ||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend | ||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend |
Oops, something went wrong.