Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add rate limiting for computing kernel creation #2965

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions changes/2965.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add rate limiting for computing kernel creation. The number of kernels that can be created within a single scheduler tick is now capped by a config value.
5 changes: 5 additions & 0 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@
"hang-tolerance": {
"threshold": {},
},
"max-num-kernel-to-create": None,
},
}

Expand Down Expand Up @@ -436,6 +437,10 @@
t.Key(
"hang-tolerance", default=_config_defaults["session"]["hang-tolerance"]
): session_hang_tolerance_iv,
tx.AliasedKey(
["max-num-kernel-to-create", "creation-rate-limit"],
default=_config_defaults["session"]["max-num-kernel-to-create"],
): t.Null | t.ToInt,
},
).allow_extra("*"),
}).allow_extra("*")
Expand Down
109 changes: 54 additions & 55 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Optional,
Union,
cast,
)

import aiotools
Expand Down Expand Up @@ -1195,68 +1196,64 @@ def _pipeline(r: Redis) -> RedisPipeline:
self.registry,
known_slot_types,
)
max_num_kernels_to_create = cast(
Optional[int], self.shared_config.data["session"]["max-num-kernel-to-create"]
)
num_remaining_scheduled = 0

try:
async with self.lock_factory(LockID.LOCKID_PREPARE, 600):
now = datetime.now(tzutc())

async def _mark_session_preparing() -> Sequence[SessionRow]:
async def _mark_session_preparing() -> tuple[list[SessionRow], int]:
async with self.db.begin_session() as db_sess:
update_query = (
sa.update(KernelRow)
.values(
status=KernelStatus.PREPARING,
status_changed=now,
status_info="",
status_data={},
status_history=sql_json_merge(
KernelRow.status_history,
(),
{
KernelStatus.PREPARING.name: now.isoformat(),
},
),
)
.where(
(KernelRow.status == KernelStatus.SCHEDULED),
)
)
await db_sess.execute(update_query)
update_sess_query = (
sa.update(SessionRow)
.values(
status=SessionStatus.PREPARING,
# status_changed=now,
status_info="",
status_data={},
status_history=sql_json_merge(
SessionRow.status_history,
(),
{
SessionStatus.PREPARING.name: now.isoformat(),
},
),
)
.where(SessionRow.status == SessionStatus.SCHEDULED)
.returning(SessionRow.id)
)
rows = (await db_sess.execute(update_sess_query)).fetchall()
if len(rows) == 0:
return []
target_session_ids = [r["id"] for r in rows]
select_query = (
ret: list[SessionRow] = []
kernel_cnt = 0
session_query = (
sa.select(SessionRow)
.where(SessionRow.id.in_(target_session_ids))
.options(
noload("*"),
selectinload(SessionRow.kernels).noload("*"),
)
.where(SessionRow.status == SessionStatus.SCHEDULED)
.options(selectinload(SessionRow.kernels))
)
result = await db_sess.execute(select_query)
return result.scalars().all()

scheduled_sessions: Sequence[SessionRow]
scheduled_sessions = await execute_with_retry(_mark_session_preparing)
log.debug("prepare(): preparing {} session(s)", len(scheduled_sessions))
session_rows = (await db_sess.scalars(session_query)).all()
session_rows = cast(list[SessionRow], session_rows)
while session_rows:
row = session_rows.pop(0)
kernel_cnt += len(row.kernels)
if (
max_num_kernels_to_create is not None
and kernel_cnt > max_num_kernels_to_create
):
# prepare a certain number of kernels in one tick
# to prevent awaiting too many `create_kernels` RPC tasks.
# TODO: fire-and-forget the kernel creation tasks
return ret, len(session_rows) + 1
row.status = SessionStatus.PREPARING
row.status_info = ""
row.status_data = {}
row.status_history = {
**row.status_history,
SessionStatus.PREPARING.name: now.isoformat(),
}
for kern in row.kernels:
kern.status = KernelStatus.PREPARING
kern.status_changed = now
kern.status_info = ""
kern.status_data = {}
kern.status_history = {
**kern.status_history,
KernelStatus.PREPARING.name: now.isoformat(),
}
ret.append(row)
return ret, len(session_rows)

scheduled_sessions, num_remaining_scheduled = await execute_with_retry(
_mark_session_preparing
)
log.debug(
"prepare(): preparing {} session(s), {} session(s) remain scheduled",
len(scheduled_sessions),
num_remaining_scheduled,
)
async with (
async_timeout.timeout(delay=50.0),
aiotools.PersistentTaskGroup() as tg,
Expand All @@ -1281,6 +1278,8 @@ async def _mark_session_preparing() -> Sequence[SessionRow]:
redis_key, "resource_group", scheduled_session.scaling_group_name
),
)
if num_remaining_scheduled > 0:
await self.event_producer.produce_event(DoPrepareEvent())
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down