-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Join
component
#5852
Closed
Closed
feat: Join
component
#5852
Changes from 12 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
bcd67c7
fix tests
ZanSara 9253431
add component
ZanSara 4392537
add test
ZanSara bfd7717
reno
ZanSara 1a25d2f
typo
ZanSara 74eed59
add serialization
ZanSara 8192e34
stray changes
ZanSara 7b1767e
add tests
ZanSara 7499871
fix order
ZanSara 0657dc2
marshalling types
ZanSara 617fffd
marshalling tests
ZanSara a80cf9b
Merge branch 'main' into join-lists
ZanSara 28246c5
docstrings update
dfokina 684c0bb
explain the type error
ZanSara 2ef4ad9
Merge branch 'main' into join-lists
ZanSara File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Type | ||
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError | ||
from haystack.preview.utils import marshal_type, unmarshal_type | ||
|
||
|
||
@component | ||
class Join: | ||
""" | ||
Simple component that joins together a group of inputs of the same type. Works with every type that supports | ||
the + operator for joining, like lists, strings, etc. | ||
""" | ||
|
||
def __init__(self, inputs_count: int, inputs_type: Type): | ||
""" | ||
:param inputs_count: the number of inputs to expect. | ||
:param inputs_type: the type of the inputs. Every type that supports the + operator works. | ||
""" | ||
if inputs_count < 1: | ||
raise ValueError("inputs_count must be at least 1") | ||
self.inputs_count = inputs_count | ||
self.inputs_type = inputs_type | ||
component.set_input_types(self, **{f"input_{i}": inputs_type for i in range(inputs_count)}) | ||
component.set_output_types(self, output=inputs_type) | ||
|
||
def to_dict(self): | ||
return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=marshal_type(self.inputs_type)) | ||
|
||
@classmethod | ||
def from_dict(cls, data): | ||
if not "inputs_type" in data["init_parameters"]: | ||
raise DeserializationError("The inputs_type parameter for Join is missing.") | ||
data["init_parameters"]["inputs_type"] = unmarshal_type(data["init_parameters"]["inputs_type"]) | ||
return default_from_dict(cls, data) | ||
|
||
def run(self, **kwargs): | ||
""" | ||
Joins together a group of inputs of the same type. Works with every type that supports the + operator, | ||
like lists, strings, etc. | ||
""" | ||
if len(kwargs) != self.inputs_count: | ||
raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}") | ||
|
||
values = list(kwargs.values()) | ||
output = values[0] | ||
for values in values[1:]: | ||
output += values | ||
ZanSara marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return {"output": output} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from haystack.preview.utils.expit import expit | ||
from haystack.preview.utils.requests_utils import request_with_retry | ||
from haystack.preview.utils.filters import document_matches_filter | ||
from haystack.preview.utils.marshalling import marshal_type, unmarshal_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Type | ||
import builtins | ||
import sys | ||
|
||
from haystack.preview import DeserializationError | ||
|
||
|
||
def marshal_type(type_: Type) -> str: | ||
""" | ||
Given a type, return a string representation that can be unmarshalled. | ||
|
||
:param type_: The type. | ||
:return: Its string representation. | ||
""" | ||
module = type_.__module__ | ||
if module == "builtins": | ||
return type_.__name__ | ||
return f"{module}.{type_.__name__}" | ||
|
||
|
||
def unmarshal_type(type_name: str) -> Type: | ||
""" | ||
Given the string representation of a type, return the type itself. | ||
|
||
:param type_name: The string representation of the type. | ||
:return: The type itself. | ||
""" | ||
if "." not in type_name: | ||
type_ = getattr(builtins, type_name, None) | ||
if not type_: | ||
raise DeserializationError(f"Could not locate builtin called '{type_name}'") | ||
return type_ | ||
|
||
parts = type_name.split(".") | ||
module_name = ".".join(parts[:-1]) | ||
type_name = parts[-1] | ||
|
||
module = sys.modules.get(module_name, None) | ||
if not module: | ||
raise DeserializationError(f"Could not locate the module '{module_name}'") | ||
|
||
type_ = getattr(module, type_name, None) | ||
if not type_: | ||
raise DeserializationError(f"Could not locate the type '{type_name}'") | ||
|
||
return type_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
--- | ||
preview: | ||
- Add `Join`, a small component that can be used to join lists and other types supporting the + operator. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from haystack.preview.components.joiners.join import Join | ||
|
||
|
||
class TestJoin: | ||
@pytest.mark.unit | ||
def test_join_to_dict(self): | ||
comp = Join(inputs_count=2, inputs_type=str) | ||
assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} | ||
|
||
@pytest.mark.unit | ||
def test_join_from_dict(self): | ||
data = {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} | ||
comp = Join.from_dict(data) | ||
assert comp.inputs_count == 2 | ||
assert comp.inputs_type == str | ||
|
||
@pytest.mark.unit | ||
def test_join_list(self): | ||
comp = Join(inputs_count=2, inputs_type=List[int]) | ||
output = comp.run(input_0=[1, 2], input_1=[3, 4]) | ||
assert output == {"output": [1, 2, 3, 4]} | ||
|
||
@pytest.mark.unit | ||
def test_join_str(self): | ||
comp = Join(inputs_count=2, inputs_type=str) | ||
output = comp.run(input_0="hello", input_1="test") | ||
assert output == {"output": "hellotest"} | ||
|
||
@pytest.mark.unit | ||
def test_join_one_input(self): | ||
comp = Join(inputs_count=1, inputs_type=str) | ||
output = comp.run(input_0="hello") | ||
assert output == {"output": "hello"} | ||
|
||
@pytest.mark.unit | ||
def test_join_zero_input(self): | ||
with pytest.raises(ValueError): | ||
Join(inputs_count=0, inputs_type=str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
|
||
from haystack.preview import Document, DeserializationError | ||
from haystack.preview.utils.marshalling import marshal_type, unmarshal_type | ||
|
||
|
||
TYPE_STRING_PAIRS = [(int, "int"), (Document, "haystack.preview.dataclasses.document.Document")] | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) | ||
def test_marshal_type(type_, string): | ||
assert marshal_type(type_) == string | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) | ||
def test_unmarshal_type(type_, string): | ||
assert unmarshal_type(string) == type_ | ||
|
||
|
||
@pytest.mark.unit | ||
def test_unmarshal_type_missing_builtin(): | ||
with pytest.raises(DeserializationError, match="Could not locate builtin called 'something'"): | ||
unmarshal_type("something") | ||
|
||
|
||
@pytest.mark.unit | ||
def test_unmarshal_type_missing_module(): | ||
with pytest.raises(DeserializationError, match="Could not locate the module 'something'"): | ||
unmarshal_type("something.int") | ||
|
||
|
||
@pytest.mark.unit | ||
def test_unmarshal_type_missing_type(): | ||
with pytest.raises(DeserializationError, match="Could not locate the type 'Documentttt'"): | ||
unmarshal_type("haystack.preview.dataclasses.document.Documentttt") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not set it to 2 by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why having
inputs_count
at all? The degenerate case would be having only one input, that would make the component a no-opThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry @masci I don't get your point. I can set it to be minimum two, but in principle being able to set how many inputs to expect seems way more useful than fixing it to two. There are many cases where you want to aggregate several values (just look at this pipeline)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not saying to hardcode to two, I'm talking about removing it, why do we need an upper limit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not an upper limit: you need to specify in advance how many inputs you expect. I mean, I could set it to a huge number and make them all optional, but it will make quite some noise in the error messages: if you fail to connect it, the error message will list all the possible input connections and the error becomes seriously unreadable 😅 I'd rather let the user specify how many they want and stick to that. It makes debugging easier.
If this is a big deal let's discuss it offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would be great to remove
input_counts
parameter if possible.