Skip to content

Commit

Permalink
chore(openapi): unify inject param (#645)
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm authored Nov 14, 2023
1 parent b0ab8cc commit 00d6016
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
8 changes: 7 additions & 1 deletion openllm-python/src/openllm/entrypoints/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,17 @@ def asdict(self) -> dict[str, t.Any]:


def append_schemas(
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
svc: bentoml.Service,
generated_schema: dict[str, t.Any],
tags_order: t.Literal['prepend', 'append'] = 'prepend',
inject: bool = True,
) -> bentoml.Service:
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification

if not inject:
return svc

svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
svc_schema = svc_schema.asdict()
Expand Down
26 changes: 8 additions & 18 deletions openllm-python/src/openllm/entrypoints/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from openllm_core.utils import converter, gen_random_uuid

from ._openapi import append_schemas, get_generator

# from ._openapi import add_schema_definitions
from ._openapi import add_schema_definitions, append_schemas, get_generator
from ..protocol.cohere import (
Chat,
ChatStreamEnd,
Expand Down Expand Up @@ -89,29 +87,21 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic
debug=True,
routes=[
Route(
'/v1/generate',
endpoint=functools.partial(cohere_generate, llm=llm),
name='cohere_generate',
methods=['POST'],
include_in_schema=False,
),
Route(
'/v1/chat',
endpoint=functools.partial(cohere_chat, llm=llm),
name='cohere_chat',
methods=['POST'],
include_in_schema=False,
'/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']
),
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
],
)
mount_path = '/cohere'

svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
return append_schemas(
svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=False
)


# @add_schema_definitions
@add_schema_definitions
async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
Expand Down Expand Up @@ -201,7 +191,7 @@ def convert_role(role):
return messages


# @add_schema_definitions
@add_schema_definitions
async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response:
json_str = await req.body()
try:
Expand Down

0 comments on commit 00d6016

Please sign in to comment.