Skip to content

Commit

Permalink
feat: Add Domain GQL schema and resolver (#2934)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Oct 23, 2024
1 parent e3e1d33 commit 78cd049
Show file tree
Hide file tree
Showing 10 changed files with 1,101 additions and 48 deletions.
1 change: 1 addition & 0 deletions changes/2934.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add GQL Relay domain query schema and resolver
184 changes: 162 additions & 22 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ type Queries {
agents(scaling_group: String, status: String): [Agent]
agent_summary(agent_id: String!): AgentSummary
agent_summary_list(limit: Int!, offset: Int!, filter: String, order: String, scaling_group: String, status: String): AgentSummaryList

"""Added in 24.12.0."""
domain_node(id: GlobalIDField!, permission: DomainPermissionValueField = "read_attribute"): DomainNode

"""Added in 24.12.0."""
domain_nodes(filter: String, order: String, permission: DomainPermissionValueField = "read_attribute", offset: Int, before: String, after: String, first: Int, last: Int): DomainConnection
domain(name: String): Domain
domains(is_active: Boolean): [Domain]

Expand Down Expand Up @@ -360,6 +366,115 @@ type AgentSummaryList implements PaginatedList {
total_count: Int!
}

"""Added in 24.12.0."""
type DomainNode implements Node {
"""The ID of the object"""
id: ID!
name: String
description: String
is_active: Boolean
created_at: DateTime
modified_at: DateTime
total_resource_slots: JSONString
allowed_vfolder_hosts: JSONString
allowed_docker_registries: [String]
dotfiles: Bytes
integration_id: String
scaling_groups(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ScalinGroupConnection
}

"""Added in 24.09.1."""
scalar Bytes

"""Added in 24.12.0."""
type ScalinGroupConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo!

"""Contains the nodes in this connection."""
edges: [ScalinGroupEdge]!

"""Total count of the GQL nodes of the query."""
count: Int
}

"""
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
"""
type PageInfo {
"""When paginating forwards, are there more items?"""
hasNextPage: Boolean!

"""When paginating backwards, are there more items?"""
hasPreviousPage: Boolean!

"""When paginating backwards, the cursor to continue."""
startCursor: String

"""When paginating forwards, the cursor to continue."""
endCursor: String
}

"""
Added in 24.12.0. A Relay edge containing a `ScalinGroup` and its cursor.
"""
type ScalinGroupEdge {
"""The item at the end of the edge"""
node: ScalingGroupNode

"""A cursor for use in pagination"""
cursor: String!
}

"""Added in 24.12.0."""
type ScalingGroupNode implements Node {
"""The ID of the object"""
id: ID!
name: String
description: String
is_active: Boolean
is_public: Boolean
created_at: DateTime
wsproxy_addr: String
wsproxy_api_token: String
driver: String
driver_opts: JSONString
scheduler: String
scheduler_opts: JSONString
use_host_network: Boolean
}

"""
Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of "<node type name>:<node id>". UUID or string type values are also allowed.
"""
scalar GlobalIDField

"""
Added in 24.12.0. One of ['read_attribute', 'read_sensitive_attribute', 'update_attribute', 'create_user', 'create_project'].
"""
scalar DomainPermissionValueField

"""Added in 24.12.0"""
type DomainConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo!

"""Contains the nodes in this connection."""
edges: [DomainEdge]!

"""Total count of the GQL nodes of the query."""
count: Int
}

"""Added in 24.12.0 A Relay edge containing a `Domain` and its cursor."""
type DomainEdge {
"""The item at the end of the edge"""
node: DomainNode

"""A cursor for use in pagination"""
cursor: String!
}

type Domain {
name: String
description: String
Expand Down Expand Up @@ -411,23 +526,6 @@ type UserConnection {
count: Int
}

"""
The Relay compliant `PageInfo` type, containing data necessary to paginate this connection.
"""
type PageInfo {
"""When paginating forwards, are there more items?"""
hasNextPage: Boolean!

"""When paginating backwards, are there more items?"""
hasPreviousPage: Boolean!

"""When paginating backwards, the cursor to continue."""
startCursor: String

"""When paginating forwards, the cursor to continue."""
endCursor: String
}

"""Added in 24.03.0 A Relay edge containing a `User` and its cursor."""
type UserEdge {
"""The item at the end of the edge"""
Expand Down Expand Up @@ -1041,11 +1139,6 @@ type ComputeSessionEdge {
cursor: String!
}

"""
Added in 24.09.0. Global ID of GQL relay spec. Base64 encoded version of "<node type name>:<node id>". UUID or string type values are also allowed.
"""
scalar GlobalIDField

type ComputeSessionList implements PaginatedList {
items: [ComputeSession]!
total_count: Int!
Expand Down Expand Up @@ -1404,6 +1497,12 @@ type Mutations {
To purge domain, there should be no users and groups in the target domain.
"""
purge_domain(name: String!): PurgeDomain

"""Added in 24.12.0."""
create_domain_node(input: CreateDomainNodeInput!): CreateDomainNode

"""Added in 24.12.0."""
modify_domain_node(input: ModifyDomainNodeInput!): ModifyDomainNode
create_group(name: String!, props: GroupInput!): CreateGroup
modify_group(gid: UUID!, props: ModifyGroupInput!): ModifyGroup

Expand Down Expand Up @@ -1637,6 +1736,47 @@ type PurgeDomain {
msg: String
}

"""Added in 24.12.0."""
type CreateDomainNode {
ok: Boolean
msg: String
item: DomainNode
}

"""Added in 24.12.0."""
input CreateDomainNodeInput {
name: String!
description: String
is_active: Boolean = true
total_resource_slots: JSONString = "{}"
allowed_vfolder_hosts: JSONString = "{}"
allowed_docker_registries: [String] = []
integration_id: String = null
dotfiles: Bytes = "90"
scaling_groups: [String]
}

"""Added in 24.12.0."""
type ModifyDomainNode {
item: DomainNode
client_mutation_id: String
}

"""Added in 24.12.0."""
input ModifyDomainNodeInput {
id: GlobalIDField!
description: String
is_active: Boolean
total_resource_slots: JSONString
allowed_vfolder_hosts: JSONString
allowed_docker_registries: [String]
integration_id: String
dotfiles: Bytes
sgroups_to_add: [String]
sgroups_to_remove: [String]
client_mutation_id: String
}

type CreateGroup {
ok: Boolean
msg: String
Expand Down
23 changes: 23 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,29 @@ async def batch_multiresult_in_session(
return [*objs_per_key.values()]


async def batch_multiresult_in_scalar_stream(
graph_ctx: GraphQueryContext,
db_sess: SASession,
query: sa.sql.Select,
obj_type: type[T_SQLBasedGQLObject],
key_list: Iterable[T_Key],
key_getter: Callable[[Row], T_Key],
) -> Sequence[Sequence[T_SQLBasedGQLObject]]:
"""
A batched query adaptor for (key -> [item]) resolving patterns.
stream the result in async session.
"""
objs_per_key: dict[T_Key, list[T_SQLBasedGQLObject]]
objs_per_key = dict()
for key in key_list:
objs_per_key[key] = list()
async for row in await db_sess.stream_scalars(query):
objs_per_key[key_getter(row)].append(
obj_type.from_row(graph_ctx, row),
)
return [*objs_per_key.values()]


def privileged_query(required_role: UserRole):
def wrap(func):
@functools.wraps(func)
Expand Down
51 changes: 33 additions & 18 deletions src/ai/backend/manager/models/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sqlalchemy.orm import load_only, relationship

from ai.backend.common import msgpack
from ai.backend.common.types import ResourceSlot
from ai.backend.common.types import ResourceSlot, VFolderHostPermissionMap
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.group import ProjectType

Expand Down Expand Up @@ -135,11 +135,12 @@ class DomainModel(RBACModel[DomainPermission]):
modified_at: datetime

_total_resource_slots: Optional[dict]
_allowed_vfolder_hosts: dict
_allowed_vfolder_hosts: VFolderHostPermissionMap
_allowed_docker_registries: list[str]
_integration_id: Optional[str]
_dotfiles: str

orm_obj: DomainRow
_permissions: frozenset[DomainPermission] = field(default_factory=frozenset)

@property
Expand All @@ -153,7 +154,7 @@ def total_resource_slots(self) -> Optional[dict]:

@property
@required_permission(DomainPermission.READ_SENSITIVE_ATTRIBUTE)
def allowed_vfolder_hosts(self) -> dict:
def allowed_vfolder_hosts(self) -> VFolderHostPermissionMap:
return self._allowed_vfolder_hosts

@property
Expand Down Expand Up @@ -185,6 +186,7 @@ def from_row(cls, row: DomainRow, permissions: Iterable[DomainPermission]) -> Se
_integration_id=row.integration_id,
_dotfiles=row.dotfiles,
_permissions=frozenset(permissions),
orm_obj=row,
)


Expand Down Expand Up @@ -658,24 +660,37 @@ async def _permission_for_member(
return MEMBER_PERMISSIONS


async def get_permission_ctx(
target_scope: ScopeType,
requested_permission: DomainPermission,
*,
ctx: ClientContext,
db_session: SASession,
) -> DomainPermissionContext:
builder = DomainPermissionContextBuilder(db_session)
permission_ctx = await builder.build(ctx, target_scope, requested_permission)
return permission_ctx


async def get_domains(
target_scope: ScopeType,
requested_permission: DomainPermission,
domain_name: Optional[str] = None,
domain_names: Optional[Iterable[str]] = None,
*,
ctx: ClientContext,
db_conn: SAConnection,
db_session: SASession,
) -> list[DomainModel]:
async with ctx.db.begin_readonly_session(db_conn) as db_session:
builder = DomainPermissionContextBuilder(db_session)
permission_ctx = await builder.build(ctx, target_scope, requested_permission)
query_stmt = await permission_ctx.build_query()
if query_stmt is None:
return []
if domain_name is not None:
query_stmt = query_stmt.where(DomainRow.name == domain_name)
result: list[DomainModel] = []
async for row in await db_session.stream_scalars(query_stmt):
permissions = await permission_ctx.calculate_final_permission(row)
result.append(DomainModel.from_row(row, permissions))
return result
ret: list[DomainModel] = []
permission_ctx = await get_permission_ctx(
target_scope, requested_permission, ctx=ctx, db_session=db_session
)
cond = permission_ctx.query_condition
if cond is None:
return ret
query_stmt = sa.select(DomainRow).where(cond)
if domain_names is not None:
query_stmt = query_stmt.where(DomainRow.name.in_(domain_names))
async for row in await db_session.stream_scalars(query_stmt):
permissions = await permission_ctx.calculate_final_permission(row)
ret.append(DomainModel.from_row(row, permissions))
return ret
Loading

0 comments on commit 78cd049

Please sign in to comment.