Skip to content

Commit

Permalink
feat: add dynamic state model creation and update (#3271)
Browse files Browse the repository at this point in the history
* feat: add initial implementation of dynamic state model creation and output getter in graph state module

* feat: implement _reset_all_output_values method to initialize component outputs in custom_component class

* feat: add state model management with lazy initialization and dynamic instance getter in custom_component class

* feat: Refactor Component class to use public method get_output_by_method

Refactor the Component class in the custom_component module to change the visibility of the method `_get_output_by_method` to public by renaming it to `get_output_by_method`. This change improves the accessibility and clarity of the method for external use.

* feat: add output setter utility to manage output values in state model properties

* feat: implement validation for methods' classes in output getter/setter utilities in state model to ensure proper structure

* feat: add state model creation from graph in state_model.py

* feat: enhance Graph class with lazy loading for state model creation from graph

* feat: add unit tests for state model creation and validation in test_state_model.py

* feat: add unit tests for state model creation and validation in test_state_model.py

* feat: add functional test for graph state update and validation in test_graph_state_model.py

* fix: update _instance_getter function to accept a parameter in component.py for state model instance retrieval

* refactor: rename test to clarify purpose in test_state_model.py for functional state update validation

* chore: import Finish constant in test_graph_state_model.py for improved clarity and usage in state model tests

* refactor: add optional validation in output getter/setter methods for improved method integrity in state model handling

* refactor: enhance state model creation with optional validation and error handling for output methods in model.py

* refactor: serialize and deserialize GraphStateModel in test_graph_state_model.py

* refactor: improve error message and add verbose mode for graph start in test_state_model.py

* refactor: remove verbose flag from graph.start in TestCreateStateModel for consistency in test_state_model.py

* refactor: disable validation when creating GraphStateModel in state_model.py for improved flexibility

* refactor: add validation documentation for method attributes in model.py to enhance code clarity and usability

* refactor: expand docstring for build_output_getter in model.py to clarify usage and validation details

* refactor: add detailed docstring for build_output_setter in model.py to improve clarity on functionality and usage scenarios

* refactor: add comprehensive docstring for create_state_model in model.py to clarify functionality and usage examples

* refactor: enhance docstring for create_state_model_from_graph in state_model.py to clarify functionality and provide examples

* test: add JSON schema validation in graph state model tests for improved structure and correctness verification

* refactor: Improve graph_state_model.json_schema unit test readability and structure.
  • Loading branch information
ogabrielluiz authored Aug 13, 2024
1 parent 2ffd723 commit c5d9cba
Show file tree
Hide file tree
Showing 7 changed files with 654 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import yaml
from pydantic import BaseModel

from langflow.graph.state.model import create_state_model
from langflow.helpers.custom import format_type
from langflow.schema.artifact import get_artifact_type, post_process_raw
from langflow.schema.data import Data
Expand Down Expand Up @@ -35,7 +36,7 @@ class Component(CustomComponent):
def __init__(self, **kwargs):
# if key starts with _ it is a config
# else it is an input

self._reset_all_output_values()
inputs = {}
config = {}
for key, value in kwargs.items():
Expand All @@ -50,6 +51,7 @@ def __init__(self, **kwargs):
self._parameters = inputs or {}
self._edges: list[EdgeData] = []
self._components: list[Component] = []
self._state_model = None
self.set_attributes(self._parameters)
self._output_logs = {}
config = config or {}
Expand All @@ -70,6 +72,30 @@ def __init__(self, **kwargs):
self._set_output_types()
self.set_class_code()

def _reset_all_output_values(self):
for output in self.outputs:
setattr(output, "value", UNDEFINED)

def _build_state_model(self):
if self._state_model:
return self._state_model
name = self.name or self.__class__.__name__
model_name = f"{name}StateModel"
fields = {}
for output in self.outputs:
fields[output.name] = getattr(self, output.method)
self._state_model = create_state_model(model_name=model_name, **fields)
return self._state_model

def get_state_model_instance_getter(self):
state_model = self._build_state_model()

def _instance_getter(_):
return state_model()

_instance_getter.__annotations__["return"] = state_model
return _instance_getter

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
Expand Down Expand Up @@ -247,7 +273,7 @@ def _set_output_types(self):
output.add_types(return_types)
output.set_selected()

def _get_output_by_method(self, method: Callable):
def get_output_by_method(self, method: Callable):
# method is a callable and output.method is a string
# we need to find the output that has the same method
output = next((output for output in self.outputs if output.method == method.__name__), None)
Expand All @@ -268,7 +294,7 @@ def _method_is_valid_output(self, method: Callable):
method_is_output = (
hasattr(method, "__self__")
and isinstance(method.__self__, Component)
and method.__self__._get_output_by_method(method)
and method.__self__.get_output_by_method(method)
)
return method_is_output

Expand Down Expand Up @@ -298,7 +324,7 @@ def _get_or_create_input(self, key):
def _connect_to_component(self, key, value, _input):
component = value.__self__
self._components.append(component)
output = component._get_output_by_method(value)
output = component.get_output_by_method(value)
self._add_edge(component, key, output, _input)

def _add_edge(self, component, key, output, _input):
Expand Down
8 changes: 8 additions & 0 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.schema import GraphData, GraphDump, VertexBuildResult
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.state_model import create_state_model_from_graph
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex, VertexStates
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
log_config = {"disable": False}
configure(**log_config)
self._start = start
self._state_model = None
self._end = end
self._prepared = False
self._runs = 0
Expand Down Expand Up @@ -109,6 +111,12 @@ def __init__(
if (start is not None and end is None) or (start is None and end is not None):
raise ValueError("You must provide both input and output components")

@property
def state_model(self):
if not self._state_model:
self._state_model = create_state_model_from_graph(self)
return self._state_model

def __add__(self, other):
if not isinstance(other, Graph):
raise TypeError("Can only add Graph objects")
Expand Down
67 changes: 67 additions & 0 deletions src/backend/base/langflow/graph/graph/state_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import re

from langflow.graph.state.model import create_state_model
from langflow.helpers.base_model import BaseModel


def camel_to_snake(camel_str: str) -> str:
snake_str = re.sub(r"(?<!^)(?=[A-Z])", "_", camel_str).lower()
return snake_str


def create_state_model_from_graph(graph: BaseModel) -> type[BaseModel]:
"""
Create a Pydantic state model from a graph representation.
This function generates a Pydantic model that represents the state of an entire graph.
It creates getter methods for each vertex in the graph, allowing access to the state
of individual components within the graph structure.
Args:
graph (BaseModel): The graph object from which to create the state model.
This should be a Pydantic model representing the graph structure,
with a 'vertices' attribute containing all graph vertices.
Returns:
type[BaseModel]: A dynamically created Pydantic model class representing
the state of the entire graph. This model will have properties
corresponding to each vertex in the graph, with names converted from
the vertex IDs to snake case.
Raises:
ValueError: If any vertex in the graph does not have a properly initialized
component instance (i.e., if vertex._custom_component is None).
Notes:
- Each vertex in the graph must have a '_custom_component' attribute.
- The '_custom_component' must have a 'get_state_model_instance_getter' method.
- Vertex IDs are converted from camel case to snake case for the resulting model's field names.
- The resulting model uses the 'create_state_model' function with validation disabled.
Example:
>>> class Vertex(BaseModel):
... id: str
... _custom_component: Any
>>> class Graph(BaseModel):
... vertices: List[Vertex]
>>> # Assume proper setup of vertices and components
>>> graph = Graph(vertices=[...])
>>> GraphStateModel = create_state_model_from_graph(graph)
>>> graph_state = GraphStateModel()
>>> # Access component states, e.g.:
>>> print(graph_state.some_component_name)
"""
for vertex in graph.vertices:
if hasattr(vertex, "_custom_component") and vertex._custom_component is None:
raise ValueError(f"Vertex {vertex.id} does not have a component instance.")

state_model_getters = [
vertex._custom_component.get_state_model_instance_getter()
for vertex in graph.vertices
if hasattr(vertex, "_custom_component") and hasattr(vertex._custom_component, "get_state_model_instance_getter")
]
fields = {
camel_to_snake(vertex.id): state_model_getter
for vertex, state_model_getter in zip(graph.vertices, state_model_getters)
}
return create_state_model(model_name="GraphStateModel", validate=False, **fields)
Empty file.
Loading

0 comments on commit c5d9cba

Please sign in to comment.