Skip to content

Commit

Permalink
refactor!: improve the public interface of Generators (#6374)
Browse files Browse the repository at this point in the history
* merge lazy import blocks

* refactor generators

* release note

* revert unrelated changes
  • Loading branch information
anakin87 authored Nov 22, 2023
1 parent b751978 commit e91f7a8
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 78 deletions.
6 changes: 3 additions & 3 deletions haystack/preview/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
model_name: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: str = API_BASE_URL,
**generation_kwargs,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
openai.api_key = api_key

self.model_name = model_name
self.generation_kwargs = generation_kwargs
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback

self.api_base_url = api_base_url
Expand All @@ -133,7 +133,7 @@ def to_dict(self) -> Dict[str, Any]:
model_name=self.model_name,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.generation_kwargs,
generation_kwargs=self.generation_kwargs,
)

@classmethod
Expand Down
58 changes: 33 additions & 25 deletions haystack/preview/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,34 @@ def __init__(
device: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
):
"""
:param model_name_or_path: The name or path of a Hugging Face model for text generation,
for example, "google/flan-t5-large".
If the model is also specified in the `pipeline_kwargs`, this parameter will be ignored.
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param task: The task for the Hugging Face pipeline.
Possible values are "text-generation" and "text2text-generation".
Generally, decoder-only models like GPT support "text-generation",
while encoder-decoder models like T5 support "text2text-generation".
If the task is also specified in the `pipeline_kwargs`, this parameter will be ignored.
If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
If not specified, the component will attempt to infer the task from the model name,
calling the Hugging Face Hub API.
:param device: The device on which the model is loaded. (e.g., "cpu", "cuda:0").
If `device` or `device_map` is specified in the `pipeline_kwargs`, this parameter will be ignored.
If `device` or `device_map` is specified in the `huggingface_pipeline_kwargs`,
this parameter will be ignored.
:param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `pipeline_kwargs`, this parameter will be ignored.
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
See Hugging Face's documentation for more information:
- https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation
- https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
:param pipeline_kwargs: Dictionary containing keyword arguments used to initialize the pipeline.
These keyword arguments provide fine-grained control over the pipeline.
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text generation.
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters.
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
for more information on the available kwargs.
Expand All @@ -125,28 +127,34 @@ def __init__(
"""
torch_and_transformers_import.check()

pipeline_kwargs = pipeline_kwargs or {}
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}

# check if the pipeline_kwargs contain the essential parameters
# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
pipeline_kwargs.setdefault("model", model_name_or_path)
pipeline_kwargs.setdefault("token", token)
if device is not None and "device" not in pipeline_kwargs and "device_map" not in pipeline_kwargs:
pipeline_kwargs["device"] = device
huggingface_pipeline_kwargs.setdefault("model", model_name_or_path)
huggingface_pipeline_kwargs.setdefault("token", token)
if (
device is not None
and "device" not in huggingface_pipeline_kwargs
and "device_map" not in huggingface_pipeline_kwargs
):
huggingface_pipeline_kwargs["device"] = device

# task identification and validation
if task is None:
if "task" in pipeline_kwargs:
task = pipeline_kwargs["task"]
elif isinstance(pipeline_kwargs["model"], str):
task = model_info(pipeline_kwargs["model"], token=pipeline_kwargs["token"]).pipeline_tag
if "task" in huggingface_pipeline_kwargs:
task = huggingface_pipeline_kwargs["task"]
elif isinstance(huggingface_pipeline_kwargs["model"], str):
task = model_info(
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
).pipeline_tag

if task not in SUPPORTED_TASKS:
raise ValueError(
f"Task '{task}' is not supported. " f"The supported tasks are: {', '.join(SUPPORTED_TASKS)}."
)
pipeline_kwargs["task"] = task
huggingface_pipeline_kwargs["task"] = task

# if not specified, set return_full_text to False for text-generation
# only generated text is returned (excluding prompt)
Expand All @@ -159,7 +167,7 @@ def __init__(
"Please specify only one of them."
)

self.pipeline_kwargs = pipeline_kwargs
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
self.stop_words = stop_words
self.pipeline = None
Expand All @@ -169,13 +177,13 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
if isinstance(self.pipeline_kwargs["model"], str):
return {"model": self.pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.pipeline_kwargs['model'])}]"}
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
return {"model": self.huggingface_pipeline_kwargs["model"]}
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}

def warm_up(self):
if self.pipeline is None:
self.pipeline = pipeline(**self.pipeline_kwargs)
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)

if self.stop_words and self.stopping_criteria_list is None:
stop_words_criteria = StopWordsCriteria(
Expand All @@ -187,15 +195,15 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
pipeline_kwargs_to_serialize = deepcopy(self.pipeline_kwargs)
pipeline_kwargs_to_serialize = deepcopy(self.huggingface_pipeline_kwargs)

# we don't want to serialize valid tokens
if isinstance(pipeline_kwargs_to_serialize["token"], str):
pipeline_kwargs_to_serialize["token"] = None

return default_to_dict(
self,
pipeline_kwargs=pipeline_kwargs_to_serialize,
huggingface_pipeline_kwargs=pipeline_kwargs_to_serialize,
generation_kwargs=self.generation_kwargs,
stop_words=self.stop_words,
)
Expand Down
6 changes: 3 additions & 3 deletions haystack/preview/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: str = API_BASE_URL,
system_prompt: Optional[str] = None,
**generation_kwargs,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
openai.api_key = api_key

self.model_name = model_name
self.generation_kwargs = generation_kwargs
self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt
self.streaming_callback = streaming_callback

Expand All @@ -129,7 +129,7 @@ def to_dict(self) -> Dict[str, Any]:
model_name=self.model_name,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.generation_kwargs,
generation_kwargs=self.generation_kwargs,
system_prompt=self.system_prompt,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
preview:
- |
Improve the public interface of the Generators:
- make `generation_kwargs` a dictionary
- rename `pipeline_kwargs` (in `HuggingFaceLocalGenerator`) to `huggingface_pipeline_kwargs`
26 changes: 10 additions & 16 deletions test/preview/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def test_init_with_parameters(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert openai.api_key == "test-api-key"
assert component.model_name == "gpt-4"
Expand All @@ -110,6 +109,7 @@ def test_to_dict_default(self):
"model_name": "gpt-3.5-turbo",
"streaming_callback": None,
"api_base_url": "https://api.openai.com/v1",
"generation_kwargs": {},
},
}

Expand All @@ -118,20 +118,18 @@ def test_to_dict_with_parameters(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

Expand All @@ -140,20 +138,18 @@ def test_to_dict_with_lambda_streaming_callback(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "chat.test_openai.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}

Expand All @@ -164,10 +160,9 @@ def test_from_dict(self, monkeypatch):
"type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = GPTChatGenerator.from_dict(data)
Expand All @@ -184,10 +179,9 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"type": "haystack.preview.components.generators.chat.openai.GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
with pytest.raises(ValueError, match="GPTChatGenerator expects an OpenAI API key"):
Expand All @@ -207,7 +201,7 @@ def test_run(self, chat_messages, mock_chat_completion):

@pytest.mark.unit
def test_run_with_params(self, chat_messages, mock_chat_completion):
component = GPTChatGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5)
component = GPTChatGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
response = component.run(chat_messages)

# check that the component calls the OpenAI API with the correct parameters
Expand Down Expand Up @@ -288,7 +282,7 @@ def test_check_abnormal_completions(self, caplog):
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
Expand Down
Loading

0 comments on commit e91f7a8

Please sign in to comment.