Skip to content

Commit

Permalink
chore: move HuggingFaceLocalGenerator to the generators directory (
Browse files Browse the repository at this point in the history
…#6264)

* move HuggingFaceLocalGenerator to right directory

* fix tests
  • Loading branch information
anakin87 authored Nov 9, 2023
1 parent 2b3c77e commit f95937b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 15 deletions.
2 changes: 1 addition & 1 deletion haystack/preview/components/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.preview.components.generators.openai import GPTGenerator

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HuggingFaceLocalGenerator:
Usage example:
```python
from haystack.preview.components.generators.hugging_face import HuggingFaceLocalGenerator
from haystack.preview.components.generators import HuggingFaceLocalGenerator
generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
task="text2text-generation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import pytest
import torch

from haystack.preview.components.generators.hugging_face.hugging_face_local import (
HuggingFaceLocalGenerator,
StopWordsCriteria,
)
from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria


class TestHuggingFaceLocalGenerator:
@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
@patch("haystack.preview.components.generators.hugging_face_local.model_info")
def test_init_default(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator()
Expand Down Expand Up @@ -71,7 +68,7 @@ def test_init_task_in_pipeline_kwargs(self):
}

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
@patch("haystack.preview.components.generators.hugging_face_local.model_info")
def test_init_task_inferred_from_model_name(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-base")
Expand Down Expand Up @@ -140,7 +137,7 @@ def test_init_fails_with_both_stopwords_and_stoppingcriteria(self):
)

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.model_info")
@patch("haystack.preview.components.generators.hugging_face_local.model_info")
def test_to_dict_default(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"

Expand Down Expand Up @@ -183,7 +180,7 @@ def test_to_dict_with_parameters(self):
}

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
@patch("haystack.preview.components.generators.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock):
generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
Expand All @@ -197,7 +194,7 @@ def test_warm_up(self, pipeline_mock):
)

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
@patch("haystack.preview.components.generators.hugging_face_local.pipeline")
def test_warm_up_doesn_reload(self, pipeline_mock):
generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base", task="text2text-generation", token="test-token"
Expand Down Expand Up @@ -229,7 +226,7 @@ def test_run(self):
assert results == {"replies": ["Rome"]}

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
@patch("haystack.preview.components.generators.hugging_face_local.pipeline")
def test_run_empty_prompt(self, pipeline_mock):
generator = HuggingFaceLocalGenerator(
model_name_or_path="google/flan-t5-base",
Expand Down Expand Up @@ -309,9 +306,9 @@ def test_stop_words_criteria(self):
assert present_and_continuous

@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.StopWordsCriteria")
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.StoppingCriteriaList")
@patch("haystack.preview.components.generators.hugging_face_local.pipeline")
@patch("haystack.preview.components.generators.hugging_face_local.StopWordsCriteria")
@patch("haystack.preview.components.generators.hugging_face_local.StoppingCriteriaList")
def test_warm_up_set_stopping_criteria_list(
self, pipeline_mock, stop_words_criteria_mock, stopping_criteria_list_mock
):
Expand Down

0 comments on commit f95937b

Please sign in to comment.