Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ConditionalRouter Haystack 2.x component #6147

Merged
merged 41 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cc202f3
Initial commit
vblagoje Oct 20, 2023
5412f1b
First crude working example
vblagoje Oct 20, 2023
24dfd91
Small fix
vblagoje Oct 20, 2023
4d7d161
Simplify routes, change names, add pydocs
vblagoje Oct 21, 2023
6380ca3
Several improvements, add unit tests
vblagoje Oct 21, 2023
87eb309
Compile expressions eagerly
vblagoje Oct 21, 2023
7ddf516
Fix typing mistakes in unit tests
vblagoje Oct 21, 2023
fd9d913
Add release note
vblagoje Oct 21, 2023
1d57520
Use asteval
vblagoje Oct 28, 2023
f9379de
Fix pylint
vblagoje Oct 28, 2023
4186bbf
Merge branch 'main' into connection_router_v2
vblagoje Oct 28, 2023
fbbb2ff
add alternative implementation
masci Oct 29, 2023
b3d8a6b
Retire Router, keep ConditionalRouter
vblagoje Nov 1, 2023
b6fdd56
Add optional output_slot route field
vblagoje Nov 2, 2023
1290c1e
Minor test fix
vblagoje Nov 6, 2023
45b7de9
Rename output_slot to output_name, update everything
vblagoje Nov 7, 2023
88d8b50
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 7, 2023
72e79ce
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 7, 2023
e47b432
PR review update
vblagoje Nov 7, 2023
b93b6c5
Variable typo
vblagoje Nov 7, 2023
3b47d0e
Add (de)serialization
vblagoje Nov 8, 2023
be3d00f
Improve (de)serialization, ensure it works on all Python >= 3.8 runtimes
vblagoje Nov 8, 2023
6e83b61
Fix serialize, make it idempotent
vblagoje Nov 11, 2023
f8370d1
Rename (de)serialize methods to (de)serialize_type
vblagoje Nov 13, 2023
0f4d0f3
Merge branch 'main' into connection_router_v2
vblagoje Nov 13, 2023
d3593e5
Merge branch 'main' into connection_router_v2
vblagoje Nov 17, 2023
c9d2e2e
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 17, 2023
be65d7b
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 17, 2023
e7add96
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 17, 2023
ddda60b
Update haystack/preview/components/routers/conditional_router.py
vblagoje Nov 17, 2023
0eeabc6
Simplify - make all 4 router fileds mandatory
vblagoje Nov 17, 2023
b7bbe26
Black __init__.py
vblagoje Nov 17, 2023
be1fa40
More pydocs
vblagoje Nov 17, 2023
115241e
Minor final touches
vblagoje Nov 17, 2023
fb3f3d8
Remote unit test markers
vblagoje Nov 17, 2023
fb61b11
Remove unit test markers
vblagoje Nov 17, 2023
af97e65
Merge branch 'main' into connection_router_v2
vblagoje Nov 17, 2023
332aa5f
Merge branch 'main' into connection_router_v2
vblagoje Nov 21, 2023
208b2c6
Merge branch 'connection_router_v2' of https://github.com/deepset-ai/…
vblagoje Nov 21, 2023
7eb0943
Improve (de)serialization, handle nested generics
vblagoje Nov 21, 2023
90b68a3
lg update
dfokina Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions haystack/preview/components/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from haystack.preview.components.routers.document_joiner import DocumentJoiner
from haystack.preview.components.routers.file_type_router import FileTypeRouter
from haystack.preview.components.routers.metadata_router import MetadataRouter
from haystack.preview.components.routers.conditional_router import ConditionalRouter
from haystack.preview.components.routers.text_language_router import TextLanguageRouter


__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter"]
__all__ = ["DocumentJoiner", "FileTypeRouter", "MetadataRouter", "TextLanguageRouter", "ConditionalRouter"]
347 changes: 347 additions & 0 deletions haystack/preview/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
import importlib
import inspect
import logging
import sys
from typing import List, Dict, Any, Set, get_origin

from jinja2 import meta, Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment

from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError

logger = logging.getLogger(__name__)


class NoRouteSelectedException(Exception):
"""Exception raised when no route is selected in ConditionalRouter."""


class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""


def serialize_type(target: Any) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this one (and it's sibling function deserialize_type) into an external module, so it can be reused. I think it will be handy for other components too.

"""
Serializes a type or an instance to its string representation, including the module name.

This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.

:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target

# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)

if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")

# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name

return full_path


def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.

This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.

:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""

def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""

for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1

if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char

if current_arg:
args.append(current_arg.strip())

return args

if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]

main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))

# Reconstruct
return main_type[generic_args]

else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]

module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e

deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")

return deserialized_type


@component
class ConditionalRouter:
"""
ConditionalRouter in Haystack 2.x pipelines is designed to manage data routing based on specific conditions.
This is achieved by defining a list named 'routes'. Each element in this list is a dictionary representing a
single route.

A route dictionary comprises four key elements:
- 'condition': A Jinja2 string expression that determines if the route is selected.
- 'output': A Jinja2 expression defining the route's output value.
- 'output_type': The type of the output data (e.g., str, List[int]).
- 'output_name': The name under which the `output` value of the route is published. This name is used to connect
the router to other components in the pipeline.

Here's an example:

```python
from haystack.preview.components.routers import ConditionalRouter

routes = [
{
"condition": "{{streams|length > 2}}",
"output": "{{streams}}",
"output_name": "enough_streams",
"output_type": List[int],
},
{
"condition": "{{streams|length <= 2}}",
"output": "{{streams}}",
"output_name": "insufficient_streams",
"output_type": List[int],
},
]
Comment on lines +153 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, these are dictionary with these fixed 4 keys. How about a small dataclass to help with code completion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then @silvanocerza will tell me - why did you make a data class for this thing 🤣 🤣 Perhaps in the next iteration, final release!

router = ConditionalRouter(routes)
# When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
result = router.run(**kwargs)
assert result == {"enough_streams": [1, 2, 3]}
```

In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there
are two or fewer streams.

In the pipeline setup, the router is connected to other components using the output names. For example, the
'enough_streams' output might be connected to another component that processes the streams, while the
'insufficient_streams' output might be connected to a component that fetches more streams, and so on.

Here is a pseudocode example of a pipeline that uses the ConditionalRouter and routes fetched ByteStreams to
different components depending on the number of streams fetched:

```
from typing import List
from haystack import Pipeline
from haystack.preview.dataclasses import ByteStream
from haystack.preview.components.routers import ConditionalRouter

routes = [
{
"condition": "{{streams|length > 2}}",
"output": "{{streams}}",
"output_name": "enough_streams",
"output_type": List[ByteStream],
},
{
"condition": "{{streams|length <= 2}}",
"output": "{{streams}}",
"output_name": "insufficient_streams",
"output_type": List[ByteStream],
},
]

pipe = Pipeline()
pipe.add_component("router", router)
...
pipe.connect("router.enough_streams", "some_component_a.streams")
pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
...
```
"""

def __init__(self, routes: List[Dict]):
"""
Initializes the ConditionalRouter with a list of routes detailing the conditions for routing.

:param routes: A list of dictionaries, each defining a route with a boolean condition expression
('condition'), an output value ('output'), the output type ('output_type') and
('output_name') that defines the output name for the variable defined in 'output'.
"""
self._validate_routes(routes)
self.routes: List[dict] = routes

# Create a Jinja native environment to inspect variables in the condition templates
env = NativeEnvironment()

# Inspect the routes to determine input and output types.
input_types: Set[str] = set() # let's just store the name, type will always be Any
output_types: Dict[str, str] = {}

for route in routes:
# extract inputs
route_input_names = self._extract_variables(env, [route["output"], route["condition"]])
input_types.update(route_input_names)

# extract outputs
output_types.update({route["output_name"]: route["output_type"]})

component.set_input_types(self, **{var: Any for var in input_types})
component.set_output_types(self, **output_types)

def to_dict(self) -> Dict[str, Any]:
for route in self.routes:
# output_type needs to be serialized to a string
route["output_type"] = serialize_type(route["output_type"])

return default_to_dict(self, routes=self.routes)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
init_params = data.get("init_parameters", {})
routes = init_params.get("routes")
for route in routes:
# output_type needs to be deserialized from a string to a type
route["output_type"] = deserialize_type(route["output_type"])
return default_from_dict(cls, data)

def run(self, **kwargs):
"""
Executes the routing logic by evaluating the specified boolean condition expressions
for each route in the order they are listed. The method directs the flow
of data to the output specified in the first route, whose expression
evaluates to True. If no route's expression evaluates to True, an exception
is raised.

:param kwargs: A dictionary containing the pipeline variables, which should
include all variables used in the "condition" templates.

:return: A dictionary containing the output and the corresponding result,
based on the first route whose expression evaluates to True.

:raises NoRouteSelectedException: If no route's expression evaluates to True.
"""
# Create a Jinja native environment to evaluate the condition templates as Python expressions
env = NativeEnvironment()

for route in self.routes:
try:
t = env.from_string(route["condition"])
if t.render(**kwargs):
# We now evaluate the `output` expression to determine the route output
t_output = env.from_string(route["output"])
output = t_output.render(**kwargs)
# and return the output as a dictionary under the output_name key
return {route["output_name"]: output}
except Exception as e:
raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e

raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting: why failing instead of dropping the input (with a loud log if necessary)? I'd say we should at least give the option to either fail or drop the value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know; it's just a simple solution for now; let's put it in the hands of users, and we'll see what they say. If we add an option, it is yet another variable to turn on/off, describe, test, confuse people, and from my perspective - unnecessary.


def _validate_routes(self, routes: List[Dict]):
"""
Validates a list of routes.

:param routes: A list of routes.
:type routes: List[Dict]
"""
env = NativeEnvironment()
for route in routes:
try:
keys = set(route.keys())
except AttributeError:
raise ValueError(f"Route must be a dictionary, got: {route}")

mandatory_fields = {"condition", "output", "output_type", "output_name"}
has_all_mandatory_fields = mandatory_fields.issubset(keys)
if not has_all_mandatory_fields:
raise ValueError(
f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
)
for field in ["condition", "output"]:
if not self._validate_template(env, route[field]):
raise ValueError(f"Invalid template for field '{field}': {route[field]}")

def _extract_variables(self, env: NativeEnvironment, templates: List[str]) -> Set[str]:
"""
Extracts all variables from a list of Jinja template strings.

:param env: A Jinja environment.
:type env: Environment
:param templates: A list of Jinja template strings.
:type templates: List[str]
:return: A set of variable names.
"""
variables = set()
for template in templates:
ast = env.parse(template)
variables.update(meta.find_undeclared_variables(ast))
return variables

def _validate_template(self, env: Environment, template_text: str):
"""
Validates a template string by parsing it with Jinja.

:param env: A Jinja environment.
:type env: Environment
:param template_text: A Jinja template string.
:type template_text: str
:return: True if the template is valid, False otherwise.
"""
try:
env.parse(template_text)
return True
except TemplateSyntaxError:
return False
6 changes: 6 additions & 0 deletions releasenotes/notes/add-router-f1f0cec79b1efe9a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
preview:
- |
Add `ConditionalRouter` component to enhance the conditional pipeline routing capabilities.
The `ConditionalRouter` component orchestrates the flow of data by evaluating specified route conditions
to determine the appropriate route among a set of provided route alternatives.
Loading