Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add EventManager to centralize callbacks #3434

Merged
merged 55 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
17fff8c
refactor: Update MessageBase text attribute based on isinstance check.
ogabrielluiz Aug 19, 2024
1893536
feat: Add update_message function to update a message in the database.
ogabrielluiz Aug 19, 2024
20a1727
refactor(chat): Update imports and remove unnecessary config method i…
ogabrielluiz Aug 19, 2024
76bb506
refactor: Add stream_message method to ChatComponent.
ogabrielluiz Aug 19, 2024
5c9566a
refactor: Update method call in ChatOutput component.
ogabrielluiz Aug 19, 2024
3e274fe
feat: Add callback function to custom component and update build_resu…
ogabrielluiz Aug 19, 2024
2a9eb09
feat: Add callback parameter to instantiate_class function.
ogabrielluiz Aug 19, 2024
a2a88db
feat(graph): Add callback functions for sync and async operations.
ogabrielluiz Aug 19, 2024
6b62ad4
feat: Add callback function support to vertex build process.
ogabrielluiz Aug 19, 2024
6a25c3e
feat: Add handling for added message in InterfaceVertex class.
ogabrielluiz Aug 19, 2024
2f8abb7
feat: Add callback support to Graph methods.
ogabrielluiz Aug 19, 2024
750810c
feat(chat): Add callback function to build_vertices function.
ogabrielluiz Aug 19, 2024
86366dc
refactor: Simplify update_message function and use session_scope for …
ogabrielluiz Aug 19, 2024
12c0aa6
fix: Call set_callback method if available on custom component.
ogabrielluiz Aug 19, 2024
c67df62
refactor(chat): Update chat message chunk handling and ID conversion.
ogabrielluiz Aug 19, 2024
d0846bb
feat: Add null check before setting cache in build_vertex_stream func…
ogabrielluiz Aug 19, 2024
5cb936a
refactor: Fix send_event_wrapper function and add callback parameter …
ogabrielluiz Aug 19, 2024
c9eb4a0
refactor: Simplify conditional statement and import order in ChatOutput.
ogabrielluiz Aug 19, 2024
73f6e7b
refactor: move log method to Component class.
ogabrielluiz Aug 20, 2024
bf611df
refactor: Simplify CallbackFunction definition.
ogabrielluiz Aug 20, 2024
5e46b5f
feat: Initialize _current_output attribute in Component class.
ogabrielluiz Aug 20, 2024
6535d36
feat: store current output name in custom component during processing.
ogabrielluiz Aug 20, 2024
5d97cf5
feat: Add current output and component ID to log data.
ogabrielluiz Aug 20, 2024
f2a6027
fix: Add condition to check current output before invoking callback.
ogabrielluiz Aug 20, 2024
920ec58
refactor: Update callback to log_callback in graph methods.
ogabrielluiz Aug 20, 2024
afdf9bb
feat: Add test for callback graph execution with log messages.
ogabrielluiz Aug 20, 2024
1a69ad2
update projects
ogabrielluiz Aug 23, 2024
f58156a
fix(chat.py): fix condition to check if message text is a string befo…
ogabrielluiz Aug 27, 2024
b6ce4af
refactor(ChatOutput.py): update ChatOutput class to correctly store a…
ogabrielluiz Aug 27, 2024
37dcf99
refactor(chat.py): update return type of store_message method to retu…
ogabrielluiz Aug 27, 2024
d4e2116
update starter projects
ogabrielluiz Aug 27, 2024
10410dc
refactor(component.py): update type hint for name parameter in log me…
ogabrielluiz Aug 28, 2024
3f8bc98
feat: Add EventManager class for managing events and event registration
ogabrielluiz Aug 28, 2024
410f1fc
refactor: Update log_callback to event_manager in custom component cl…
ogabrielluiz Aug 28, 2024
0807193
refactor(component.py): rename _log_callback to _event_manager and up…
ogabrielluiz Aug 28, 2024
34b1cb3
refactor(chat.py): rename _log_callback method to _event_manager.on_t…
ogabrielluiz Aug 28, 2024
1309eac
refactor: Rename log_callback to event_manager for clarity and consis…
ogabrielluiz Aug 28, 2024
39ee39d
refactor: Update Vertex class to use EventManager instead of log_call…
ogabrielluiz Aug 28, 2024
4c1661b
refactor: update build_flow to use EventManager
ogabrielluiz Aug 28, 2024
a89cdf6
refactor: Update EventManager class to use Protocol for event callbacks
ogabrielluiz Aug 28, 2024
8066f25
if event_type is not passed, it uses the default send_event
ogabrielluiz Aug 28, 2024
a129a01
Add method to register event functions in EventManager
ogabrielluiz Aug 28, 2024
56d4ea6
update test_callback_graph
ogabrielluiz Aug 28, 2024
f81244b
Add unit tests for EventManager in test_event_manager.py
ogabrielluiz Aug 28, 2024
830de83
revert chatOutput change
ogabrielluiz Aug 29, 2024
a31d2bd
Add validation for event function in EventManager
ogabrielluiz Aug 29, 2024
17926da
Add tests for EventManager's event function validation logic
ogabrielluiz Aug 29, 2024
086c430
Add type ignore comment to lambda function in test_event_manager.py
ogabrielluiz Aug 29, 2024
ef55e58
refactor: Update EventManager class to use Protocol for event callbacks
ogabrielluiz Aug 29, 2024
9dfc45d
refactor(event_manager.py): simplify event registration and validatio…
ogabrielluiz Aug 29, 2024
39866d1
refactor(chat.py): standardize event_manager method calls by using ke…
ogabrielluiz Aug 29, 2024
f16e6f5
update event manager tests
ogabrielluiz Aug 29, 2024
1dc155a
Add callback validation and manager parameter in EventManager
ogabrielluiz Aug 29, 2024
6484131
Add support for passing callback through the Graph in test_callback_g…
ogabrielluiz Aug 29, 2024
2439eb3
fix(event_manager.py): update EventCallback signature to include mana…
ogabrielluiz Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VertexBuildResponse,
VerticesOrderResponse,
)
from langflow.events.event_manager import EventManager, create_default_event_manager
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
Expand Down Expand Up @@ -204,7 +205,7 @@ async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
async def _build_vertex(vertex_id: str, graph: "Graph", event_manager: "EventManager") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
Expand All @@ -222,6 +223,7 @@ async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
files=files,
get_cache=chat_service.get_cache,
set_cache=chat_service.set_cache,
event_manager=event_manager,
)
result_dict = vertex_build_result.result_dict
params = vertex_build_result.params
Expand Down Expand Up @@ -316,17 +318,13 @@ async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc

def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {"event": event_type, "data": value}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))

async def build_vertices(
vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue
vertex_id: str,
graph: "Graph",
client_consumed_queue: asyncio.Queue,
event_manager: "EventManager",
) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph, event_manager))
try:
await build_task
except asyncio.CancelledError:
Expand All @@ -340,13 +338,15 @@ async def build_vertices(
build_data = json.loads(vertex_build_response_json)
except Exception as exc:
raise ValueError(f"Error serializing vertex build response: {exc}") from exc
send_event("end_vertex", {"build_data": build_data}, queue)
event_manager.on_end_vertex(data={"build_data": build_data})
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue, client_consumed_queue))
task = asyncio.create_task(
build_vertices(next_vertex_id, graph, client_consumed_queue, event_manager)
)
tasks.append(task)
try:
await asyncio.gather(*tasks)
Expand All @@ -355,7 +355,7 @@ async def build_vertices(
task.cancel()
return

async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
async def event_generator(event_manager: EventManager, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
Expand All @@ -366,9 +366,9 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
return
except Exception as e:
if isinstance(e, HTTPException):
send_event("error", {"error": str(e.detail), "statusCode": e.status_code}, queue)
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code})
raise e
send_event("error", {"error": str(e)}, queue)
event_manager.on_error(data={"error": str(e)})
raise e

ids, vertices_to_run, graph = vertices_task.result()
Expand All @@ -377,16 +377,16 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
ids, vertices_to_run, graph = await build_graph_and_get_order()
except Exception as e:
if isinstance(e, HTTPException):
send_event("error", {"error": str(e.detail), "statusCode": e.status_code}, queue)
event_manager.on_error(data={"error": str(e.detail), "statusCode": e.status_code})
raise e
send_event("error", {"error": str(e)}, queue)
event_manager.on_error(data={"error": str(e)})
raise e
send_event("vertices_sorted", {"ids": ids, "to_run": vertices_to_run}, queue)
event_manager.on_vertices_sorted(data={"ids": ids, "to_run": vertices_to_run})
await client_consumed_queue.get()

tasks = []
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue, client_consumed_queue))
task = asyncio.create_task(build_vertices(vertex_id, graph, client_consumed_queue, event_manager))
tasks.append(task)
try:
await asyncio.gather(*tasks)
Expand All @@ -395,8 +395,8 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
for task in tasks:
task.cancel()
return
send_event("end", {}, queue)
await queue.put((None, None, time.time))
event_manager.on_end(data={})
await event_manager.queue.put((None, None, time.time))

async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator:
while True:
Expand All @@ -413,7 +413,8 @@ async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio

asyncio_queue: asyncio.Queue = asyncio.Queue()
asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))
event_manager = create_default_event_manager(queue=asyncio_queue)
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed))

def on_disconnect():
logger.debug("Client disconnected, closing tasks")
Expand Down Expand Up @@ -639,6 +640,7 @@ async def build_vertex_stream(
flow_id_str = str(flow_id)

async def stream_vertex():
graph = None
try:
cache = await chat_service.get_cache(flow_id_str)
if not cache:
Expand Down Expand Up @@ -692,7 +694,8 @@ async def stream_vertex():
yield str(StreamData(event="error", data={"error": exc_message}))
finally:
logger.debug("Closing stream")
await chat_service.set_cache(flow_id_str, graph)
if graph:
await chat_service.set_cache(flow_id_str, graph)
yield str(StreamData(event="close", data={"message": "Stream closed"}))

return StreamingResponse(stream_vertex(), media_type="text/event-stream")
Expand Down
91 changes: 47 additions & 44 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,70 @@
from typing import Optional, Union
from typing import AsyncIterator, Iterator, Optional, Union

from langflow.base.data.utils import IMG_FILE_TYPES, TEXT_FILE_TYPES
from langflow.custom import Component
from langflow.memory import store_message
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_USER, MESSAGE_SENDER_AI
from langflow.services.database.models.message.crud import update_message
from langflow.utils.async_helpers import run_until_complete


class ChatComponent(Component):
display_name = "Chat Component"
description = "Use as base for chat components."

def build_config(self):
return {
"input_value": {
"input_types": ["Text"],
"display_name": "Text",
"multiline": True,
},
"sender": {
"options": [MESSAGE_SENDER_AI, MESSAGE_SENDER_USER],
"display_name": "Sender Type",
"advanced": True,
},
"sender_name": {"display_name": "Sender Name", "advanced": True},
"session_id": {
"display_name": "Session ID",
"info": "If provided, the message will be stored in the memory.",
"advanced": True,
},
"return_message": {
"display_name": "Return Message",
"info": "Return the message as a Message containing the sender, sender_name, and session_id.",
"advanced": True,
},
"data_template": {
"display_name": "Data Template",
"multiline": True,
"info": "In case of Message being a Data, this template will be used to convert it to text.",
"advanced": True,
},
"files": {
"field_type": "file",
"display_name": "Files",
"file_types": TEXT_FILE_TYPES + IMG_FILE_TYPES,
"info": "Files to be sent with the message.",
"advanced": True,
},
}

# Keep this method for backward compatibility
def store_message(
self,
message: Message,
) -> list[Message]:
) -> Message:
messages = store_message(
message,
flow_id=self.graph.flow_id,
)
if len(messages) > 1:
raise ValueError("Only one message can be stored at a time.")
stored_message = messages[0]
if hasattr(self, "_event_manager") and self._event_manager and stored_message.id:
if not isinstance(message.text, str):
complete_message = self._stream_message(message, stored_message.id)
message_table = update_message(message_id=stored_message.id, message=dict(text=complete_message))
stored_message = Message(**message_table.model_dump())
self.vertex._added_message = stored_message
self.status = stored_message
return stored_message

def _process_chunk(self, chunk: str, complete_message: str, message: Message, message_id: str) -> str:
complete_message += chunk
data = {
"text": complete_message,
"chunk": chunk,
"sender": message.sender,
"sender_name": message.sender_name,
"id": str(message_id),
}
if self._event_manager:
self._event_manager.on_token(data=data)
return complete_message

async def _handle_async_iterator(self, iterator: AsyncIterator, message: Message, message_id: str) -> str:
complete_message = ""
async for chunk in iterator:
complete_message = self._process_chunk(chunk.content, complete_message, message, message_id)
return complete_message

def _stream_message(self, message: Message, message_id: str) -> str:
iterator = message.text
if not isinstance(iterator, (AsyncIterator, Iterator)):
raise ValueError("The message must be an iterator or an async iterator.")

if isinstance(iterator, AsyncIterator):
return run_until_complete(self._handle_async_iterator(iterator, message, message_id))

complete_message = ""
for chunk in iterator:
complete_message = self._process_chunk(chunk.content, complete_message, message, message_id)

self.status = messages
return messages
return complete_message

def build_with_data(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/outputs/ChatOutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langflow.io import DropdownInput, MessageTextInput, Output
from langflow.memory import store_message
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_NAME_AI, MESSAGE_SENDER_USER, MESSAGE_SENDER_AI
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI, MESSAGE_SENDER_USER


class ChatOutput(ChatComponent):
Expand Down
38 changes: 35 additions & 3 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import inspect
from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
from collections.abc import Callable
from uuid import UUID

import nanoid # type: ignore
import yaml
from pydantic import BaseModel

from langflow.events.event_manager import EventManager
from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
from langflow.schema.log import LoggableType
from langflow.schema.message import Message
from langflow.services.tracing.schema import Log
from langflow.template.field.base import UNDEFINED, Input, Output
Expand All @@ -35,6 +37,7 @@ class Component(CustomComponent):
outputs: list[Output] = []
code_class_base_inheritance: ClassVar[str] = "Component"
_output_logs: dict[str, Log] = {}
_current_output: str = ""

def __init__(self, **kwargs):
# if key starts with _ it is a config
Expand All @@ -56,6 +59,8 @@ def __init__(self, **kwargs):
self._parameters = inputs or {}
self._edges: list[EdgeData] = []
self._components: list[Component] = []
self._current_output = ""
self._event_manager: EventManager | None = None
self._state_model = None
self.set_attributes(self._parameters)
self._output_logs = {}
Expand All @@ -77,6 +82,9 @@ def __init__(self, **kwargs):
self._set_output_types()
self.set_class_code()

def set_event_manager(self, event_manager: EventManager | None = None):
self._event_manager = event_manager

def _reset_all_output_values(self):
for output in self.outputs:
setattr(output, "value", UNDEFINED)
Expand Down Expand Up @@ -597,7 +605,9 @@ async def _build_with_tracing(self):
async def _build_without_tracing(self):
return await self._build_results()

async def build_results(self):
async def build_results(
self,
):
if self._tracing_service:
return await self._build_with_tracing()
return await self._build_without_tracing()
Expand All @@ -616,6 +626,7 @@ async def _build_results(self):
):
if output.method is None:
raise ValueError(f"Output {output.name} does not have a method defined.")
self._current_output = output.name
method: Callable = getattr(self, output.method)
if output.cache and output.value != UNDEFINED:
_results[output.name] = output.value
Expand All @@ -634,6 +645,7 @@ async def _build_results(self):
result.set_flow_id(self._vertex.graph.flow_id)
_results[output.name] = result
output.value = result

custom_repr = self.custom_repr()
if custom_repr is None and isinstance(result, (dict, Data, str)):
custom_repr = result
Expand Down Expand Up @@ -661,6 +673,7 @@ async def _build_results(self):
_artifacts[output.name] = artifact
self._output_logs[output.name] = self._logs
self._logs = []
self._current_output = ""
self._artifacts = _artifacts
self._results = _results
if self._tracing_service:
Expand Down Expand Up @@ -716,6 +729,25 @@ def to_tool(self):
return ComponentTool(component=self)

def get_project_name(self):
if hasattr(self, "_tracing_service"):
if hasattr(self, "_tracing_service") and self._tracing_service:
return self._tracing_service.project_name
return "Langflow"

def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
"""
Logs a message.

Args:
message (LoggableType | list[LoggableType]): The message to log.
"""
if name is None:
name = f"Log {len(self._logs) + 1}"
log = Log(message=message, type=get_artifact_type(message), name=name)
self._logs.append(log)
if self._tracing_service and self._vertex:
self._tracing_service.add_log(trace_name=self.trace_name, log=log)
if self._event_manager is not None and self._current_output:
data = log.model_dump()
data["output"] = self._current_output
data["component_id"] = self._id
self._event_manager.on_log(data=data)
Loading
Loading