Skip to content

Commit

Permalink
fix(agents-api): Fix cozo queries and make tests pass
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Jun 6, 2024
1 parent c39ca0a commit 3c241b2
Show file tree
Hide file tree
Showing 11 changed files with 351 additions and 11 deletions.
9 changes: 6 additions & 3 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import Literal, Dict, Any
from uuid import UUID

from beartype import beartype

from ..utils import cozo_query
from typing import Literal, Dict, Any


@cozo_query
@beartype
def create_execution_query(
developer_id: UUID,
agent_id: UUID,
task_id: UUID,
execution_id: UUID,
developer_id: UUID,
status: Literal[
"queued", "starting", "running", "waiting-for-input", "success", "failed"
"queued", "starting", "running", "waiting_for_input", "success", "failed"
] = "queued",
arguments: Dict[str, Any] = {},
) -> tuple[str, dict]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Literal, Dict, Any
from uuid import UUID

from beartype import beartype

from ..utils import cozo_query


@cozo_query
@beartype
def create_execution_transition_query(
developer_id: UUID,
execution_id: UUID,
transition_id: UUID,
type_: Literal["finished", "waiting", "error", "step"],
from_: tuple[str, int],
to: tuple[str, int] | None,
output: Dict[str, Any],
) -> tuple[str, dict]:
# TODO: Check for agent in developer ID; Assert whether dev can access agent and by relation the task
# TODO: Check for task and execution

query = """
{
?[execution_id, transition_id, type, from, to, output] <- [[
to_uuid($execution_id),
to_uuid($transition_id),
$type,
$from,
$to,
$output,
]]
:insert transitions {
execution_id, transition_id, type, from, to, output
}
}
"""
return (
query,
{
"execution_id": str(execution_id),
"transition_id": str(transition_id),
"type": type_,
"from": from_,
"to": to,
"output": output,
},
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from uuid import UUID

from beartype import beartype

from ..utils import cozo_query


@cozo_query
@beartype
def get_execution_transition_query(
execution_id: UUID, transition_id: UUID
) -> tuple[str, dict]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from uuid import UUID

from beartype import beartype

from ..utils import cozo_query


@cozo_query
def list_execution_transition_query(execution_id: UUID) -> tuple[str, dict]:
@beartype
def list_execution_transitions_query(execution_id: UUID) -> tuple[str, dict]:

query = """
{
Expand Down
257 changes: 257 additions & 0 deletions agents-api/agents_api/models/execution/test_execution_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Tests for execution queries
from uuid import uuid4

from cozo_migrate.api import init, apply
from pycozo import Client
from ward import test

from .create_execution import create_execution_query
from .get_execution_status import get_execution_status_query
from .get_execution import get_execution_query
from .list_executions import list_task_executions_query
from .update_execution_status import update_execution_status_query
from .create_execution_transition import create_execution_transition_query
from .get_execution_transition import get_execution_transition_query
from .list_execution_transitions import list_execution_transitions_query
from .update_execution_transition import update_execution_transition_query


def cozo_client(migrations_dir: str = "./migrations"):
# Create a new client for each test
# and initialize the schema.
client = Client()

init(client)
apply(client, migrations_dir=migrations_dir, all_=True)

return client


@test("model: create execution")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()
execution_id = uuid4()

create_execution_query(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
execution_id=execution_id,
arguments={"input": "test"},
client=client,
)


@test("model: get execution")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()
execution_id = uuid4()

create_execution_query(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
execution_id=execution_id,
arguments={"input": "test"},
client=client,
)

result = get_execution_query(
task_id=task_id, execution_id=execution_id, client=client
)

assert len(result["status"]) == 1
assert result["status"][0] == "queued"


@test("model: get execution status")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()
execution_id = uuid4()

create_execution_query(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
execution_id=execution_id,
arguments={"input": "test"},
client=client,
)

result = get_execution_status_query(
task_id=task_id, execution_id=execution_id, client=client
)

assert len(result["status"]) == 1
assert result["status"][0] == "queued"


@test("model: list executions empty")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()

result = list_task_executions_query(
task_id=task_id, agent_id=agent_id, developer_id=developer_id, client=client
)

assert len(result) == 0


@test("model: list executions")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()
execution_id = uuid4()

create_execution_query(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
execution_id=execution_id,
arguments={"input": "test"},
client=client,
)

result = list_task_executions_query(
task_id=task_id, agent_id=agent_id, developer_id=developer_id, client=client
)

assert len(result["status"]) == 1
assert result["status"][0] == "queued"


@test("model: update execution status")
def _():
client = cozo_client()
developer_id = uuid4()
agent_id = uuid4()
task_id = uuid4()
execution_id = uuid4()

create_execution_query(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
execution_id=execution_id,
arguments={"input": "test"},
client=client,
)

result = update_execution_status_query(
task_id=task_id, execution_id=execution_id, status="running", client=client
)

updated_rows = result[result["_kind"] == "inserted"].reset_index()
assert len(updated_rows) == 1
assert updated_rows["status"][0] == "running"


@test("model: create execution transition")
def _():
client = cozo_client()
developer_id = uuid4()
execution_id = uuid4()
transition_id = uuid4()

create_execution_transition_query(
developer_id=developer_id,
execution_id=execution_id,
transition_id=transition_id,
type_="step",
from_=("test", 1),
to=("test", 2),
output={"input": "test"},
client=client,
)


@test("model: get execution transition")
def _():
client = cozo_client()
developer_id = uuid4()
execution_id = uuid4()
transition_id = uuid4()

create_execution_transition_query(
developer_id=developer_id,
execution_id=execution_id,
transition_id=transition_id,
type_="step",
from_=("test", 1),
to=("test", 2),
output={"input": "test"},
client=client,
)

result = get_execution_transition_query(
execution_id=execution_id, transition_id=transition_id, client=client
)

assert len(result["type"]) == 1


@test("model: list execution transitions")
def _():
client = cozo_client()
developer_id = uuid4()
execution_id = uuid4()
transition_id = uuid4()

create_execution_transition_query(
developer_id=developer_id,
execution_id=execution_id,
transition_id=transition_id,
type_="step",
from_=("test", 1),
to=("test", 2),
output={"input": "test"},
client=client,
)

result = list_execution_transitions_query(execution_id=execution_id, client=client)

assert len(result["type"]) == 1


@test("model: update execution transitions")
def _():
client = cozo_client()
developer_id = uuid4()
execution_id = uuid4()
transition_id = uuid4()

create_execution_transition_query(
developer_id=developer_id,
execution_id=execution_id,
transition_id=transition_id,
type_="step",
from_=("test", 1),
to=("test", 2),
output={"input": "test"},
client=client,
)

result = update_execution_transition_query(
execution_id=execution_id,
transition_id=transition_id,
type="finished",
client=client,
)

updated_rows = result[result["_kind"] == "inserted"].reset_index()
assert len(updated_rows) == 1
assert updated_rows["type"][0] == "finished"
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Literal, Dict, Any
from uuid import UUID

from beartype import beartype

from ..utils import cozo_query
from typing import Literal, Dict, Any


@cozo_query
@beartype
def update_execution_status_query(
task_id: UUID,
execution_id: UUID,
status: Literal[
"queued", "starting", "running", "waiting-for-input", "success", "failed"
"queued", "starting", "running", "waiting_for_input", "success", "failed"
],
arguments: Dict[str, Any] = {},
) -> tuple[str, dict]:
Expand All @@ -28,6 +31,8 @@ def update_execution_status_query(
status,
updated_at
}
:returning
}
"""
Expand Down
Loading

0 comments on commit 3c241b2

Please sign in to comment.