Skip to content
This repository has been archived by the owner on Oct 19, 2023. It is now read-only.

Commit

Permalink
fix: open ai handle
Browse files Browse the repository at this point in the history
  • Loading branch information
zac-li committed Jun 14, 2023
1 parent 23efc6b commit a814ad5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
9 changes: 8 additions & 1 deletion lcserve/backend/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from typing import Callable


def serving(_func=None, *, websocket: bool = False, auth: Callable = None):
def serving(
_func=None,
*,
websocket: bool = False,
trace_for_openai=False,
auth: Callable = None
):
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
Expand All @@ -23,6 +29,7 @@ def sync_wrapper(*args, **kwargs):
'doc': func.__doc__,
'params': {
'include_ws_callback_handlers': websocket,
'trace_for_openai': trace_for_openai,
# If websocket is True, pass the callback handlers to the client.
'auth': auth,
},
Expand Down
57 changes: 45 additions & 12 deletions lcserve/backend/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .langchain_helper import (
AsyncStreamingWebsocketCallbackHandler,
BuiltinsWrapper,
OpenAITracingCallbackHandler,
StreamingWebsocketCallbackHandler,
TracingCallbackHandler,
)
Expand Down Expand Up @@ -456,7 +457,10 @@ def _get_decorator_params(func):
_decorator_params = _get_decorator_params(func)
if hasattr(func, '__serving__'):
self._register_http_route(
func, dirname=dirname, auth=_decorator_params.get('auth', None)
func,
dirname=dirname,
auth=_decorator_params.get('auth', None),
trace_for_openai=_decorator_params.get('trace_for_openai', False),
)
elif hasattr(func, '__ws_serving__'):
self._register_ws_route(
Expand All @@ -466,16 +470,23 @@ def _get_decorator_params(func):
include_ws_callback_handlers=_decorator_params.get(
'include_ws_callback_handlers', False
),
trace_for_openai=_decorator_params.get('trace_for_openai', False),
)

def _register_http_route(
self, func: Callable, dirname: str = None, auth: Callable = None, **kwargs
self,
func: Callable,
dirname: str = None,
auth: Callable = None,
trace_for_openai: bool = False,
**kwargs,
):
return self._register_route(
func,
dirname=dirname,
auth=auth,
route_type=RouteType.HTTP,
trace_for_openai=trace_for_openai,
**kwargs,
)

Expand All @@ -485,6 +496,7 @@ def _register_ws_route(
dirname: str = None,
auth: Callable = None,
include_ws_callback_handlers: bool = False,
trace_for_openai: bool = False,
**kwargs,
):
return self._register_route(
Expand All @@ -493,6 +505,7 @@ def _register_ws_route(
auth=auth,
route_type=RouteType.WEBSOCKET,
include_ws_callback_handlers=include_ws_callback_handlers,
trace_for_openai=trace_for_openai,
**kwargs,
)

Expand All @@ -503,6 +516,7 @@ def _register_route(
auth: Callable = None,
route_type: RouteType = RouteType.HTTP,
include_ws_callback_handlers: bool = False,
trace_for_openai: bool = False,
**kwargs,
):
_name = func.__name__.title().replace('_', '')
Expand Down Expand Up @@ -542,6 +556,7 @@ class Config:
file_params=file_params,
input_model=input_model,
output_model=output_model,
trace_for_openai=trace_for_openai,
post_kwargs={
'path': f'/{func.__name__}',
'name': _name,
Expand Down Expand Up @@ -569,6 +584,7 @@ class Config:
'name': _name,
},
include_ws_callback_handlers=include_ws_callback_handlers,
trace_for_openai=trace_for_openai,
workspace=self.workspace,
logger=self.logger,
tracer=self.tracer,
Expand Down Expand Up @@ -662,6 +678,7 @@ def create_http_route(
file_params: List,
input_model: BaseModel,
output_model: BaseModel,
trace_for_openai: bool,
post_kwargs: Dict,
workspace: str,
logger: JinaLogger,
Expand Down Expand Up @@ -698,11 +715,19 @@ async def _the_route(
) -> output_model:
_output, _error = '', ''
# Tracing handler provided if kwargs is present
to_support_in_kwargs = {
'tracing_handler': TracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}
if trace_for_openai:
to_support_in_kwargs = {
'tracing_handler': OpenAITracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}
else:
to_support_in_kwargs = {
'tracing_handler': TracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}

_func_data, _envs = _get_func_data(
func=func,
input_data=input_data,
Expand Down Expand Up @@ -815,6 +840,7 @@ def create_websocket_route(
input_model: BaseModel,
output_model: BaseModel,
include_ws_callback_handlers: bool,
trace_for_openai: bool,
ws_kwargs: Dict,
workspace: str,
logger: JinaLogger,
Expand Down Expand Up @@ -902,11 +928,18 @@ def _get_error_msg(e: Union[WebSocketDisconnect, ConnectionClosed]) -> str:
continue

# Tracing handler provided if kwargs is present
to_support_in_kwargs = {
'tracing_handler': TracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}
if trace_for_openai:
to_support_in_kwargs = {
'tracing_handler': OpenAITracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}
else:
to_support_in_kwargs = {
'tracing_handler': TracingCallbackHandler(
tracer=tracer, parent_span=get_current_span()
)
}

# If the function is a streaming response, we pass the websocket callback handler,
# so that stream data can be sent back to the client.
Expand Down
26 changes: 21 additions & 5 deletions lcserve/backend/langchain_helper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import json
import logging
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from uuid import UUID

from fastapi import WebSocket
from langchain.callbacks import OpenAICallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import AgentAction, LLMResult
from opentelemetry.trace import Span, Tracer, get_current_span, set_span_in_context
from opentelemetry.trace import Span, Tracer, set_span_in_context
from pydantic import BaseModel, ValidationError


Expand All @@ -35,17 +36,19 @@ def get_tracing_logger():
class TraceInfo:
trace: str
span: str
action: str
prompts: Optional[List[str]] = None
outputs: str = ""
cost: float = 0
cost: Optional[float] = None


class TracingCallbackHandler(BaseCallbackHandler):
class TracingCallbackHandlerMixin(BaseCallbackHandler):
def __init__(self, tracer: Tracer, parent_span: Span):
super().__init__()
self.tracer = tracer
self.parent_span = parent_span
self.logger = get_tracing_logger()
self.total_cost = 0

def _register_span(self, run_id, span):
_span_map[run_id] = span
Expand Down Expand Up @@ -87,6 +90,7 @@ def on_llm_start(
trace_info = TraceInfo(
trace=span_context.trace_id,
span=span_context.span_id,
action="on_llm_start",
prompts=prompts,
)
self.logger.info(json.dumps(trace_info.__dict__))
Expand All @@ -112,8 +116,9 @@ def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Non
trace_info = TraceInfo(
trace=span_context.trace_id,
span=span_context.span_id,
action="on_llm_end",
outputs=texts,
cost=20,
cost=round(self.total_cost, 3) if self.total_cost else None,
)
self.logger.info(json.dumps(trace_info.__dict__))
span.add_event("outputs", {"data": texts})
Expand Down Expand Up @@ -227,6 +232,17 @@ def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
self._end_span(run_id)


class TracingCallbackHandler(TracingCallbackHandlerMixin):
pass


class OpenAITracingCallbackHandler(TracingCallbackHandlerMixin, OpenAICallbackHandler):
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
# Set the computed total cost first with OpenAICallbackHandler and then handle the tracing
OpenAICallbackHandler.on_llm_end(self, response, run_id=run_id, **kwargs)
TracingCallbackHandlerMixin.on_llm_end(self, response, run_id=run_id, **kwargs)


class AsyncStreamingWebsocketCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, websocket: "WebSocket", output_model: "BaseModel"):
super().__init__()
Expand Down

0 comments on commit a814ad5

Please sign in to comment.