Skip to content

Commit

Permalink
feat: add EventManager to centralize callbacks (#3434)
Browse files Browse the repository at this point in the history
* refactor: Update MessageBase text attribute based on isinstance check.

* feat: Add update_message function to update a message in the database.

* refactor(chat): Update imports and remove unnecessary config method in ChatComponent.

* refactor: Add stream_message method to ChatComponent.

* refactor: Update method call in ChatOutput component.

* feat: Add callback function to custom component and update build_results signature.

* feat: Add callback parameter to instantiate_class function.

* feat(graph): Add callback functions for sync and async operations.

* feat: Add callback function support to vertex build process.

* feat: Add handling for added message in InterfaceVertex class.

* feat: Add callback support to Graph methods.

* feat(chat): Add callback function to build_vertices function.

* refactor: Simplify update_message function and use session_scope for session management.

* fix: Call set_callback method if available on custom component.

* refactor(chat): Update chat message chunk handling and ID conversion.

* feat: Add null check before setting cache in build_vertex_stream function.

* refactor: Fix send_event_wrapper function and add callback parameter to _build_vertex function.

* refactor: Simplify conditional statement and import order in ChatOutput.

* refactor: move log method to Component class.

* refactor: Simplify CallbackFunction definition.

* feat: Initialize _current_output attribute in Component class.

* feat: store current output name in custom component during processing.

* feat: Add current output and component ID to log data.

* fix: Add condition to check current output before invoking callback.

* refactor: Update callback to log_callback in graph methods.

* feat: Add test for callback graph execution with log messages.

* update projects

* fix(chat.py): fix condition to check if message text is a string before updating message text in the database

* refactor(ChatOutput.py): update ChatOutput class to correctly store and assign the message value to ensure consistency and avoid potential bugs

* refactor(chat.py): update return type of store_message method to return a single Message object instead of a list of Messages
refactor(chat.py): update logic to correctly handle updating and returning a single stored message object instead of a list of messages

* update starter projects

* refactor(component.py): update type hint for name parameter in log method to be more explicit

* feat: Add EventManager class for managing events and event registration

* refactor: Update log_callback to event_manager in custom component classes

* refactor(component.py): rename _log_callback to _event_manager and update method call to on_log for better clarity and consistency

* refactor(chat.py): rename _log_callback method to _event_manager.on_token for clarity and consistency in method naming

* refactor: Rename log_callback to event_manager for clarity and consistency

* refactor: Update Vertex class to use EventManager instead of log_callback for better clarity and consistency

* refactor: update build_flow to use EventManager

* refactor: Update EventManager class to use Protocol for event callbacks

* if event_type is not passed, it uses the default send_event

* Add method to register event functions in EventManager

- Introduced `register_event_function` method to allow passing custom event functions.
- Updated `noop` method to accept `event_type` parameter.
- Adjusted `__getattr__` to return `EventCallback` type.

* update test_callback_graph

* Add unit tests for EventManager in test_event_manager.py

- Added tests for event registration, including default event type, empty string names, and specific event types.
- Added tests for custom event functions and unregistered event access.
- Added tests for event sending, including JSON formatting, empty data, and large payloads.
- Added tests for handling JSON serialization errors and the noop function.

* revert chatOutput change

* Add validation for event function in EventManager

- Introduced `_validate_event_function` method to ensure event functions are callable and have the correct parameters.
- Updated `register_event_function` to use the new validation method.

* Add tests for EventManager's event function validation logic

- Introduce `TestValidateEventFunction` class to test various scenarios for `_validate_event_function`.
- Add tests for valid event functions, non-callable event functions, invalid parameter counts, and parameter type validation.
- Include tests for handling unannotated parameters, flexible arguments, and keyword-only parameters.
- Ensure proper warnings and exceptions are raised for invalid event functions.

* Add type ignore comment to lambda function in test_event_manager.py

* refactor: Update EventManager class to use Protocol for event callbacks

* refactor(event_manager.py): simplify event registration and validation logic to enhance readability and maintainability
feat(event_manager.py): enforce event name conventions and improve callback handling for better error management

* refactor(chat.py): standardize event_manager method calls by using keyword arguments for better clarity and consistency
refactor(chat.py): extract message processing logic into separate methods for improved readability and maintainability
fix(chat.py): ensure proper handling of async iterators in message streaming
refactor(component.py): simplify event logging by removing unnecessary event name parameter in on_log method call

* update event manager tests

* Add callback validation and manager parameter in EventManager

- Introduced `_validate_callback` method to ensure callbacks are callable and have the correct parameters.
- Updated `register_event` to include `manager` parameter in the callback.

* Add support for passing callback through the Graph in test_callback_graph

* fix(event_manager.py): update EventCallback signature to include manager parameter for better context in event handling
  • Loading branch information
ogabrielluiz authored Sep 2, 2024
1 parent 882c35e commit 3eaad7b
Show file tree
Hide file tree
Showing 24 changed files with 534 additions and 106 deletions.
49 changes: 26 additions & 23 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VertexBuildResponse,
VerticesOrderResponse,
)
from langflow.events.event_manager import EventManager, create_default_event_manager
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
Expand Down Expand Up @@ -204,7 +205,7 @@ async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
async def _build_vertex(vertex_id: str, graph: "Graph", event_manager: "EventManager") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
Expand All @@ -222,6 +223,7 @@ async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
files=files,
get_cache=chat_service.get_cache,
set_cache=chat_service.set_cache,
event_manager=event_manager,
)
result_dict = vertex_build_result.result_dict
params = vertex_build_result.params
Expand Down Expand Up @@ -316,17 +318,13 @@ async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc

def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {"event": event_type, "data": value}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))

async def build_vertices(
vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue
vertex_id: str,
graph: "Graph",
client_consumed_queue: asyncio.Queue,
event_manager: "EventManager",
) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph, event_manager))
try:
await build_task
except asyncio.CancelledError as exc:
Expand All @@ -341,13 +339,15 @@ async def build_vertices(
build_data = json.loads(vertex_build_response_json)
except Exception as exc:
raise ValueError(f"Error serializing vertex build response: {exc}") from exc
send_event("end_vertex", {"build_data": build_data}, queue)
event_manager.on_end_vertex(data={"build_data": build_data})
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue, client_consumed_queue))
task = asyncio.create_task(
build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)
)
tasks.append(task)
try:
await asyncio.gather(*tasks)
Expand All @@ -356,7 +356,7 @@ async def build_vertices(
task.cancel()
return

async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
Expand All @@ -367,9 +367,9 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
return
except Exception as e:
if isinstance(e, HTTPException):
send_event("error", {"error": str(e.detail), "statusCode": e.status_code}, queue)
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code})
raise e
send_event("error", {"error": str(e)}, queue)
event_manager.on_error(data={"error": str(e)})
raise e

ids, vertices_to_run, graph = vertices_task.result()
Expand All @@ -378,16 +378,16 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
ids, vertices_to_run, graph = await build_graph_and_get_order()
except Exception as e:
if isinstance(e, HTTPException):
send_event("error", {"error": str(e.detail), "statusCode": e.status_code}, queue)
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code})
raise e
send_event("error", {"error": str(e)}, queue)
event_manager.on_error(data={"error": str(e)})
raise e
send_event("vertices_sorted", {"ids": ids, "to_run": vertices_to_run}, queue)
event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run})
await client_consumed_queue.get()

tasks = []
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue, client_consumed_queue))
task = asyncio.create_task(build_vertices(vertex_id, graph, client_consumed_queue, event_manager))
tasks.append(task)
try:
await asyncio.gather(*tasks)
Expand All @@ -396,8 +396,8 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
for task in tasks:
task.cancel()
return
send_event("end", {}, queue)
await queue.put((None, None, time.time))
event_manager.on_end(data={})
await event_manager.queue.put((None, None, time.time))

async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator:
while True:
Expand All @@ -414,7 +414,8 @@ async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio

asyncio_queue: asyncio.Queue = asyncio.Queue()
asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))
event_manager = create_default_event_manager(queue=asyncio_queue)
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed))

def on_disconnect():
logger.debug("Client disconnected, closing tasks")
Expand Down Expand Up @@ -640,6 +641,7 @@ async def build_vertex_stream(
flow_id_str = str(flow_id)

async def stream_vertex():
graph = None
try:
cache = await chat_service.get_cache(flow_id_str)
if not cache:
Expand Down Expand Up @@ -693,7 +695,8 @@ async def stream_vertex():
yield str(StreamData(event="error", data={"error": exc_message}))
finally:
logger.debug("Closing stream")
await chat_service.set_cache(flow_id_str, graph)
if graph:
await chat_service.set_cache(flow_id_str, graph)
yield str(StreamData(event="close", data={"message": "Stream closed"}))

return StreamingResponse(stream_vertex(), media_type="text/event-stream")
Expand Down
91 changes: 47 additions & 44 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,70 @@
from typing import Optional, Union
from typing import AsyncIterator, Iterator, Optional, Union

from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
from langflow.custom import Component
from langflow.memory import store_message
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_USER, MESSAGE_SENDER_AI
from langflow.services.database.models.message.crud import update_message
from langflow.utils.async_helpers import run_until_complete


class ChatComponent(Component):
display_name = "Chat Component"
description = "Use as base for chat components."

def build_config(self):
return {
"input_value": {
"input_types": ["Text"],
"display_name": "Text",
"multiline": True,
},
"sender": {
"options": [MESSAGE_SENDER_AI, MESSAGE_SENDER_USER],
"display_name": "Sender Type",
"advanced": True,
},
"sender_name": {"display_name": "Sender Name", "advanced": True},
"session_id": {
"display_name": "Session ID",
"info": "If provided, the message will be stored in the memory.",
"advanced": True,
},
"return_message": {
"display_name": "Return Message",
"info": "Return the message as a Message containing the sender, sender_name, and session_id.",
"advanced": True,
},
"data_template": {
"display_name": "Data Template",
"multiline": True,
"info": "In case of Message being a Data, this template will be used to convert it to text.",
"advanced": True,
},
"files": {
"field_type": "file",
"display_name": "Files",
"file_types": TEXT_FILE_TYPES + IMG_FILE_TYPES,
"info": "Files to be sent with the message.",
"advanced": True,
},
}

# Keep this method for backward compatibility
def store_message(
self,
message: Message,
) -> list[Message]:
) -> Message:
messages = store_message(
message,
flow_id=self.graph.flow_id,
)
if len(messages) > 1:
raise ValueError("Only one message can be stored at a time.")
stored_message = messages[0]
if hasattr(self, "_event_manager") and self._event_manager and stored_message.id:
if not isinstance(message.text, str):
complete_message = self._stream_message(message, stored_message.id)
message_table = update_message(message_id=stored_message.id, message=dict(text=complete_message))
stored_message = Message(**message_table.model_dump())
self.vertex._added_message = stored_message
self.status = stored_message
return stored_message

def _process_chunk(self, chunk: str, complete_message: str, message: Message, message_id: str) -> str:
complete_message += chunk
data = {
"text": complete_message,
"chunk": chunk,
"sender": message.sender,
"sender_name": message.sender_name,
"id": str(message_id),
}
if self._event_manager:
self._event_manager.on_token(data=data)
return complete_message

async def _handle_async_iterator(self, iterator: AsyncIterator, message: Message, message_id: str) -> str:
complete_message = ""
async for chunk in iterator:
complete_message = self._process_chunk(chunk.content, complete_message, message, message_id)
return complete_message

def _stream_message(self, message: Message, message_id: str) -> str:
iterator = message.text
if not isinstance(iterator, (AsyncIterator, Iterator)):
raise ValueError("The message must be an iterator or an async iterator.")

if isinstance(iterator, AsyncIterator):
return run_until_complete(self._handle_async_iterator(iterator, message, message_id))

complete_message = ""
for chunk in iterator:
complete_message = self._process_chunk(chunk.content, complete_message, message, message_id)

self.status = messages
return messages
return complete_message

def build_with_data(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/outputs/ChatOutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langflow.io import DropdownInput, MessageTextInput, Output
from langflow.memory import store_message
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_NAME_AI, MESSAGE_SENDER_USER, MESSAGE_SENDER_AI
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI, MESSAGE_SENDER_USER


class ChatOutput(ChatComponent):
Expand Down
38 changes: 35 additions & 3 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import inspect
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
from collections.abc import Callable
from uuid import UUID

import nanoid # type: ignore
import yaml
from pydantic import BaseModel

from langflow.events.event_manager import EventManager
from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
from langflow.schema.log import LoggableType
from langflow.schema.message import Message
from langflow.services.tracing.schema import Log
from langflow.template.field.base import UNDEFINED, Input, Output
Expand All @@ -35,6 +37,7 @@ class Component(CustomComponent):
outputs: list[Output] = []
code_class_base_inheritance: ClassVar[str] = "Component"
_output_logs: dict[str, Log] = {}
_current_output: str = ""

def __init__(self, **kwargs):
# if key starts with _ it is a config
Expand All @@ -56,6 +59,8 @@ def __init__(self, **kwargs):
self._parameters = inputs or {}
self._edges: list[EdgeData] = []
self._components: list[Component] = []
self._current_output = ""
self._event_manager: EventManager | None = None
self._state_model = None
self.set_attributes(self._parameters)
self._output_logs = {}
Expand All @@ -77,6 +82,9 @@ def __init__(self, **kwargs):
self._set_output_types()
self.set_class_code()

def set_event_manager(self, event_manager: EventManager | None = None):
self._event_manager = event_manager

def _reset_all_output_values(self):
for output in self.outputs:
setattr(output, "value", UNDEFINED)
Expand Down Expand Up @@ -601,7 +609,9 @@ async def _build_with_tracing(self):
async def _build_without_tracing(self):
return await self._build_results()

async def build_results(self):
async def build_results(
self,
):
if self._tracing_service:
return await self._build_with_tracing()
return await self._build_without_tracing()
Expand All @@ -620,6 +630,7 @@ async def _build_results(self):
):
if output.method is None:
raise ValueError(f"Output {output.name} does not have a method defined.")
self._current_output = output.name
method: Callable = getattr(self, output.method)
if output.cache and output.value != UNDEFINED:
_results[output.name] = output.value
Expand All @@ -638,6 +649,7 @@ async def _build_results(self):
result.set_flow_id(self._vertex.graph.flow_id)
_results[output.name] = result
output.value = result

custom_repr = self.custom_repr()
if custom_repr is None and isinstance(result, (dict, Data, str)):
custom_repr = result
Expand Down Expand Up @@ -665,6 +677,7 @@ async def _build_results(self):
_artifacts[output.name] = artifact
self._output_logs[output.name] = self._logs
self._logs = []
self._current_output = ""
self._artifacts = _artifacts
self._results = _results
if self._tracing_service:
Expand Down Expand Up @@ -720,6 +733,25 @@ def to_tool(self):
return ComponentTool(component=self)

def get_project_name(self):
if hasattr(self, "_tracing_service"):
if hasattr(self, "_tracing_service") and self._tracing_service:
return self._tracing_service.project_name
return "Langflow"

def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
"""
Logs a message.
Args:
message (LoggableType | list[LoggableType]): The message to log.
"""
if name is None:
name = f"Log {len(self._logs) + 1}"
log = Log(message=message, type=get_artifact_type(message), name=name)
self._logs.append(log)
if self._tracing_service and self._vertex:
self._tracing_service.add_log(trace_name=self.trace_name, log=log)
if self._event_manager is not None and self._current_output:
data = log.model_dump()
data["output"] = self._current_output
data["component_id"] = self._id
self._event_manager.on_log(data=data)
Loading

0 comments on commit 3eaad7b

Please sign in to comment.