From 22750d342ca3ad44481ba8704016174e954e20f3 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 11 Jul 2023 13:36:23 +0200 Subject: [PATCH] test: Refactor some retriever tests into unit tests (#5306) * Modify and reactivate two unit tests * Refactor openai embedding tests into unit tests * Update test_retriever.py * Changing tests --- test/nodes/test_retriever.py | 91 ++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 07b86adc02..d86b30dfc1 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -33,6 +33,7 @@ from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from haystack.nodes.retriever.multimodal import MultiModalRetriever +from haystack.nodes.retriever._openai_encoder import _OpenAIEmbeddingEncoder from ..conftest import MockBaseRetriever, fail_at_version @@ -120,21 +121,24 @@ def test_retrieval_with_filters(retriever_with_docs: BaseRetriever, document_sto assert len(result) == 0 -def test_tfidf_retriever_multiple_indexes(): +@pytest.mark.unit +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +def test_tfidf_retriever_multiple_indexes(document_store: BaseDocumentStore): docs_index_0 = [Document(content="test_1"), Document(content="test_2"), Document(content="test_3")] docs_index_1 = [Document(content="test_4"), Document(content="test_5")] - ds = InMemoryDocumentStore(index="index_0") - tfidf_retriever = TfidfRetriever(document_store=ds) + tfidf_retriever = TfidfRetriever(document_store=document_store) - ds.write_documents(docs_index_0) - tfidf_retriever.fit(ds, index="index_0") - ds.write_documents(docs_index_1, index="index_1") - tfidf_retriever.fit(ds, index="index_1") + document_store.write_documents(docs_index_0, index="index_0") + tfidf_retriever.fit(document_store, index="index_0") + document_store.write_documents(docs_index_1, index="index_1") + tfidf_retriever.fit(document_store, index="index_1") - assert tfidf_retriever.document_counts["index_0"] == ds.get_document_count(index="index_0") - assert tfidf_retriever.document_counts["index_1"] == ds.get_document_count(index="index_1") + assert tfidf_retriever.document_counts["index_0"] == document_store.get_document_count(index="index_0") + assert tfidf_retriever.document_counts["index_1"] == document_store.get_document_count(index="index_1") +@pytest.mark.unit +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) def test_retrieval_empty_query(document_store: BaseDocumentStore): # test with empty query using the run() method mock_document = Document(id="0", content="test") @@ -318,34 +322,51 @@ def test_retribert_embedding(document_store, retriever, docs_with_ids): assert isclose(embedding[0], expected_value, rel_tol=0.001) -def test_openai_embedding_retriever_selection(): - # OpenAI released (Dec 2022) a unifying embedding model called text-embedding-ada-002 - # make sure that we can use it with the retriever selection - er = EmbeddingRetriever(embedding_model="text-embedding-ada-002", document_store=None) - assert er.model_format == "openai" - assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002" - assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002" - assert er.api_base == "https://api.openai.com/v1" - - # but also support old ada and other text-search--*-001 models - er = EmbeddingRetriever(embedding_model="ada", document_store=None) - assert er.model_format == "openai" - assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001" - assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001" - assert er.api_base == "https://api.openai.com/v1" - - # but also support old babbage and other text-search--*-001 models - er = EmbeddingRetriever(embedding_model="babbage", document_store=None) - assert er.model_format == "openai" - assert er.embedding_encoder.query_encoder_model == "text-search-babbage-query-001" - assert er.embedding_encoder.doc_encoder_model == "text-search-babbage-doc-001" +@pytest.mark.unit +def test_openai_embedding_retriever_model_format(): + # support text-embedding-ada-002 + assert ( + EmbeddingRetriever._infer_model_format(model_name_or_path="text-embedding-ada-002", use_auth_token=None) + == "openai" + ) + + # support old ada and other text-search--*-001 models + assert EmbeddingRetriever._infer_model_format(model_name_or_path="ada", use_auth_token=None) == "openai" + + # support old babbage and other text-search--*-001 models + assert EmbeddingRetriever._infer_model_format(model_name_or_path="babbage", use_auth_token=None) == "openai" + + # make sure that we can handle potential unreleased models + assert ( + EmbeddingRetriever._infer_model_format(model_name_or_path="text-embedding-babbage-002", use_auth_token=None) + == "openai" + ) + + +@pytest.mark.unit +def test_openai_encoder_setup_encoding_models(): + with patch("haystack.nodes.retriever._openai_encoder._OpenAIEmbeddingEncoder.__init__") as mock_encoder_init: + mock_encoder_init.return_value = None + encoder = _OpenAIEmbeddingEncoder(retriever=None) # type: ignore + + encoder._setup_encoding_models(model_class="ada", model_name="text-embedding-ada-002", max_seq_len=512) + assert encoder.query_encoder_model == "text-embedding-ada-002" + assert encoder.doc_encoder_model == "text-embedding-ada-002" + + # support old ada and other text-search--*-001 models + encoder._setup_encoding_models(model_class="ada", model_name="ada", max_seq_len=512) + assert encoder.query_encoder_model == "text-search-ada-query-001" + assert encoder.doc_encoder_model == "text-search-ada-doc-001" + + # support old babbage and other text-search--*-001 models + encoder._setup_encoding_models(model_class="babbage", model_name="babbage", max_seq_len=512) + assert encoder.query_encoder_model == "text-search-babbage-query-001" + assert encoder.doc_encoder_model == "text-search-babbage-doc-001" # make sure that we can handle potential unreleased models - er = EmbeddingRetriever(embedding_model="text-embedding-babbage-002", document_store=None) - assert er.model_format == "openai" - assert er.embedding_encoder.query_encoder_model == "text-embedding-babbage-002" - assert er.embedding_encoder.doc_encoder_model == "text-embedding-babbage-002" - # etc etc. + encoder._setup_encoding_models(model_class="babbage", model_name="text-embedding-babbage-002", max_seq_len=512) + assert encoder.query_encoder_model == "text-embedding-babbage-002" + assert encoder.doc_encoder_model == "text-embedding-babbage-002" @pytest.mark.integration