Skip to content

Commit

Permalink
feat: Add session status checker GQL mutation (#2836)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Oct 24, 2024
1 parent 1bc1925 commit d3bd2e8
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 9 deletions.
1 change: 1 addition & 0 deletions changes/2836.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add session status checker GQL mutation.
15 changes: 15 additions & 0 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,9 @@ type Mutations {
modify_container_registry(hostname: String!, props: ModifyContainerRegistryInput!): ModifyContainerRegistry
delete_container_registry(hostname: String!): DeleteContainerRegistry
modify_endpoint(endpoint_id: UUID!, props: ModifyEndpointInput!): ModifyEndpoint

"""Added in 24.09.0."""
check_and_transit_session_status(input: CheckAndTransitStatusInput!): CheckAndTransitStatus
}

type ModifyAgent {
Expand Down Expand Up @@ -2473,4 +2476,16 @@ input ExtraMountInput {
Added in 24.03.4. Set permission of this mount. Should be one of (ro,rw,wd). Default is null
"""
permission: String
}

"""Added in 24.12.0"""
type CheckAndTransitStatus {
item: [ComputeSessionNode]
client_mutation_id: String
}

"""Added in 24.12.0."""
input CheckAndTransitStatusInput {
ids: [GlobalIDField]!
client_mutation_id: String
}
3 changes: 2 additions & 1 deletion src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,8 +1030,9 @@ async def check_and_transit_status(
log.warning(
f"You are not allowed to transit others's sessions status, skip (s:{sid})"
)

now = datetime.now(tzutc())
if accessible_session_ids:
now = datetime.now(tzutc())
session_rows = await root_ctx.registry.session_lifecycle_mgr.transit_session_status(
accessible_session_ids, now
)
Expand Down
3 changes: 3 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
UntagImageFromRegistry,
)
from .gql_models.session import (
CheckAndTransitStatus,
ComputeSessionConnection,
ComputeSessionNode,
ModifyComputeSession,
Expand Down Expand Up @@ -329,6 +330,8 @@ class Mutations(graphene.ObjectType):

modify_endpoint = ModifyEndpoint.Field()

check_and_transit_session_status = CheckAndTransitStatus.Field(description="Added in 24.09.0.")


class Queries(graphene.ObjectType):
"""
Expand Down
63 changes: 62 additions & 1 deletion src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from collections.abc import Iterable, Sequence
from datetime import datetime
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -582,3 +582,64 @@ async def _update(db_session: AsyncSession) -> Optional[SessionRow]:
ComputeSessionNode.from_row(graph_ctx, session_row),
input.get("client_mutation_id"),
)


class CheckAndTransitStatusInput(graphene.InputObjectType):
class Meta:
description = "Added in 24.12.0."

ids = graphene.List(lambda: GlobalIDField, required=True)
client_mutation_id = graphene.String(required=False) # input for relay


class CheckAndTransitStatus(graphene.Mutation):
allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN)

class Meta:
description = "Added in 24.12.0"

class Arguments:
input = CheckAndTransitStatusInput(required=True)

# Output fields
item = graphene.List(lambda: ComputeSessionNode)
client_mutation_id = graphene.String() # Relay output

@classmethod
async def mutate(
cls,
root,
info: graphene.ResolveInfo,
input: CheckAndTransitStatusInput,
) -> CheckAndTransitStatus:
graph_ctx: GraphQueryContext = info.context
session_ids = [SessionId(sid) for _, sid in input.ids]

user_role = cast(UserRole, graph_ctx.user["role"])
user_id = cast(uuid.UUID, graph_ctx.user["uuid"])
accessible_session_ids: list[SessionId] = []
now = datetime.now(timezone.utc)

async with graph_ctx.db.connect() as db_conn:
async with graph_ctx.db.begin_readonly_session(db_conn) as db_session:
for sid in session_ids:
session_row = await SessionRow.get_session_to_determine_status(db_session, sid)
if session_row.user_uuid == user_id or user_role in (
UserRole.ADMIN,
UserRole.SUPERADMIN,
):
accessible_session_ids.append(sid)

if accessible_session_ids:
session_rows = (
await graph_ctx.registry.session_lifecycle_mgr.transit_session_status(
accessible_session_ids, now, db_conn=db_conn
)
)
await graph_ctx.registry.session_lifecycle_mgr.deregister_status_updatable_session([
row.id for row, is_transited in session_rows if is_transited
])
result = [ComputeSessionNode.from_row(graph_ctx, row) for row, _ in session_rows]
else:
result = []
return CheckAndTransitStatus(result, input.get("client_mutation_id"))
23 changes: 16 additions & 7 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,19 +1321,28 @@ async def transit_session_status(
self,
session_ids: Iterable[SessionId],
status_changed_at: datetime | None = None,
*,
db_conn: Optional[SAConnection] = None,
) -> list[tuple[SessionRow, bool]]:
if not session_ids:
return []
now = status_changed_at or datetime.now(tzutc())
result: list[tuple[SessionRow, bool]] = []
async with self.db.connect() as db_conn:

async def _transit(_db_conn: SAConnection) -> list[tuple[SessionRow, bool]]:
result: list[tuple[SessionRow, bool]] = []
for sid in session_ids:
row, is_transited = await self._transit_session_status(db_conn, sid, now)
row, is_transited = await self._transit_session_status(_db_conn, sid, now)
result.append((row, is_transited))
for row, is_transited in result:
if is_transited:
await self._post_status_transition(row)
return result
for row, is_transited in result:
if is_transited:
await self._post_status_transition(row)
return result

if db_conn is not None:
return await _transit(db_conn)
else:
async with self.db.connect() as db_conn:
return await _transit(db_conn)

async def register_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
if not session_ids:
Expand Down

0 comments on commit d3bd2e8

Please sign in to comment.