diff --git a/haystack/preview/components/joiners/__init__.py b/haystack/preview/components/joiners/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py new file mode 100644 index 0000000000..e819e41b99 --- /dev/null +++ b/haystack/preview/components/joiners/join.py @@ -0,0 +1,52 @@ +from typing import Type +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError, ComponentError +from haystack.preview.utils import marshal_type, unmarshal_type + + +@component +class Join: + """ + A simple component that joins together a group of inputs of the same type. Works with every type that supports + the + operator for joining, such as lists, strings, etc. + """ + + def __init__(self, inputs_count: int, inputs_type: Type): + """ + :param inputs_count: The number of inputs to expect. + :param inputs_type: The type of the inputs. Every type that supports the + operator works. + """ + if inputs_count < 1: + raise ValueError("inputs_count must be at least 1") + self.inputs_count = inputs_count + self.inputs_type = inputs_type + component.set_input_types(self, **{f"input_{i}": inputs_type for i in range(inputs_count)}) + component.set_output_types(self, output=inputs_type) + + def to_dict(self): + return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=marshal_type(self.inputs_type)) + + @classmethod + def from_dict(cls, data): + if not "inputs_type" in data["init_parameters"]: + raise DeserializationError("The inputs_type parameter for Join is missing.") + data["init_parameters"]["inputs_type"] = unmarshal_type(data["init_parameters"]["inputs_type"]) + return default_from_dict(cls, data) + + def run(self, **kwargs): + """ + Joins together a group of inputs of the same type. Works with every type that supports the + operator, + such as lists, strings, etc. + """ + if len(kwargs) != self.inputs_count: + raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}") + + values = list(kwargs.values()) + output = values[0] + try: + for values in values[1:]: + output += values + except TypeError: + raise ComponentError( + f"Join expected inputs of a type that supports the + operator, but got: {[type(v) for v in values]}" + ) + return {"output": output} diff --git a/haystack/preview/utils/__init__.py b/haystack/preview/utils/__init__.py index a84ea468e2..9b13aa8604 100644 --- a/haystack/preview/utils/__init__.py +++ b/haystack/preview/utils/__init__.py @@ -1,3 +1,4 @@ from haystack.preview.utils.expit import expit from haystack.preview.utils.requests_utils import request_with_retry from haystack.preview.utils.filters import document_matches_filter +from haystack.preview.utils.marshalling import marshal_type, unmarshal_type diff --git a/haystack/preview/utils/marshalling.py b/haystack/preview/utils/marshalling.py new file mode 100644 index 0000000000..1248b3cac0 --- /dev/null +++ b/haystack/preview/utils/marshalling.py @@ -0,0 +1,46 @@ +from typing import Type +import builtins +import sys + +from haystack.preview import DeserializationError + + +def marshal_type(type_: Type) -> str: + """ + Given a type, return a string representation that can be unmarshalled. + + :param type_: The type. + :return: Its string representation. + """ + module = type_.__module__ + if module == "builtins": + return type_.__name__ + return f"{module}.{type_.__name__}" + + +def unmarshal_type(type_name: str) -> Type: + """ + Given the string representation of a type, return the type itself. + + :param type_name: The string representation of the type. + :return: The type itself. + """ + if "." not in type_name: + type_ = getattr(builtins, type_name, None) + if not type_: + raise DeserializationError(f"Could not locate builtin called '{type_name}'") + return type_ + + parts = type_name.split(".") + module_name = ".".join(parts[:-1]) + type_name = parts[-1] + + module = sys.modules.get(module_name, None) + if not module: + raise DeserializationError(f"Could not locate the module '{module_name}'") + + type_ = getattr(module, type_name, None) + if not type_: + raise DeserializationError(f"Could not locate the type '{type_name}'") + + return type_ diff --git a/releasenotes/notes/join-lists-1307f7872a37e238.yaml b/releasenotes/notes/join-lists-1307f7872a37e238.yaml new file mode 100644 index 0000000000..06f96b0add --- /dev/null +++ b/releasenotes/notes/join-lists-1307f7872a37e238.yaml @@ -0,0 +1,3 @@ +--- +preview: + - Add `Join`, a small component that can be used to join lists and other types supporting the + operator. diff --git a/test/preview/components/joiner/test_join.py b/test/preview/components/joiner/test_join.py new file mode 100644 index 0000000000..481c04e392 --- /dev/null +++ b/test/preview/components/joiner/test_join.py @@ -0,0 +1,42 @@ +from typing import List + +import pytest + +from haystack.preview.components.joiners.join import Join + + +class TestJoin: + @pytest.mark.unit + def test_join_to_dict(self): + comp = Join(inputs_count=2, inputs_type=str) + assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} + + @pytest.mark.unit + def test_join_from_dict(self): + data = {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} + comp = Join.from_dict(data) + assert comp.inputs_count == 2 + assert comp.inputs_type == str + + @pytest.mark.unit + def test_join_list(self): + comp = Join(inputs_count=2, inputs_type=List[int]) + output = comp.run(input_0=[1, 2], input_1=[3, 4]) + assert output == {"output": [1, 2, 3, 4]} + + @pytest.mark.unit + def test_join_str(self): + comp = Join(inputs_count=2, inputs_type=str) + output = comp.run(input_0="hello", input_1="test") + assert output == {"output": "hellotest"} + + @pytest.mark.unit + def test_join_one_input(self): + comp = Join(inputs_count=1, inputs_type=str) + output = comp.run(input_0="hello") + assert output == {"output": "hello"} + + @pytest.mark.unit + def test_join_zero_input(self): + with pytest.raises(ValueError): + Join(inputs_count=0, inputs_type=str) diff --git a/test/preview/utils/test_marshalling.py b/test/preview/utils/test_marshalling.py new file mode 100644 index 0000000000..c44d26c373 --- /dev/null +++ b/test/preview/utils/test_marshalling.py @@ -0,0 +1,37 @@ +import pytest + +from haystack.preview import Document, DeserializationError +from haystack.preview.utils.marshalling import marshal_type, unmarshal_type + + +TYPE_STRING_PAIRS = [(int, "int"), (Document, "haystack.preview.dataclasses.document.Document")] + + +@pytest.mark.unit +@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) +def test_marshal_type(type_, string): + assert marshal_type(type_) == string + + +@pytest.mark.unit +@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) +def test_unmarshal_type(type_, string): + assert unmarshal_type(string) == type_ + + +@pytest.mark.unit +def test_unmarshal_type_missing_builtin(): + with pytest.raises(DeserializationError, match="Could not locate builtin called 'something'"): + unmarshal_type("something") + + +@pytest.mark.unit +def test_unmarshal_type_missing_module(): + with pytest.raises(DeserializationError, match="Could not locate the module 'something'"): + unmarshal_type("something.int") + + +@pytest.mark.unit +def test_unmarshal_type_missing_type(): + with pytest.raises(DeserializationError, match="Could not locate the type 'Documentttt'"): + unmarshal_type("haystack.preview.dataclasses.document.Documentttt")