Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dynamic state model creation and update #3271

Merged
merged 27 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4e69493
feat: add initial implementation of dynamic state model creation and …
ogabrielluiz Aug 9, 2024
d479ccb
feat: implement _reset_all_output_values method to initialize compone…
ogabrielluiz Aug 9, 2024
292a42c
feat: add state model management with lazy initialization and dynamic…
ogabrielluiz Aug 9, 2024
4594e34
feat: Refactor Component class to use public method get_output_by_method
ogabrielluiz Aug 9, 2024
35ccaae
feat: add output setter utility to manage output values in state mode…
ogabrielluiz Aug 9, 2024
caa6391
feat: implement validation for methods' classes in output getter/sett…
ogabrielluiz Aug 9, 2024
4b8daff
feat: add state model creation from graph in state_model.py
ogabrielluiz Aug 9, 2024
da06d07
feat: enhance Graph class with lazy loading for state model creation …
ogabrielluiz Aug 9, 2024
03df5a7
feat: add unit tests for state model creation and validation in test_…
ogabrielluiz Aug 9, 2024
8dea891
feat: add unit tests for state model creation and validation in test_…
ogabrielluiz Aug 9, 2024
4603340
feat: add functional test for graph state update and validation in te…
ogabrielluiz Aug 9, 2024
6e35d1b
fix: update _instance_getter function to accept a parameter in compon…
ogabrielluiz Aug 9, 2024
3eeedc9
refactor: rename test to clarify purpose in test_state_model.py for f…
ogabrielluiz Aug 9, 2024
10fac16
chore: import Finish constant in test_graph_state_model.py for improv…
ogabrielluiz Aug 9, 2024
3f89717
refactor: add optional validation in output getter/setter methods for…
ogabrielluiz Aug 9, 2024
aafc563
refactor: enhance state model creation with optional validation and e…
ogabrielluiz Aug 9, 2024
023e673
refactor: serialize and deserialize GraphStateModel in test_graph_sta…
ogabrielluiz Aug 9, 2024
ce24763
refactor: improve error message and add verbose mode for graph start …
ogabrielluiz Aug 10, 2024
8b35ec9
refactor: remove verbose flag from graph.start in TestCreateStateMode…
ogabrielluiz Aug 10, 2024
38a3da2
refactor: disable validation when creating GraphStateModel in state_m…
ogabrielluiz Aug 10, 2024
7a3d989
refactor: add validation documentation for method attributes in model…
ogabrielluiz Aug 12, 2024
1c1518c
refactor: expand docstring for build_output_getter in model.py to cla…
ogabrielluiz Aug 12, 2024
f51a708
refactor: add detailed docstring for build_output_setter in model.py …
ogabrielluiz Aug 12, 2024
b4f2513
refactor: add comprehensive docstring for create_state_model in model…
ogabrielluiz Aug 12, 2024
e4d194a
refactor: enhance docstring for create_state_model_from_graph in stat…
ogabrielluiz Aug 12, 2024
182fea4
test: add JSON schema validation in graph state model tests for impro…
ogabrielluiz Aug 12, 2024
61aef47
refactor: Improve graph_state_model.json_schema unit test readability…
ogabrielluiz Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import yaml
from pydantic import BaseModel

from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
Expand Down Expand Up @@ -35,7 +36,7 @@ class Component(CustomComponent):
def __init__(self, **kwargs):
# if key starts with _ it is a config
# else it is an input

self._reset_all_output_values()
inputs = {}
config = {}
for key, value in kwargs.items():
Expand All @@ -50,6 +51,7 @@ def __init__(self, **kwargs):
self._parameters = inputs or {}
self._edges: list[EdgeData] = []
self._components: list[Component] = []
self._state_model = None
self.set_attributes(self._parameters)
self._output_logs = {}
config = config or {}
Expand All @@ -70,6 +72,30 @@ def __init__(self, **kwargs):
self._set_output_types()
self.set_class_code()

def _reset_all_output_values(self):
for output in self.outputs:
setattr(output, "value", UNDEFINED)

def _build_state_model(self):
if self._state_model:
return self._state_model
name = self.name or self.__class__.__name__
model_name = f"{name}StateModel"
fields = {}
for output in self.outputs:
fields[output.name] = getattr(self, output.method)
self._state_model = create_state_model(model_name=model_name, **fields)
return self._state_model

def get_state_model_instance_getter(self):
state_model = self._build_state_model()

def _instance_getter(_):
return state_model()

_instance_getter.__annotations__["return"] = state_model
return _instance_getter

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
Expand Down Expand Up @@ -247,7 +273,7 @@ def _set_output_types(self):
output.add_types(return_types)
output.set_selected()

def _get_output_by_method(self, method: Callable):
def get_output_by_method(self, method: Callable):
# method is a callable and output.method is a string
# we need to find the output that has the same method
output = next((output for output in self.outputs if output.method == method.__name__), None)
Expand All @@ -268,7 +294,7 @@ def _method_is_valid_output(self, method: Callable):
method_is_output = (
hasattr(method, "__self__")
and isinstance(method.__self__, Component)
and method.__self__._get_output_by_method(method)
and method.__self__.get_output_by_method(method)
)
return method_is_output

Expand Down Expand Up @@ -298,7 +324,7 @@ def _get_or_create_input(self, key):
def _connect_to_component(self, key, value, _input):
component = value.__self__
self._components.append(component)
output = component._get_output_by_method(value)
output = component.get_output_by_method(value)
self._add_edge(component, key, output, _input)

def _add_edge(self, component, key, output, _input):
Expand Down
8 changes: 8 additions & 0 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
log_config = {"disable": False}
configure(**log_config)
self._start = start
self._state_model = None
self._end = end
self._prepared = False
self._runs = 0
Expand Down Expand Up @@ -109,6 +111,12 @@ def __init__(
if (start is not None and end is None) or (start is None and end is not None):
raise ValueError("You must provide both input and output components")

@property
def state_model(self):
if not self._state_model:
self._state_model = create_state_model_from_graph(self)
return self._state_model

def __add__(self, other):
if not isinstance(other, Graph):
raise TypeError("Can only add Graph objects")
Expand Down
67 changes: 67 additions & 0 deletions src/backend/base/langflow/graph/graph/state_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import re

from langflow.graph.state.model import create_state_model
from langflow.helpers.base_model import BaseModel


def camel_to_snake(camel_str: str) -> str:
snake_str = re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
return snake_str


def create_state_model_from_graph(graph: BaseModel) -> type[BaseModel]:
"""
Create a Pydantic state model from a graph representation.

This function generates a Pydantic model that represents the state of an entire graph.
It creates getter methods for each vertex in the graph, allowing access to the state
of individual components within the graph structure.

Args:
graph (BaseModel): The graph object from which to create the state model.
This should be a Pydantic model representing the graph structure,
with a 'vertices' attribute containing all graph vertices.

Returns:
type[BaseModel]: A dynamically created Pydantic model class representing
the state of the entire graph. This model will have properties
corresponding to each vertex in the graph, with names converted from
the vertex IDs to snake case.

Raises:
ValueError: If any vertex in the graph does not have a properly initialized
component instance (i.e., if vertex._custom_component is None).

Notes:
- Each vertex in the graph must have a '_custom_component' attribute.
- The '_custom_component' must have a 'get_state_model_instance_getter' method.
- Vertex IDs are converted from camel case to snake case for the resulting model's field names.
- The resulting model uses the 'create_state_model' function with validation disabled.

Example:
>>> class Vertex(BaseModel):
... id: str
... _custom_component: Any
>>> class Graph(BaseModel):
... vertices: List[Vertex]
>>> # Assume proper setup of vertices and components
>>> graph = Graph(vertices=[...])
>>> GraphStateModel = create_state_model_from_graph(graph)
>>> graph_state = GraphStateModel()
>>> # Access component states, e.g.:
>>> print(graph_state.some_component_name)
"""
for vertex in graph.vertices:
if hasattr(vertex, "_custom_component") and vertex._custom_component is None:
raise ValueError(f"Vertex {vertex.id} does not have a component instance.")

state_model_getters = [
vertex._custom_component.get_state_model_instance_getter()
for vertex in graph.vertices
if hasattr(vertex, "_custom_component") and hasattr(vertex._custom_component, "get_state_model_instance_getter")
]
fields = {
camel_to_snake(vertex.id): state_model_getter
for vertex, state_model_getter in zip(graph.vertices, state_model_getters)
}
return create_state_model(model_name="GraphStateModel", validate=False, **fields)
Empty file.
Loading