Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup handle_worker() in preparation for #2815 (Stop generation early) #3573

Merged
merged 6 commits into from
Jul 25, 2023

Conversation

0xfacade
Copy link
Contributor

@0xfacade 0xfacade commented Jul 15, 2023

In this PR, I clean up the handle_worker() method a bit so that I can later extend it (in a future PR). There are no functional changes in this PR.

Changes:

  • collect the many variables in a new class HandleWorkerContext that also features methods for initialization and destruction
  • collect methods to handle updating the session in new class SessionManager
  • move management of futures into a new class FuturesManager
  • extract the logic for handling a work request and a worker response from the main loop into their own respective functions

The last change is the most important one for my future changes. In the main loop of handle_worker(), we were already waiting for two different types of futures: newly dequeued work requests from the Redis work queue, and responses from the worker received over the websocket. I'll need to add a third type of future next that allows us to listen to requests to stop generating text (#2815). The results of the different futures used to be differentiated based on their return type, which was very hard to read. I've created a decorator in FuturesManager that wraps the awaitable in another awaitable that returns a tuple, where the first entry is a FutureType enum value, and the second value is the result of awaiting the passed in awaitable. This makes it easy to distinguish what type of result was received.

I tested my changes by spinning up the inference server + worker with docker compose. Then I used the text-client to interface with the server.

Open issues:

  • haven't yet been able to run the pre-commit hook - working on it
  • need to add more docstrings
  • testing? There don't seem to be any tests

Comment on lines -46 to -57
async def add_worker_connect_event(
session: database.AsyncSession,
worker_id: str,
worker_info: inference.WorkerInfo,
):
event = models.DbWorkerEvent(
worker_id=worker_id,
event_type=models.WorkerEventType.connect,
worker_info=worker_info,
)
session.add(event)
await session.commit()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to the SessionManager without changes

@github-actions
Copy link

pre-commit failed.
Please run pre-commit run --all-files locally and commit the changes.
Find more information in the repository's CONTRIBUTING.md

Comment on lines -96 to -124
try:
worker_utils.get_protocol_version(protocol_version)
api_key = worker_utils.get_api_key(api_key)
worker_id = await worker_utils.get_worker_id(api_key=api_key, protocol_version=protocol_version)
except fastapi.HTTPException as e:
logger.warning(f"handle_worker: {e.status_code=} {e.detail=}")
if e.status_code == fastapi.status.HTTP_426_UPGRADE_REQUIRED:
await worker_utils.send_worker_request(websocket=websocket, request=inference.UpgradeProtocolRequest())
elif e.status_code == fastapi.status.HTTP_401_UNAUTHORIZED:
await worker_utils.send_worker_request(websocket=websocket, request=inference.WrongApiKeyRequest())
try:
await websocket.close(code=e.status_code, reason=e.detail)
except Exception:
pass
raise fastapi.WebSocketException(e.status_code, e.detail)

logger.info(f"handle_worker: {worker_id=}")
worker_info = await worker_utils.receive_worker_info(websocket)
logger.info(f"handle_worker: {worker_info=}")
worker_config = worker_info.config
worker_compat_hash = worker_config.compat_hash
work_queue = queueing.work_queue(deps.redis_client, worker_compat_hash)
redis_client = deps.make_redis_client()
blocking_work_queue = queueing.work_queue(redis_client, worker_compat_hash)
worker_session = worker_utils.WorkerSession(
worker_id=worker_id,
worker_info=worker_info,
)
work_request_map: dict[str, WorkRequestContainer] = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to HandleWorkerContext.build() without changes

Comment on lines -126 to -129
try:
async with deps.manual_create_session() as session:
await add_worker_connect_event(session=session, worker_id=worker_id, worker_info=worker_info)
await worker_utils.store_worker_session(worker_session)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to SessionManager.init()

Comment on lines -131 to -135
async def _update_session(metrics: inference.WorkerMetricsInfo):
worker_session.requests_in_flight = len(work_request_map)
if metrics:
worker_session.metrics = metrics
await worker_utils.store_worker_session(worker_session)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to SessionManager.update()

Comment on lines -137 to -140
def _add_dequeue(ftrs: set):
requests_in_progress = len(work_request_map)
if requests_in_progress < worker_config.max_parallel_requests:
ftrs.add(asyncio.ensure_future(blocking_work_queue.dequeue(timeout=0)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to FuturesManager.ensure_listening_for_work_requests(). The future is now wrapped such that it returns a tuple[FutureType, X] where X is the return type of whatever the future returned before.

Comment on lines -142 to -143
def _add_receive(ftrs: set):
ftrs.add(asyncio.ensure_future(worker_utils.receive_worker_response(websocket=websocket)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to FuturesManager.ensure_listening_for_worker_response() with similar changes to return the FutureType in addition to the response.

Comment on lines -152 to -157
(done, pending_futures) = await asyncio.wait(
pending_futures, timeout=settings.worker_ping_interval, return_when=asyncio.FIRST_COMPLETED
)
ftr: asyncio.Future
for ftr in done:
result = ftr.result()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to FuturesManager.wait_for_event_or_timeout()

if result is None:
logger.error(f"handle_worker: {worker_id=} received None from queue. This should never happen.")
raise RuntimeError("Received None from queue. This should never happen.")
elif isinstance(result, tuple):
Copy link
Contributor Author

@0xfacade 0xfacade Jul 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the test that has so far been used to identify a dequeued work request, which has now been replaced by a check for type_ == FutureType.WORK_REQUEST

Comment on lines -162 to -180
try:
_, message_id = result
work_request = await initiate_work_for_message(
websocket=websocket,
work_queue=work_queue,
message_id=message_id,
worker_id=worker_id,
worker_config=worker_config,
)
work_request_map[work_request.id] = WorkRequestContainer(
work_request=work_request, message_id=message_id
)
except chat_schema.MessageCancelledException as e:
logger.warning(f"Message was cancelled before work could be initiated: {e.message_id=}")
except chat_schema.MessageTimeoutException as e:
logger.warning(f"Message timed out before work could be initiated: {e.message.id=}")
await handle_timeout(message=e.message)
finally:
_add_dequeue(pending_futures)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to _handle_work_request

Comment on lines -182 to -215
try:
worker_response: inference.WorkerResponse = result
match worker_response.response_type:
case "pong":
worker_response = cast(inference.PongResponse, worker_response)
await _update_session(worker_response.metrics)
case "token":
worker_response = cast(inference.TokenResponse, worker_response)
await handle_token_response(
work_request_map=work_request_map,
response=worker_response,
)
case "generated_text":
worker_response = cast(inference.GeneratedTextResponse, worker_response)
await handle_generated_text_response(
work_request_map=work_request_map,
response=worker_response,
)
await _update_session(worker_response.metrics)
case "error":
worker_response = cast(inference.ErrorResponse, worker_response)
await handle_error_response(
work_request_map=work_request_map,
response=worker_response,
)
await _update_session(worker_response.metrics)
case "general_error":
worker_response = cast(inference.GeneralErrorResponse, worker_response)
await handle_general_error_response(
response=worker_response,
)
await _update_session(worker_response.metrics)
case "safe_prompt":
logger.info("Received safe prompt response")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to _handle_worker_response()

@@ -53,7 +53,7 @@ def get_protocol_version(protocol_version: str = protocol_version_header) -> str
async def get_worker_id(
api_key: str = Depends(get_api_key),
protocol_version: str = Depends(get_protocol_version),
) -> models.DbWorker:
) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed that the type was wrong.

Comment on lines +325 to +327
if len(futures._futures) == 0:
futures.ensure_listening_for_work_requests()
futures.ensure_listening_to_worker_responses()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to avoid logical changes as much as possible, I kept this piece of code here as it was, but that now means that it accesses the internal structure of the FuturesManager.

I'm considering whether it might be nicer to not throw all the futures into a single set, but instead keep one future in a specific variable for listening to worker responses, and a set for listening to the potentially multiple futures with which we listen for work requests. The futures manager could then have only a single method that ensures that there's at least one pending future of each type, and only that would be called from the outside at the end of each loop

Copy link
Collaborator

@olliestanley olliestanley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for delayed review - thanks, these changes make sense to me

I saw you have a checklist in the original comment so have not merged yet but if this is ready to go as-is let me know and I'll merge

@andreaskoepf
Copy link
Collaborator

@0xfacade Do you still want to complete work on the todo points (unticked checkboxes) or can it be merged as it is?

@0xfacade
Copy link
Contributor Author

@olliestanley @andreaskoepf This can be merged! Will work on the extension for the stop generating feature now

@olliestanley olliestanley merged commit 0c2fa4c into LAION-AI:main Jul 25, 2023
yk added a commit that referenced this pull request Aug 8, 2023
andreaskoepf pushed a commit that referenced this pull request Aug 8, 2023
…ration early) (#3573)" (#3644)

This PR reverts commit 0c2fa4c.

We'll investigate why this caused problems
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants