diff --git a/openllm-python/src/openllm/entrypoints/_openapi.py b/openllm-python/src/openllm/entrypoints/_openapi.py index 427b55ce6..e483090a8 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.py +++ b/openllm-python/src/openllm/entrypoints/_openapi.py @@ -611,11 +611,17 @@ def asdict(self) -> dict[str, t.Any]: def append_schemas( - svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend' + svc: bentoml.Service, + generated_schema: dict[str, t.Any], + tags_order: t.Literal['prepend', 'append'] = 'prepend', + inject: bool = True, ) -> bentoml.Service: # HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec. from bentoml._internal.service.openapi.specification import OpenAPISpecification + if not inject: + return svc + svc_schema = svc.openapi_spec if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict() diff --git a/openllm-python/src/openllm/entrypoints/cohere.py b/openllm-python/src/openllm/entrypoints/cohere.py index 645dbc360..ebd5bb1f9 100644 --- a/openllm-python/src/openllm/entrypoints/cohere.py +++ b/openllm-python/src/openllm/entrypoints/cohere.py @@ -13,9 +13,7 @@ from openllm_core.utils import converter, gen_random_uuid -from ._openapi import append_schemas, get_generator - -# from ._openapi import add_schema_definitions +from ._openapi import add_schema_definitions, append_schemas, get_generator from ..protocol.cohere import ( Chat, ChatStreamEnd, @@ -89,29 +87,21 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic debug=True, routes=[ Route( - '/v1/generate', - endpoint=functools.partial(cohere_generate, llm=llm), - name='cohere_generate', - methods=['POST'], - include_in_schema=False, - ), - Route( - '/v1/chat', - endpoint=functools.partial(cohere_chat, llm=llm), - name='cohere_chat', - methods=['POST'], - include_in_schema=False, + '/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST'] ), + Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']), Route('/schema', endpoint=openapi_schema, include_in_schema=False), ], ) mount_path = '/cohere' svc.mount_asgi_app(app, path=mount_path) - return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append') + return append_schemas( + svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=False + ) -# @add_schema_definitions +@add_schema_definitions async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: json_str = await req.body() try: @@ -201,7 +191,7 @@ def convert_role(role): return messages -# @add_schema_definitions +@add_schema_definitions async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: json_str = await req.body() try: