Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Join component #5852

Closed
wants to merge 15 commits into from
Empty file.
47 changes: 47 additions & 0 deletions haystack/preview/components/joiners/join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Type
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
from haystack.preview.utils import marshal_type, unmarshal_type


@component
class Join:
"""
Simple component that joins together a group of inputs of the same type. Works with every type that supports
the + operator for joining, like lists, strings, etc.
"""

def __init__(self, inputs_count: int, inputs_type: Type):
"""
:param inputs_count: the number of inputs to expect.
:param inputs_type: the type of the inputs. Every type that supports the + operator works.
"""
if inputs_count < 1:
raise ValueError("inputs_count must be at least 1")
self.inputs_count = inputs_count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not set it to 2 by default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why having inputs_count at all? The degenerate case would be having only one input, that would make the component a no-op

Copy link
Contributor Author

@ZanSara ZanSara Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @masci I don't get your point. I can set it to be minimum two, but in principle being able to set how many inputs to expect seems way more useful than fixing it to two. There are many cases where you want to aggregate several values (just look at this pipeline)

Copy link
Contributor

@masci masci Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why having inputs_count at all?

I'm not saying to hardcode to two, I'm talking about removing it, why do we need an upper limit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an upper limit: you need to specify in advance how many inputs you expect. I mean, I could set it to a huge number and make them all optional, but it will make quite some noise in the error messages: if you fail to connect it, the error message will list all the possible input connections and the error becomes seriously unreadable 😅 I'd rather let the user specify how many they want and stick to that. It makes debugging easier.

If this is a big deal let's discuss it offline.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great to remove input_counts parameter if possible.

self.inputs_type = inputs_type
component.set_input_types(self, **{f"input_{i}": inputs_type for i in range(inputs_count)})
component.set_output_types(self, output=inputs_type)

def to_dict(self):
return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=marshal_type(self.inputs_type))

@classmethod
def from_dict(cls, data):
if not "inputs_type" in data["init_parameters"]:
raise DeserializationError("The inputs_type parameter for Join is missing.")
data["init_parameters"]["inputs_type"] = unmarshal_type(data["init_parameters"]["inputs_type"])
return default_from_dict(cls, data)

def run(self, **kwargs):
"""
Joins together a group of inputs of the same type. Works with every type that supports the + operator,
like lists, strings, etc.
"""
if len(kwargs) != self.inputs_count:
raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}")

values = list(kwargs.values())
output = values[0]
for values in values[1:]:
output += values
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
return {"output": output}
1 change: 1 addition & 0 deletions haystack/preview/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from haystack.preview.utils.expit import expit
from haystack.preview.utils.requests_utils import request_with_retry
from haystack.preview.utils.filters import document_matches_filter
from haystack.preview.utils.marshalling import marshal_type, unmarshal_type
46 changes: 46 additions & 0 deletions haystack/preview/utils/marshalling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Type
import builtins
import sys

from haystack.preview import DeserializationError


def marshal_type(type_: Type) -> str:
"""
Given a type, return a string representation that can be unmarshalled.

:param type_: The type.
:return: Its string representation.
"""
module = type_.__module__
if module == "builtins":
return type_.__name__
return f"{module}.{type_.__name__}"


def unmarshal_type(type_name: str) -> Type:
"""
Given the string representation of a type, return the type itself.

:param type_name: The string representation of the type.
:return: The type itself.
"""
if "." not in type_name:
type_ = getattr(builtins, type_name, None)
if not type_:
raise DeserializationError(f"Could not locate builtin called '{type_name}'")
return type_

parts = type_name.split(".")
module_name = ".".join(parts[:-1])
type_name = parts[-1]

module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module '{module_name}'")

type_ = getattr(module, type_name, None)
if not type_:
raise DeserializationError(f"Could not locate the type '{type_name}'")

return type_
3 changes: 3 additions & 0 deletions releasenotes/notes/join-lists-1307f7872a37e238.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
preview:
- Add `Join`, a small component that can be used to join lists and other types supporting the + operator.
42 changes: 42 additions & 0 deletions test/preview/components/joiner/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List

import pytest

from haystack.preview.components.joiners.join import Join


class TestJoin:
@pytest.mark.unit
def test_join_to_dict(self):
comp = Join(inputs_count=2, inputs_type=str)
assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}}

@pytest.mark.unit
def test_join_from_dict(self):
data = {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}}
comp = Join.from_dict(data)
assert comp.inputs_count == 2
assert comp.inputs_type == str

@pytest.mark.unit
def test_join_list(self):
comp = Join(inputs_count=2, inputs_type=List[int])
output = comp.run(input_0=[1, 2], input_1=[3, 4])
assert output == {"output": [1, 2, 3, 4]}

@pytest.mark.unit
def test_join_str(self):
comp = Join(inputs_count=2, inputs_type=str)
output = comp.run(input_0="hello", input_1="test")
assert output == {"output": "hellotest"}

@pytest.mark.unit
def test_join_one_input(self):
comp = Join(inputs_count=1, inputs_type=str)
output = comp.run(input_0="hello")
assert output == {"output": "hello"}

@pytest.mark.unit
def test_join_zero_input(self):
with pytest.raises(ValueError):
Join(inputs_count=0, inputs_type=str)
37 changes: 37 additions & 0 deletions test/preview/utils/test_marshalling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from haystack.preview import Document, DeserializationError
from haystack.preview.utils.marshalling import marshal_type, unmarshal_type


TYPE_STRING_PAIRS = [(int, "int"), (Document, "haystack.preview.dataclasses.document.Document")]


@pytest.mark.unit
@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS)
def test_marshal_type(type_, string):
assert marshal_type(type_) == string


@pytest.mark.unit
@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS)
def test_unmarshal_type(type_, string):
assert unmarshal_type(string) == type_


@pytest.mark.unit
def test_unmarshal_type_missing_builtin():
with pytest.raises(DeserializationError, match="Could not locate builtin called 'something'"):
unmarshal_type("something")


@pytest.mark.unit
def test_unmarshal_type_missing_module():
with pytest.raises(DeserializationError, match="Could not locate the module 'something'"):
unmarshal_type("something.int")


@pytest.mark.unit
def test_unmarshal_type_missing_type():
with pytest.raises(DeserializationError, match="Could not locate the type 'Documentttt'"):
unmarshal_type("haystack.preview.dataclasses.document.Documentttt")