-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup handle_worker()
in preparation for #2815 (Stop generation early)
#3573
Conversation
async def add_worker_connect_event( | ||
session: database.AsyncSession, | ||
worker_id: str, | ||
worker_info: inference.WorkerInfo, | ||
): | ||
event = models.DbWorkerEvent( | ||
worker_id=worker_id, | ||
event_type=models.WorkerEventType.connect, | ||
worker_info=worker_info, | ||
) | ||
session.add(event) | ||
await session.commit() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to the SessionManager
without changes
❌ pre-commit failed. |
try: | ||
worker_utils.get_protocol_version(protocol_version) | ||
api_key = worker_utils.get_api_key(api_key) | ||
worker_id = await worker_utils.get_worker_id(api_key=api_key, protocol_version=protocol_version) | ||
except fastapi.HTTPException as e: | ||
logger.warning(f"handle_worker: {e.status_code=} {e.detail=}") | ||
if e.status_code == fastapi.status.HTTP_426_UPGRADE_REQUIRED: | ||
await worker_utils.send_worker_request(websocket=websocket, request=inference.UpgradeProtocolRequest()) | ||
elif e.status_code == fastapi.status.HTTP_401_UNAUTHORIZED: | ||
await worker_utils.send_worker_request(websocket=websocket, request=inference.WrongApiKeyRequest()) | ||
try: | ||
await websocket.close(code=e.status_code, reason=e.detail) | ||
except Exception: | ||
pass | ||
raise fastapi.WebSocketException(e.status_code, e.detail) | ||
|
||
logger.info(f"handle_worker: {worker_id=}") | ||
worker_info = await worker_utils.receive_worker_info(websocket) | ||
logger.info(f"handle_worker: {worker_info=}") | ||
worker_config = worker_info.config | ||
worker_compat_hash = worker_config.compat_hash | ||
work_queue = queueing.work_queue(deps.redis_client, worker_compat_hash) | ||
redis_client = deps.make_redis_client() | ||
blocking_work_queue = queueing.work_queue(redis_client, worker_compat_hash) | ||
worker_session = worker_utils.WorkerSession( | ||
worker_id=worker_id, | ||
worker_info=worker_info, | ||
) | ||
work_request_map: dict[str, WorkRequestContainer] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to HandleWorkerContext.build()
without changes
try: | ||
async with deps.manual_create_session() as session: | ||
await add_worker_connect_event(session=session, worker_id=worker_id, worker_info=worker_info) | ||
await worker_utils.store_worker_session(worker_session) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to SessionManager.init()
async def _update_session(metrics: inference.WorkerMetricsInfo): | ||
worker_session.requests_in_flight = len(work_request_map) | ||
if metrics: | ||
worker_session.metrics = metrics | ||
await worker_utils.store_worker_session(worker_session) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to SessionManager.update()
def _add_dequeue(ftrs: set): | ||
requests_in_progress = len(work_request_map) | ||
if requests_in_progress < worker_config.max_parallel_requests: | ||
ftrs.add(asyncio.ensure_future(blocking_work_queue.dequeue(timeout=0))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to FuturesManager.ensure_listening_for_work_requests()
. The future is now wrapped such that it returns a tuple[FutureType, X]
where X
is the return type of whatever the future returned before.
def _add_receive(ftrs: set): | ||
ftrs.add(asyncio.ensure_future(worker_utils.receive_worker_response(websocket=websocket))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to FuturesManager.ensure_listening_for_worker_response()
with similar changes to return the FutureType
in addition to the response.
(done, pending_futures) = await asyncio.wait( | ||
pending_futures, timeout=settings.worker_ping_interval, return_when=asyncio.FIRST_COMPLETED | ||
) | ||
ftr: asyncio.Future | ||
for ftr in done: | ||
result = ftr.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to FuturesManager.wait_for_event_or_timeout()
if result is None: | ||
logger.error(f"handle_worker: {worker_id=} received None from queue. This should never happen.") | ||
raise RuntimeError("Received None from queue. This should never happen.") | ||
elif isinstance(result, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the test that has so far been used to identify a dequeued work request, which has now been replaced by a check for type_ == FutureType.WORK_REQUEST
try: | ||
_, message_id = result | ||
work_request = await initiate_work_for_message( | ||
websocket=websocket, | ||
work_queue=work_queue, | ||
message_id=message_id, | ||
worker_id=worker_id, | ||
worker_config=worker_config, | ||
) | ||
work_request_map[work_request.id] = WorkRequestContainer( | ||
work_request=work_request, message_id=message_id | ||
) | ||
except chat_schema.MessageCancelledException as e: | ||
logger.warning(f"Message was cancelled before work could be initiated: {e.message_id=}") | ||
except chat_schema.MessageTimeoutException as e: | ||
logger.warning(f"Message timed out before work could be initiated: {e.message.id=}") | ||
await handle_timeout(message=e.message) | ||
finally: | ||
_add_dequeue(pending_futures) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to _handle_work_request
try: | ||
worker_response: inference.WorkerResponse = result | ||
match worker_response.response_type: | ||
case "pong": | ||
worker_response = cast(inference.PongResponse, worker_response) | ||
await _update_session(worker_response.metrics) | ||
case "token": | ||
worker_response = cast(inference.TokenResponse, worker_response) | ||
await handle_token_response( | ||
work_request_map=work_request_map, | ||
response=worker_response, | ||
) | ||
case "generated_text": | ||
worker_response = cast(inference.GeneratedTextResponse, worker_response) | ||
await handle_generated_text_response( | ||
work_request_map=work_request_map, | ||
response=worker_response, | ||
) | ||
await _update_session(worker_response.metrics) | ||
case "error": | ||
worker_response = cast(inference.ErrorResponse, worker_response) | ||
await handle_error_response( | ||
work_request_map=work_request_map, | ||
response=worker_response, | ||
) | ||
await _update_session(worker_response.metrics) | ||
case "general_error": | ||
worker_response = cast(inference.GeneralErrorResponse, worker_response) | ||
await handle_general_error_response( | ||
response=worker_response, | ||
) | ||
await _update_session(worker_response.metrics) | ||
case "safe_prompt": | ||
logger.info("Received safe prompt response") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to _handle_worker_response()
@@ -53,7 +53,7 @@ def get_protocol_version(protocol_version: str = protocol_version_header) -> str | |||
async def get_worker_id( | |||
api_key: str = Depends(get_api_key), | |||
protocol_version: str = Depends(get_protocol_version), | |||
) -> models.DbWorker: | |||
) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed that the type was wrong.
if len(futures._futures) == 0: | ||
futures.ensure_listening_for_work_requests() | ||
futures.ensure_listening_to_worker_responses() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to avoid logical changes as much as possible, I kept this piece of code here as it was, but that now means that it accesses the internal structure of the FuturesManager
.
I'm considering whether it might be nicer to not throw all the futures into a single set, but instead keep one future in a specific variable for listening to worker responses, and a set for listening to the potentially multiple futures with which we listen for work requests. The futures manager could then have only a single method that ensures that there's at least one pending future of each type, and only that would be called from the outside at the end of each loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for delayed review - thanks, these changes make sense to me
I saw you have a checklist in the original comment so have not merged yet but if this is ready to go as-is let me know and I'll merge
@0xfacade Do you still want to complete work on the todo points (unticked checkboxes) or can it be merged as it is? |
@olliestanley @andreaskoepf This can be merged! Will work on the extension for the stop generating feature now |
In this PR, I clean up the
handle_worker()
method a bit so that I can later extend it (in a future PR). There are no functional changes in this PR.Changes:
HandleWorkerContext
that also features methods for initialization and destructionSessionManager
FuturesManager
The last change is the most important one for my future changes. In the main loop of
handle_worker()
, we were already waiting for two different types of futures: newly dequeued work requests from the Redis work queue, and responses from the worker received over the websocket. I'll need to add a third type of future next that allows us to listen to requests to stop generating text (#2815). The results of the different futures used to be differentiated based on their return type, which was very hard to read. I've created a decorator inFuturesManager
that wraps the awaitable in another awaitable that returns a tuple, where the first entry is aFutureType
enum value, and the second value is the result of awaiting the passed in awaitable. This makes it easy to distinguish what type of result was received.I tested my changes by spinning up the inference server + worker with
docker compose
. Then I used thetext-client
to interface with the server.Open issues:
pre-commit
hook - working on ittesting?There don't seem to be any tests