diff --git a/paperqa/__init__.py b/paperqa/__init__.py index 3830bbdf9..2de2ce632 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -21,29 +21,29 @@ from .version import __version__ __all__ = [ - "Docs", "Answer", - "PromptCollection", - "__version__", - "Doc", - "Text", + "AnthropicLLMModel", "Context", - "LLMModel", + "Doc", + "Docs", "EmbeddingModel", - "OpenAIEmbeddingModel", "HybridEmbeddingModel", - "SparseEmbeddingModel", - "OpenAILLMModel", - "AnthropicLLMModel", + "LLMModel", + "LLMResult", + "LangchainEmbeddingModel", "LangchainLLMModel", + "LangchainVectorStore", "LlamaEmbeddingModel", - "SentenceTransformerEmbeddingModel", - "LangchainEmbeddingModel", "NumpyVectorStore", - "LangchainVectorStore", + "OpenAIEmbeddingModel", + "OpenAILLMModel", + "PromptCollection", + "SentenceTransformerEmbeddingModel", + "SparseEmbeddingModel", + "Text", + "__version__", + "embedding_model_factory", + "llm_model_factory", "print_callback", - "LLMResult", "vector_store_factory", - "llm_model_factory", - "embedding_model_factory", ] diff --git a/paperqa/docs.py b/paperqa/docs.py index fce995c08..2c5e7e4e1 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -142,7 +142,7 @@ def config_summary_llm_config(cls, data: Any) -> Any: if ( data.summary_llm_model is None and data.llm == "default" - and type(data.llm_model) == OpenAILLMModel + and isinstance(data.llm_model, OpenAILLMModel) ): data.summary_llm_model = OpenAILLMModel( config={"model": "gpt-3.5-turbo", "temperature": 0.1} @@ -530,8 +530,8 @@ async def adoc_match( if ( rerank is None and ( - type(self.llm_model) == OpenAILLMModel - and cast(OpenAILLMModel, self).config["model"].startswith("gpt-4") + isinstance(self.llm_model, OpenAILLMModel) + and self.config["model"].startswith("gpt-4") ) or rerank is True ): @@ -703,19 +703,21 @@ async def process(match): # noqa: C901, PLR0912 if self.strip_citations: # remove citations that collide with our grounded citations (for the answer LLM) context = strip_citations(context) - c = Context( - context=context, - text=Text( - text=match.text, - name=match.name, - doc=match.doc.__class__( - **match.doc.model_dump(exclude="embedding") + return ( + Context( + context=context, + text=Text( + text=match.text, + name=match.name, + doc=match.doc.__class__( + **match.doc.model_dump(exclude="embedding") + ), ), + score=score, + **extras, ), - score=score, - **extras, + llm_result, ) - return c, llm_result results = await gather_with_concurrency( self.max_concurrent, [process(m) for m in matches] @@ -723,13 +725,12 @@ async def process(match): # noqa: C901, PLR0912 # update token counts [answer.add_tokens(r[1]) for r in results] - # filter out failures - contexts = [c for c, r in results if c is not None] - + # filter out failures, sort by score, limit to max_sources answer.contexts = sorted( - contexts + answer.contexts, key=lambda x: x.score, reverse=True - ) - answer.contexts = answer.contexts[:max_sources] + [c for c, r in results if c is not None] + answer.contexts, + key=lambda x: x.score, + reverse=True, + )[:max_sources] context_str = "\n\n".join( [ f"{c.text.name}: {c.context}" diff --git a/paperqa/llms.py b/paperqa/llms.py index 76773ae1b..355bbe350 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -81,13 +81,11 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911 def is_anyscale_model(model_name: str) -> bool: # compares prefixes with anyscale models # https://docs.anyscale.com/endpoints/text-generation/query-a-model/ - if ( + return bool( os.environ.get("ANYSCALE_API_KEY") and os.environ.get("ANYSCALE_BASE_URL") and model_name.startswith(ANYSCALE_MODEL_PREFIXES) - ): - return True - return False + ) def is_openai_model(model_name: str) -> bool: @@ -296,8 +294,8 @@ async def execute( ) result.prompt = messages result.prompt_count = sum( - [self.count_tokens(m["content"]) for m in messages] - ) + sum([self.count_tokens(m["role"]) for m in messages]) + self.count_tokens(m["content"]) for m in messages + ) + sum(self.count_tokens(m["role"]) for m in messages) if callbacks is None: output = await self.achat(client, messages) @@ -612,14 +610,16 @@ async def similarity_search( def clear(self) -> None: pass - async def max_marginal_relevance_search( # noqa: D417 + async def max_marginal_relevance_search( self, client: Any, query: str, k: int, fetch_k: int ) -> tuple[Sequence[Embeddable], list[float]]: """Vectorized implementation of Maximal Marginal Relevance (MMR) search. Args: + client: TODOC. query: Query vector. k: Number of results to return. + fetch_k: Number of results to fetch from the vector store. Returns: List of tuples (doc, score) of length k. diff --git a/paperqa/readers.py b/paperqa/readers.py index 3cbbe1f9e..a2a2db05a 100644 --- a/paperqa/readers.py +++ b/paperqa/readers.py @@ -116,7 +116,7 @@ def parse_text( "parsing_libraries": ["tiktoken (cl100k_base)"] if use_tiktoken else [], "paperqa_version": str(pqa_version), "total_parsed_text_length": ( - len(text) if isinstance(text, str) else sum([len(t) for t in text]) + len(text) if isinstance(text, str) else sum(len(t) for t in text) ), "parse_type": "txt" if not html else "html", } @@ -254,7 +254,7 @@ def read_doc( ) -> tuple[list[Text], ParsedMetadata]: ... -def read_doc( # noqa: PLR0912 +def read_doc( path: Path, doc: Doc, parsed_text_only: bool = False, diff --git a/paperqa/utils.py b/paperqa/utils.py index 1d54df346..f75ab1280 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -17,9 +17,7 @@ def name_in_text(name: str, text: str) -> bool: sname = name.strip() pattern = rf"\b({re.escape(sname)})\b(?!\w)" - if re.search(pattern, text): - return True - return False + return bool(re.search(pattern, text)) def maybe_is_text(s: str, thresh: float = 2.5) -> bool: @@ -33,9 +31,7 @@ def maybe_is_text(s: str, thresh: float = 2.5) -> bool: entropy += -p * math.log2(p) # Check if the entropy is within a reasonable range for text - if entropy > thresh: - return True - return False + return entropy > thresh def maybe_is_pdf(file: BinaryIO) -> bool: @@ -47,7 +43,7 @@ def maybe_is_pdf(file: BinaryIO) -> bool: def maybe_is_html(file: BinaryIO) -> bool: magic_number = file.read(4) file.seek(0) - return magic_number in (b" float: diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 3686bc06d..673f6d382 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -1030,7 +1030,7 @@ def test_langchain_llm(): assert docs2.summary_llm == "babbage-002" docs2.get_evidence( Answer(question="What is Frederick Bates's greatest accomplishment?"), - get_callbacks=lambda x: [lambda y: print(y)], # noqa: ARG005 + get_callbacks=lambda x: [print], # noqa: ARG005 ) @@ -1117,8 +1117,7 @@ async def test_langchain_vector_store(): index.add_texts_and_embeddings(some_texts) assert index._store is not None # check search returns Text obj - data, score = await index.similarity_search(None, "test", k=1) # type: ignore[unreachable] - print(data) + data, _ = await index.similarity_search(None, "test", k=1) # type: ignore[unreachable] assert isinstance(data[0], Text) # now try with convenience @@ -1386,7 +1385,7 @@ def test_parser_only_reader(): assert any("pypdf" in t for t in parsed_text.metadata.parsing_libraries) assert parsed_text.metadata.chunk_metadata is None assert parsed_text.metadata.total_parsed_text_length == sum( - [len(t) for t in parsed_text.content.values()] # type: ignore[misc,union-attr] + len(t) for t in parsed_text.content.values() # type: ignore[misc,union-attr] ) @@ -1491,12 +1490,12 @@ def test_citation(): f.write(r.text) docs = Docs() docs.add(doc_path) # type: ignore[arg-type] - assert next(iter(docs.docs.values())).docname in ( + assert next(iter(docs.docs.values())).docname in { "Wikipedia2024", "Frederick2024", "Wikipedia", "Frederick", - ) + } def test_dockey_filter():