Skip to content

Commit

Permalink
test: Refactor some retriever tests into unit tests (#5306)
Browse files Browse the repository at this point in the history
* Modify and reactivate two unit tests

* Refactor openai embedding tests into unit tests

* Update test_retriever.py

* Changing tests
  • Loading branch information
sjrl authored Jul 11, 2023
1 parent 514f93a commit 22750d3
Showing 1 changed file with 56 additions and 35 deletions.
91 changes: 56 additions & 35 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from haystack.nodes.retriever._openai_encoder import _OpenAIEmbeddingEncoder

from ..conftest import MockBaseRetriever, fail_at_version

Expand Down Expand Up @@ -120,21 +121,24 @@ def test_retrieval_with_filters(retriever_with_docs: BaseRetriever, document_sto
assert len(result) == 0


def test_tfidf_retriever_multiple_indexes():
@pytest.mark.unit
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_tfidf_retriever_multiple_indexes(document_store: BaseDocumentStore):
docs_index_0 = [Document(content="test_1"), Document(content="test_2"), Document(content="test_3")]
docs_index_1 = [Document(content="test_4"), Document(content="test_5")]
ds = InMemoryDocumentStore(index="index_0")
tfidf_retriever = TfidfRetriever(document_store=ds)
tfidf_retriever = TfidfRetriever(document_store=document_store)

ds.write_documents(docs_index_0)
tfidf_retriever.fit(ds, index="index_0")
ds.write_documents(docs_index_1, index="index_1")
tfidf_retriever.fit(ds, index="index_1")
document_store.write_documents(docs_index_0, index="index_0")
tfidf_retriever.fit(document_store, index="index_0")
document_store.write_documents(docs_index_1, index="index_1")
tfidf_retriever.fit(document_store, index="index_1")

assert tfidf_retriever.document_counts["index_0"] == ds.get_document_count(index="index_0")
assert tfidf_retriever.document_counts["index_1"] == ds.get_document_count(index="index_1")
assert tfidf_retriever.document_counts["index_0"] == document_store.get_document_count(index="index_0")
assert tfidf_retriever.document_counts["index_1"] == document_store.get_document_count(index="index_1")


@pytest.mark.unit
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_retrieval_empty_query(document_store: BaseDocumentStore):
# test with empty query using the run() method
mock_document = Document(id="0", content="test")
Expand Down Expand Up @@ -318,34 +322,51 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids):
assert isclose(embedding[0], expected_value, rel_tol=0.001)


def test_openai_embedding_retriever_selection():
# OpenAI released (Dec 2022) a unifying embedding model called text-embedding-ada-002
# make sure that we can use it with the retriever selection
er = EmbeddingRetriever(embedding_model="text-embedding-ada-002", document_store=None)
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002"
assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002"
assert er.api_base == "https://api.openai.com/v1"

# but also support old ada and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever(embedding_model="ada", document_store=None)
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001"
assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001"
assert er.api_base == "https://api.openai.com/v1"

# but also support old babbage and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever(embedding_model="babbage", document_store=None)
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-search-babbage-query-001"
assert er.embedding_encoder.doc_encoder_model == "text-search-babbage-doc-001"
@pytest.mark.unit
def test_openai_embedding_retriever_model_format():
# support text-embedding-ada-002
assert (
EmbeddingRetriever._infer_model_format(model_name_or_path="text-embedding-ada-002", use_auth_token=None)
== "openai"
)

# support old ada and other text-search-<modelname>-*-001 models
assert EmbeddingRetriever._infer_model_format(model_name_or_path="ada", use_auth_token=None) == "openai"

# support old babbage and other text-search-<modelname>-*-001 models
assert EmbeddingRetriever._infer_model_format(model_name_or_path="babbage", use_auth_token=None) == "openai"

# make sure that we can handle potential unreleased models
assert (
EmbeddingRetriever._infer_model_format(model_name_or_path="text-embedding-babbage-002", use_auth_token=None)
== "openai"
)


@pytest.mark.unit
def test_openai_encoder_setup_encoding_models():
with patch("haystack.nodes.retriever._openai_encoder._OpenAIEmbeddingEncoder.__init__") as mock_encoder_init:
mock_encoder_init.return_value = None
encoder = _OpenAIEmbeddingEncoder(retriever=None) # type: ignore

encoder._setup_encoding_models(model_class="ada", model_name="text-embedding-ada-002", max_seq_len=512)
assert encoder.query_encoder_model == "text-embedding-ada-002"
assert encoder.doc_encoder_model == "text-embedding-ada-002"

# support old ada and other text-search-<modelname>-*-001 models
encoder._setup_encoding_models(model_class="ada", model_name="ada", max_seq_len=512)
assert encoder.query_encoder_model == "text-search-ada-query-001"
assert encoder.doc_encoder_model == "text-search-ada-doc-001"

# support old babbage and other text-search-<modelname>-*-001 models
encoder._setup_encoding_models(model_class="babbage", model_name="babbage", max_seq_len=512)
assert encoder.query_encoder_model == "text-search-babbage-query-001"
assert encoder.doc_encoder_model == "text-search-babbage-doc-001"

# make sure that we can handle potential unreleased models
er = EmbeddingRetriever(embedding_model="text-embedding-babbage-002", document_store=None)
assert er.model_format == "openai"
assert er.embedding_encoder.query_encoder_model == "text-embedding-babbage-002"
assert er.embedding_encoder.doc_encoder_model == "text-embedding-babbage-002"
# etc etc.
encoder._setup_encoding_models(model_class="babbage", model_name="text-embedding-babbage-002", max_seq_len=512)
assert encoder.query_encoder_model == "text-embedding-babbage-002"
assert encoder.doc_encoder_model == "text-embedding-babbage-002"


@pytest.mark.integration
Expand Down

0 comments on commit 22750d3

Please sign in to comment.