From d3bd2e8f9c14716536140f3bef7fb486cab47c80 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Thu, 24 Oct 2024 18:50:17 +0900 Subject: [PATCH] feat: Add session status checker GQL mutation (#2836) --- changes/2836.feature.md | 1 + src/ai/backend/manager/api/schema.graphql | 15 +++++ src/ai/backend/manager/api/session.py | 3 +- src/ai/backend/manager/models/gql.py | 3 + .../manager/models/gql_models/session.py | 63 ++++++++++++++++++- src/ai/backend/manager/models/session.py | 23 ++++--- 6 files changed, 99 insertions(+), 9 deletions(-) create mode 100644 changes/2836.feature.md diff --git a/changes/2836.feature.md b/changes/2836.feature.md new file mode 100644 index 0000000000..f5f23245f9 --- /dev/null +++ b/changes/2836.feature.md @@ -0,0 +1 @@ +Add session status checker GQL mutation. diff --git a/src/ai/backend/manager/api/schema.graphql b/src/ai/backend/manager/api/schema.graphql index 628b65c7b5..171b4ca785 100644 --- a/src/ai/backend/manager/api/schema.graphql +++ b/src/ai/backend/manager/api/schema.graphql @@ -1674,6 +1674,9 @@ type Mutations { modify_container_registry(hostname: String!, props: ModifyContainerRegistryInput!): ModifyContainerRegistry delete_container_registry(hostname: String!): DeleteContainerRegistry modify_endpoint(endpoint_id: UUID!, props: ModifyEndpointInput!): ModifyEndpoint + + """Added in 24.09.0.""" + check_and_transit_session_status(input: CheckAndTransitStatusInput!): CheckAndTransitStatus } type ModifyAgent { @@ -2473,4 +2476,16 @@ input ExtraMountInput { Added in 24.03.4. Set permission of this mount. Should be one of (ro,rw,wd). Default is null """ permission: String +} + +"""Added in 24.12.0""" +type CheckAndTransitStatus { + item: [ComputeSessionNode] + client_mutation_id: String +} + +"""Added in 24.12.0.""" +input CheckAndTransitStatusInput { + ids: [GlobalIDField]! + client_mutation_id: String } \ No newline at end of file diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index f37d98ee94..a4892f263f 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -1030,8 +1030,9 @@ async def check_and_transit_status( log.warning( f"You are not allowed to transit others's sessions status, skip (s:{sid})" ) + + now = datetime.now(tzutc()) if accessible_session_ids: - now = datetime.now(tzutc()) session_rows = await root_ctx.registry.session_lifecycle_mgr.transit_session_status( accessible_session_ids, now ) diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index c6298ca585..d8c87944aa 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -88,6 +88,7 @@ UntagImageFromRegistry, ) from .gql_models.session import ( + CheckAndTransitStatus, ComputeSessionConnection, ComputeSessionNode, ModifyComputeSession, @@ -329,6 +330,8 @@ class Mutations(graphene.ObjectType): modify_endpoint = ModifyEndpoint.Field() + check_and_transit_session_status = CheckAndTransitStatus.Field(description="Added in 24.09.0.") + class Queries(graphene.ObjectType): """ diff --git a/src/ai/backend/manager/models/gql_models/session.py b/src/ai/backend/manager/models/gql_models/session.py index ceeb3cf55b..9649a19434 100644 --- a/src/ai/backend/manager/models/gql_models/session.py +++ b/src/ai/backend/manager/models/gql_models/session.py @@ -2,7 +2,7 @@ import uuid from collections.abc import Iterable, Sequence -from datetime import datetime +from datetime import datetime, timezone from typing import ( TYPE_CHECKING, Any, @@ -582,3 +582,64 @@ async def _update(db_session: AsyncSession) -> Optional[SessionRow]: ComputeSessionNode.from_row(graph_ctx, session_row), input.get("client_mutation_id"), ) + + +class CheckAndTransitStatusInput(graphene.InputObjectType): + class Meta: + description = "Added in 24.12.0." + + ids = graphene.List(lambda: GlobalIDField, required=True) + client_mutation_id = graphene.String(required=False) # input for relay + + +class CheckAndTransitStatus(graphene.Mutation): + allowed_roles = (UserRole.USER, UserRole.ADMIN, UserRole.SUPERADMIN) + + class Meta: + description = "Added in 24.12.0" + + class Arguments: + input = CheckAndTransitStatusInput(required=True) + + # Output fields + item = graphene.List(lambda: ComputeSessionNode) + client_mutation_id = graphene.String() # Relay output + + @classmethod + async def mutate( + cls, + root, + info: graphene.ResolveInfo, + input: CheckAndTransitStatusInput, + ) -> CheckAndTransitStatus: + graph_ctx: GraphQueryContext = info.context + session_ids = [SessionId(sid) for _, sid in input.ids] + + user_role = cast(UserRole, graph_ctx.user["role"]) + user_id = cast(uuid.UUID, graph_ctx.user["uuid"]) + accessible_session_ids: list[SessionId] = [] + now = datetime.now(timezone.utc) + + async with graph_ctx.db.connect() as db_conn: + async with graph_ctx.db.begin_readonly_session(db_conn) as db_session: + for sid in session_ids: + session_row = await SessionRow.get_session_to_determine_status(db_session, sid) + if session_row.user_uuid == user_id or user_role in ( + UserRole.ADMIN, + UserRole.SUPERADMIN, + ): + accessible_session_ids.append(sid) + + if accessible_session_ids: + session_rows = ( + await graph_ctx.registry.session_lifecycle_mgr.transit_session_status( + accessible_session_ids, now, db_conn=db_conn + ) + ) + await graph_ctx.registry.session_lifecycle_mgr.deregister_status_updatable_session([ + row.id for row, is_transited in session_rows if is_transited + ]) + result = [ComputeSessionNode.from_row(graph_ctx, row) for row, _ in session_rows] + else: + result = [] + return CheckAndTransitStatus(result, input.get("client_mutation_id")) diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index 7b8def6d9e..7164a2664e 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -1321,19 +1321,28 @@ async def transit_session_status( self, session_ids: Iterable[SessionId], status_changed_at: datetime | None = None, + *, + db_conn: Optional[SAConnection] = None, ) -> list[tuple[SessionRow, bool]]: if not session_ids: return [] now = status_changed_at or datetime.now(tzutc()) - result: list[tuple[SessionRow, bool]] = [] - async with self.db.connect() as db_conn: + + async def _transit(_db_conn: SAConnection) -> list[tuple[SessionRow, bool]]: + result: list[tuple[SessionRow, bool]] = [] for sid in session_ids: - row, is_transited = await self._transit_session_status(db_conn, sid, now) + row, is_transited = await self._transit_session_status(_db_conn, sid, now) result.append((row, is_transited)) - for row, is_transited in result: - if is_transited: - await self._post_status_transition(row) - return result + for row, is_transited in result: + if is_transited: + await self._post_status_transition(row) + return result + + if db_conn is not None: + return await _transit(db_conn) + else: + async with self.db.connect() as db_conn: + return await _transit(db_conn) async def register_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None: if not session_ids: