diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8c7273e..c8d6202 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,4 +29,4 @@ jobs: - name: Run Pytest env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run : poetry run pytest tests/ + run : poetry run pytest -n 8 tests/ diff --git a/openbb_agents/agent.py b/openbb_agents/agent.py index fb31c2f..df7976e 100644 --- a/openbb_agents/agent.py +++ b/openbb_agents/agent.py @@ -1,10 +1,12 @@ import asyncio import logging -from typing import Callable from langchain.vectorstores import VectorStore from .chains import ( + agenerate_subquestion_answer, + agenerate_subquestions_from_query, + asearch_tools, generate_final_answer, generate_subquestion_answer, generate_subquestions_from_query, @@ -13,85 +15,15 @@ from .models import AnsweredSubQuestion, SubQuestion from .tools import ( build_openbb_tool_vector_index, + build_vector_index_from_openbb_function_descriptions, + map_name_to_openbb_function_description, ) from .utils import get_dependencies logger = logging.getLogger(__name__) -async def _fetch_tools_and_answer_subquestion( - user_query: str, - subquestion: SubQuestion, - tool_vector_index: VectorStore, - answered_subquestions: list[AnsweredSubQuestion], -) -> AnsweredSubQuestion: - logger.info("Attempting to select tools for: %s", {subquestion.question}) - dependencies = get_dependencies( - answered_subquestions=answered_subquestions, subquestion=subquestion - ) - tools = await search_tools( - subquestion=subquestion, - tool_vector_index=tool_vector_index, - answered_subquestions=dependencies, - ) - tool_names = [tool.__name__ for tool in tools] - logger.info("Retrieved tool(s): %s", tool_names) - - # Then attempt to answer subquestion - logger.info("Answering subquestion: %s", subquestion.question) - answered_subquestion = await generate_subquestion_answer( - user_query=user_query, - subquestion=subquestion, - tools=tools, - dependencies=dependencies, - ) - - logger.info("Answered subquestion: %s", answered_subquestion.answer) - return answered_subquestion - - -def _get_unanswered_subquestions( - answered_subquestions: list[AnsweredSubQuestion], subquestions: list[SubQuestion] -) -> list[SubQuestion]: - answered_subquestion_ids = [ - answered_subquestion.subquestion.id - for answered_subquestion in answered_subquestions - ] - return [ - subquestion - for subquestion in subquestions - if subquestion.id not in answered_subquestion_ids - ] - - -def _is_subquestion_answerable( - subquestion: SubQuestion, answered_subquestions: list[AnsweredSubQuestion] -) -> bool: - if not subquestion.depends_on: - return True - - for id_ in subquestion.depends_on: - if id_ not in [ - answered_subquestion.subquestion.id - for answered_subquestion in answered_subquestions - ]: - return False - return True - - -def _get_answerable_subquestions( - subquestions: list[SubQuestion], answered_subquestions: list[AnsweredSubQuestion] -) -> list[SubQuestion]: - return [ - subquestion - for subquestion in subquestions - if _is_subquestion_answerable( - subquestion=subquestion, answered_subquestions=answered_subquestions - ) - ] - - -def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str: +def openbb_agent(query: str, openbb_tools: list[str] | None = None) -> str: """Answer a query using the OpenBB Agent equipped with tools. By default all available openbb tools are used. You can have a query @@ -102,35 +34,30 @@ def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str: ---------- query : str The query you want to have answered. - openbb_tools : optional[list[str]] - Optional. Specify the OpenBB collections or commands that you use to use. If not + openbb_tools : list[Callable] + Optional. Specify the OpenBB functions you want to use. If not specified, every available OpenBB tool will be used. Examples -------- >>> # Use all OpenBB tools to answer the query - >>> openbb_agent("What is the market cap of TSLA?") + >>> openbb_agent("What is the stock price of TSLA?") >>> # Use only the specified tools to answer the query - >>> openbb_agent("What is the market cap of TSLA?", - ... openbb_tools=["/equity/fundamental", "/equity/price/historical"]) + >>> openbb_agent("What is the stock price of TSLA?", + ... openbb_tools=['.equity.price.quote']) """ - + tool_vector_index = _handle_tool_vector_index(openbb_tools) subquestions = generate_subquestions_from_query(user_query=query) + logger.info("Generated subquestions: %s", subquestions) - tool_vector_index = build_openbb_tool_vector_index() answered_subquestions = [] - while unanswered_subquestions := _get_unanswered_subquestions( - answered_subquestions=answered_subquestions, subquestions=subquestions - ): - logger.info("Unanswered subquestions: %s", unanswered_subquestions) - answerable_subquestions = _get_answerable_subquestions( - subquestions=unanswered_subquestions, - answered_subquestions=answered_subquestions, - ) - logger.info("Answerable subquestions: %s", answerable_subquestions) - for subquestion in answerable_subquestions: # TODO: Do in parallel + for subquestion in subquestions: + if _is_subquestion_answerable( + subquestion=subquestion, answered_subquestions=answered_subquestions + ): + logger.info("Answering subquestion: %s", subquestion) answered_subquestion = _fetch_tools_and_answer_subquestion( user_query=query, subquestion=subquestion, @@ -138,6 +65,48 @@ def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str: answered_subquestions=answered_subquestions, ) answered_subquestions.append(answered_subquestion) + else: + logger.info("Skipping unanswerable subquestion: %s", subquestion) + return generate_final_answer( + user_query=query, + answered_subquestions=answered_subquestions, + ) + + +async def aopenbb_agent(query: str, openbb_tools: list[str] | None = None) -> str: + """Answer a query using the OpenBB Agent equipped with tools. + + Async variant of `openbb_agent`. + + By default all available openbb tools are used. You can have a query + answered using a smaller subset of OpenBB tools by using the `openbb_tools` + argument. + + Parameters + ---------- + query : str + The query you want to have answered. + openbb_tools : list[Callable] + Optional. Specify the OpenBB functions you want to use. If not + specified, every available OpenBB tool will be used. + + Examples + -------- + >>> # Use all OpenBB tools to answer the query + >>> openbb_agent("What is the stock price of TSLA?") + >>> # Use only the specified tools to answer the query + >>> openbb_agent("What is the stock price of TSLA?", + ... openbb_tools=['.equity.price.quote']) + + """ + tool_vector_index = _handle_tool_vector_index(openbb_tools) + + subquestions = await agenerate_subquestions_from_query(user_query=query) + answered_subquestions = await _aprocess_subquestions( + user_query=query, + subquestions=subquestions, + tool_vector_index=tool_vector_index, + ) return generate_final_answer( user_query=query, @@ -145,19 +114,18 @@ def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str: ) -async def _process_subquestions( - user_query, subquestions, tool_vector_index +async def _aprocess_subquestions( + user_query: str, subquestions: list[SubQuestion], tool_vector_index: VectorStore ) -> list[AnsweredSubQuestion]: answered_subquestions = [] queued_subquestions = [] tasks = [] while True: - logger.info("Loop...") unanswered_subquestions = _get_unanswered_subquestions( answered_subquestions=answered_subquestions, subquestions=subquestions ) - logger.info("Unanswered subquestions: %s", unanswered_subquestions) + logger.info("Pending subquestions: %s", unanswered_subquestions) new_answerable_subquestions = _get_answerable_subquestions( subquestions=unanswered_subquestions, @@ -166,12 +134,12 @@ async def _process_subquestions( logger.info("Answerable subquestions: %s", new_answerable_subquestions) for subquestion in new_answerable_subquestions: - logger.info("Scheduling task for subquestion: %s", subquestion) + logger.info("Scheduling subquestion for answer: %s", subquestion) # Make sure we only submit newly answerable questions (since the # other ones have been submitted already) if subquestion not in queued_subquestions: task = asyncio.create_task( - _fetch_tools_and_answer_subquestion( + _afetch_tools_and_answer_subquestion( user_query=user_query, subquestion=subquestion, tool_vector_index=tool_vector_index, @@ -185,30 +153,133 @@ async def _process_subquestions( break done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - logger.info("Task completed...") + tasks = [task for task in tasks if not task.done()] + for task in done: if task.exception(): - logger.error("Error in task: %s", task.exception()) + logger.error("Unexpected error in task: %s", task.exception()) else: answered_subquestion = task.result() logger.info("Finished task for subquestion: %s", answered_subquestion) answered_subquestions.append(answered_subquestion) - tasks = [task for task in tasks if not task.done()] return answered_subquestions -async def aopenbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str: - tool_vector_index = build_openbb_tool_vector_index() +def _fetch_tools_and_answer_subquestion( + user_query: str, + subquestion: SubQuestion, + tool_vector_index: VectorStore, + answered_subquestions: list[AnsweredSubQuestion], +) -> AnsweredSubQuestion: + logger.info("Attempting to select tools for: %s", {subquestion.question}) + dependencies = get_dependencies( + answered_subquestions=answered_subquestions, subquestion=subquestion + ) + tools = search_tools( + subquestion=subquestion, + tool_vector_index=tool_vector_index, + answered_subquestions=dependencies, + ) + tool_names = [tool.__name__ for tool in tools] + logger.info("Retrieved tool(s): %s", tool_names) - subquestions = await generate_subquestions_from_query(user_query=query) - answered_subquestions = await _process_subquestions( - user_query=query, - subquestions=subquestions, + # Then attempt to answer subquestion + logger.info("Answering subquestion: %s", subquestion.question) + answered_subquestion = generate_subquestion_answer( + user_query=user_query, + subquestion=subquestion, + tools=tools, + dependencies=dependencies, + ) + + logger.info("Answered subquestion: %s", answered_subquestion.answer) + return answered_subquestion + + +async def _afetch_tools_and_answer_subquestion( + user_query: str, + subquestion: SubQuestion, + tool_vector_index: VectorStore, + answered_subquestions: list[AnsweredSubQuestion], +) -> AnsweredSubQuestion: + logger.info("Attempting to select tools for: %s", {subquestion.question}) + dependencies = get_dependencies( + answered_subquestions=answered_subquestions, subquestion=subquestion + ) + tools = await asearch_tools( + subquestion=subquestion, tool_vector_index=tool_vector_index, + answered_subquestions=dependencies, ) + tool_names = [tool.__name__ for tool in tools] + logger.info("Retrieved tool(s): %s", tool_names) - return generate_final_answer( - user_query=query, - answered_subquestions=answered_subquestions, + # Then attempt to answer subquestion + logger.info("Answering subquestion: %s", subquestion.question) + answered_subquestion = await agenerate_subquestion_answer( + user_query=user_query, + subquestion=subquestion, + tools=tools, + dependencies=dependencies, ) + + logger.info("Answered subquestion: %s", answered_subquestion.answer) + return answered_subquestion + + +def _get_unanswered_subquestions( + answered_subquestions: list[AnsweredSubQuestion], subquestions: list[SubQuestion] +) -> list[SubQuestion]: + answered_subquestion_ids = [ + answered_subquestion.subquestion.id + for answered_subquestion in answered_subquestions + ] + return [ + subquestion + for subquestion in subquestions + if subquestion.id not in answered_subquestion_ids + ] + + +def _is_subquestion_answerable( + subquestion: SubQuestion, answered_subquestions: list[AnsweredSubQuestion] +) -> bool: + if not subquestion.depends_on: + return True + + for id_ in subquestion.depends_on: + if id_ not in [ + answered_subquestion.subquestion.id + for answered_subquestion in answered_subquestions + ]: + return False + return True + + +def _get_answerable_subquestions( + subquestions: list[SubQuestion], answered_subquestions: list[AnsweredSubQuestion] +) -> list[SubQuestion]: + return [ + subquestion + for subquestion in subquestions + if _is_subquestion_answerable( + subquestion=subquestion, answered_subquestions=answered_subquestions + ) + ] + + +def _handle_tool_vector_index(openbb_tools: list[str] | None) -> VectorStore: + if not openbb_tools: + logger.info("Using all available OpenBB tools.") + tool_vector_index = build_openbb_tool_vector_index() + else: + logger.info("Using specified OpenBB tools: %s", openbb_tools) + openbb_function_descriptions = [ + map_name_to_openbb_function_description(obb_function_name) + for obb_function_name in openbb_tools + ] + tool_vector_index = build_vector_index_from_openbb_function_descriptions( + openbb_function_descriptions + ) + return tool_vector_index diff --git a/openbb_agents/chains.py b/openbb_agents/chains.py index 42cfbe5..d417195 100644 --- a/openbb_agents/chains.py +++ b/openbb_agents/chains.py @@ -9,6 +9,7 @@ FunctionCall, FunctionResultMessage, OpenaiChatModel, + ParallelFunctionCall, SystemMessage, UserMessage, chatprompt, @@ -35,8 +36,6 @@ def generate_final_answer( user_query: str, answered_subquestions: list[AnsweredSubQuestion], ) -> str: - """Generate the final response to a query, given a list of answered subquestions.""" - @prompt( FINAL_RESPONSE_PROMPT_TEMPLATE, model=OpenaiChatModel(model="gpt-4o", temperature=0.0), @@ -51,51 +50,70 @@ def _final_answer( ) -def _build_messages_for_function_call( - function_call: FunctionCall, - result: Any, -) -> list[Any]: - return [ - AssistantMessage(function_call), - FunctionResultMessage(content=str(result), function_call=function_call), - ] - +async def agenerate_final_answer( + user_query: str, + answered_subquestions: list[AnsweredSubQuestion], +) -> str: + @prompt( + FINAL_RESPONSE_PROMPT_TEMPLATE, + model=OpenaiChatModel(model="gpt-4o", temperature=0.0), + ) + async def _final_answer( + user_query: str, answered_subquestions: list[AnsweredSubQuestion] + ) -> str: + ... -def _build_messages_for_validation_error( - function_call: FunctionCall, - val_err: ValidationError, -) -> list[Any]: - logger.error(f"Input schema validation error: {val_err}") - # Sidestep `magentic`'s input validation, which will still - # occur when we pass `function_call` to `AssistantMessage` - # https://github.com/jackmpcollins/magentic/issues/211 - dummy = lambda *args, **kwargs: ... # noqa: E731 - dummy.__name__ = function_call.function.__name__ - new_function_call = FunctionCall( - function=dummy, - **function_call.arguments, + return await _final_answer( + user_query=user_query, answered_subquestions=answered_subquestions ) - return [ - AssistantMessage(new_function_call), - FunctionResultMessage( - content=str(val_err), - function_call=new_function_call, - ), - ] -def _build_messages_for_generic_error( - function_call: FunctionCall, - err: Exception, -) -> list[Any]: - logger.error(f"Error calling function: {err}") - return [ - AssistantMessage(function_call), - FunctionResultMessage(content=str(err), function_call=function_call), - ] +def generate_subquestion_answer( + user_query: str, + subquestion: SubQuestion, + dependencies: list[AnsweredSubQuestion], + tools: list[Callable], +) -> AnsweredSubQuestion: + current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + messages: list[Any] = [SystemMessage(SUBQUESTION_ANSWER_PROMPT)] + + answer = None + while not answer: + + @chatprompt( + *messages, + model=OpenaiChatModel(model="gpt-4o", temperature=0.0), + functions=tools, + ) + def _answer_subquestion( + user_query: str, + subquestion: str, + dependencies: list[AnsweredSubQuestion], + current_datetime: str, + ) -> str | ParallelFunctionCall: + ... + + response = _answer_subquestion( # type: ignore + user_query=user_query, + subquestion=subquestion.question, + dependencies=dependencies, + current_datetime=current_datetime, + ) + + if isinstance(response, ParallelFunctionCall): + for function_call in response._function_calls: + logger.info( + "Function call: %s(%s)", + function_call.function.__name__, + function_call.arguments, + ) + messages += _handle_function_call(function_call=function_call) + elif isinstance(response, str): + answer = response + return AnsweredSubQuestion(subquestion=subquestion, answer=answer) -async def generate_subquestion_answer( +async def agenerate_subquestion_answer( user_query: str, subquestion: SubQuestion, dependencies: list[AnsweredSubQuestion], @@ -134,19 +152,7 @@ async def _answer_subquestion( function_call.function.__name__, function_call.arguments, ) - try: - result = function_call() - messages += _build_messages_for_function_call( - function_call=function_call, result=result - ) - except ValidationError as val_err: - messages += _build_messages_for_validation_error( - function_call=function_call, val_err=val_err - ) - except Exception as err: - messages += _build_messages_for_generic_error( - function_call=function_call, err=err - ) + messages += _handle_function_call(function_call=function_call) elif isinstance(response, str): answer = response return AnsweredSubQuestion(subquestion=subquestion, answer=answer) @@ -157,11 +163,48 @@ async def _answer_subquestion( UserMessage("# User query\n{user_query}"), model=OpenaiChatModel(model="gpt-4o", temperature=0.0), ) -async def generate_subquestions_from_query(user_query: str) -> list[SubQuestion]: +def generate_subquestions_from_query(user_query: str) -> list[SubQuestion]: ... -async def search_tools( +@chatprompt( + SystemMessage(GENERATE_SUBQUESTION_SYSTEM_PROMPT_TEMPLATE), + UserMessage("# User query\n{user_query}"), + model=OpenaiChatModel(model="gpt-4o", temperature=0.0), +) +async def agenerate_subquestions_from_query(user_query: str) -> list[SubQuestion]: + ... + + +def search_tools( + subquestion: SubQuestion, + tool_vector_index: VectorStore, + answered_subquestions: list[AnsweredSubQuestion] | None = None, +) -> list[Callable]: + def llm_query_tool_index(query: str) -> str: + """Use natural language to search the tool index for tools.""" + logger.info("Searching tool index for: %s", query) + results = tool_vector_index.similarity_search(query=query, k=4) + return "\n".join([r.page_content for r in results]) + + @prompt_chain( + TOOL_SEARCH_PROMPT_TEMPLATE, + model=OpenaiChatModel(model="gpt-3.5-turbo", temperature=0.2), + functions=[llm_query_tool_index], + ) + def _search_tools( + subquestion: str, answered_subquestions: list[AnsweredSubQuestion] | None + ) -> list[str]: + ... + + tool_names = _search_tools(subquestion.question, answered_subquestions) + callables = _get_callables_from_tool_search_results( + tool_vector_index=tool_vector_index, tool_names=tool_names + ) + return callables + + +async def asearch_tools( subquestion: SubQuestion, tool_vector_index: VectorStore, answered_subquestions: list[AnsweredSubQuestion] | None = None, @@ -183,6 +226,16 @@ async def _search_tools( ... tool_names = await _search_tools(subquestion.question, answered_subquestions) + callables = _get_callables_from_tool_search_results( + tool_vector_index=tool_vector_index, tool_names=tool_names + ) + return callables + + +def _get_callables_from_tool_search_results( + tool_vector_index: VectorStore, + tool_names: list[str], +) -> list[Callable]: callables = [] for tool_name in tool_names: for doc in tool_vector_index.docstore._dict.values(): # type: ignore @@ -190,3 +243,61 @@ async def _search_tools( callables.append(doc.metadata["callable"]) break return callables + + +def _handle_function_call(function_call: FunctionCall) -> list[Any]: + try: + result = function_call() + return _build_messages_for_function_call( + function_call=function_call, result=result + ) + except ValidationError as val_err: + return _build_messages_for_validation_error( + function_call=function_call, val_err=val_err + ) + except Exception as err: + return _build_messages_for_generic_error(function_call=function_call, err=err) + + +def _build_messages_for_function_call( + function_call: FunctionCall, + result: Any, +) -> list[Any]: + return [ + AssistantMessage(function_call), + FunctionResultMessage(content=str(result), function_call=function_call), + ] + + +def _build_messages_for_validation_error( + function_call: FunctionCall, + val_err: ValidationError, +) -> list[Any]: + logger.error(f"Input schema validation error: {val_err}") + # Sidestep `magentic`'s input validation, which will still + # occur when we pass `function_call` to `AssistantMessage` + # https://github.com/jackmpcollins/magentic/issues/211 + dummy = lambda *args, **kwargs: ... # noqa: E731 + dummy.__name__ = function_call.function.__name__ + new_function_call = FunctionCall( + function=dummy, + **function_call.arguments, + ) + return [ + AssistantMessage(new_function_call), + FunctionResultMessage( + content=str(val_err), + function_call=new_function_call, + ), + ] + + +def _build_messages_for_generic_error( + function_call: FunctionCall, + err: Exception, +) -> list[Any]: + logger.error(f"Error calling function: {err}") + return [ + AssistantMessage(function_call), + FunctionResultMessage(content=str(err), function_call=function_call), + ] diff --git a/openbb_agents/tools.py b/openbb_agents/tools.py index 2b1f794..832b549 100644 --- a/openbb_agents/tools.py +++ b/openbb_agents/tools.py @@ -3,7 +3,7 @@ from typing import Any from langchain.schema import Document -from langchain_community.vectorstores import FAISS +from langchain_community.vectorstores.faiss import FAISS from langchain_core.vectorstores import VectorStore from langchain_openai import OpenAIEmbeddings from openbb import obb @@ -63,21 +63,27 @@ def get_valid_openbb_function_names() -> list[str]: def get_valid_openbb_function_descriptions() -> list[OpenBBFunctionDescription]: - command_schemas = _get_openbb_coverage_command_schemas() obb_function_descriptions = [] for obb_function_name in get_valid_openbb_function_names(): - dict_ = command_schemas[obb_function_name] obb_function_descriptions.append( - OpenBBFunctionDescription( - name=obb_function_name, - input_model=dict_["input"], - output_model=dict_["output"], - callable=dict_["callable"], - ) + map_name_to_openbb_function_description(obb_function_name) ) return obb_function_descriptions +def map_name_to_openbb_function_description( + obb_function_name: str, +) -> OpenBBFunctionDescription: + command_schemas = _get_openbb_coverage_command_schemas() + dict_ = command_schemas[obb_function_name] + return OpenBBFunctionDescription( + name=obb_function_name, + input_model=dict_["input"], + output_model=dict_["output"], + callable=dict_["callable"], + ) + + def _get_flat_properties_from_pydantic_model_as_str(model: Any) -> str: output_str = "" schema_properties = model.schema()["properties"] @@ -99,7 +105,7 @@ def make_vector_index_description( return output_str -def build_vector_index( +def build_vector_index_from_openbb_function_descriptions( openbb_function_descriptions: list[OpenBBFunctionDescription], ) -> VectorStore: documents = [] @@ -119,4 +125,6 @@ def build_vector_index( def build_openbb_tool_vector_index() -> VectorStore: logger.info("Building OpenBB tool vector index...") - return build_vector_index(get_valid_openbb_function_descriptions()) + return build_vector_index_from_openbb_function_descriptions( + get_valid_openbb_function_descriptions() + ) diff --git a/poetry.lock b/poetry.lock index 43b532f..65d93f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -725,6 +725,20 @@ six = ">=1.9.0" gmpy = ["gmpy"] gmpy2 = ["gmpy2"] +[[package]] +name = "execnet" +version = "2.1.1" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.8" +files = [ + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "executing" version = "2.0.1" @@ -2280,13 +2294,13 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] [[package]] name = "nest-asyncio" -version = "1.5.9" +version = "1.6.0" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" files = [ - {file = "nest_asyncio-1.5.9-py3-none-any.whl", hash = "sha256:61ec07ef052e72e3de22045b81b2cc7d71fceb04c568ba0b2e4b2f9f5231bec2"}, - {file = "nest_asyncio-1.5.9.tar.gz", hash = "sha256:d1e1144e9c6e3e6392e0fcf5211cb1c8374b5648a98f1ebe48e5336006b41907"}, + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] [[package]] @@ -3585,6 +3599,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-freezegun" version = "0.4.2" @@ -3600,6 +3632,26 @@ files = [ freezegun = ">0.3" pytest = ">=3.0.0" +[[package]] +name = "pytest-xdist" +version = "3.6.1" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"}, + {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"}, +] + +[package.dependencies] +execnet = ">=2.1" +pytest = ">=7.0.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -5620,4 +5672,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11,<3.12" -content-hash = "d72b4f02abd9c08f9e841f060c8e53e555269fc090e4267030accc4b55268080" +content-hash = "6219a59550206675edb997d68463df8b765f9b18720f68559277a6c1eae6feeb" diff --git a/pyproject.toml b/pyproject.toml index 66cf744..3aefa8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ langchain = "^0.1.17" langchain-community = "^0.0.37" langchain-openai = "^0.1.6" openbb-yfinance = "^1.1.5" +pytest-asyncio = "^0.23.6" +pytest-xdist = "^3.6.1" [tool.poetry.group.dev.dependencies] pre-commit = "^3.5.0" diff --git a/tests/conftest.py b/tests/conftest.py index b9d0169..24772a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,12 @@ from unittest.mock import patch import pytest +from langchain_core.vectorstores import VectorStore from openbb import obb from pydantic import BaseModel, Field +from openbb_agents.tools import build_openbb_tool_vector_index + @pytest.fixture def mock_obb_user_credentials(monkeypatch): @@ -82,3 +85,8 @@ def mock_obb_coverage_command_schema( with patch("openbb_agents.tools._get_openbb_coverage_command_schemas") as mock: mock.return_value = mock_coverage_command_schema_dict yield mock + + +@pytest.fixture +def openbb_tool_vector_index() -> VectorStore: + return build_openbb_tool_vector_index() diff --git a/tests/test_agent.py b/tests/test_agent.py index e69de29..8803d34 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -0,0 +1,29 @@ +import pytest + +from openbb_agents.agent import aopenbb_agent, openbb_agent +from openbb_agents.testing import with_llm + + +def test_openbb_agent(openbb_tool_vector_index): + test_query = "What is the stock price of AAPL and MSFT?" + actual_result = openbb_agent( + query=test_query, + openbb_tools=[".equity.price.quote", ".equity.fundamental.metrics"], + ) + assert isinstance(actual_result, str) + assert with_llm(actual_result, "MSFT's stock price is in the answer.") + assert with_llm(actual_result, "AAPL's stock price is in the answer.") + assert with_llm(actual_result, "One of the stock prices is higher than the other.") + + +@pytest.mark.asyncio +async def test_aopenbb_agent(openbb_tool_vector_index): + test_query = "What is the stock price of AAPL and MSFT? Which is higher?" + actual_result = await aopenbb_agent( + query=test_query, + openbb_tools=[".equity.price.quote", ".equity.fundamental.metrics"], + ) + assert isinstance(actual_result, str) + assert with_llm(actual_result, "MSFT's stock price is in the answer.") + assert with_llm(actual_result, "AAPL's stock price is in the answer.") + assert with_llm(actual_result, "One of the stock prices is higher than the other.") diff --git a/tests/test_chains.py b/tests/test_chains.py index 7576551..d75d867 100644 --- a/tests/test_chains.py +++ b/tests/test_chains.py @@ -1,9 +1,14 @@ from typing import Literal +import pytest from openbb import obb from pydantic import BaseModel from openbb_agents.chains import ( + agenerate_final_answer, + agenerate_subquestion_answer, + agenerate_subquestions_from_query, + asearch_tools, generate_final_answer, generate_subquestion_answer, generate_subquestions_from_query, @@ -11,7 +16,6 @@ ) from openbb_agents.models import AnsweredSubQuestion, SubQuestion from openbb_agents.testing import with_llm -from openbb_agents.tools import build_openbb_tool_vector_index def test_generate_subquestions_from_query(): @@ -22,13 +26,35 @@ def test_generate_subquestions_from_query(): assert isinstance(actual_result[0], SubQuestion) -def test_search_tools_no_dependencies(): +@pytest.mark.asyncio +async def test_agenerate_subquestions_from_query(): + test_query = "Calculate the P/E ratio of AAPL." + actual_result = await agenerate_subquestions_from_query(user_query=test_query) + assert isinstance(actual_result, list) + assert len(actual_result) > 0 + assert isinstance(actual_result[0], SubQuestion) + + +def test_search_tools_no_dependencies(openbb_tool_vector_index): test_subquestion = SubQuestion(id=1, question="What is the stock price of AAPL?") - test_tool_vector_index = build_openbb_tool_vector_index() actual_result = search_tools( subquestion=test_subquestion, answered_subquestions=None, - tool_vector_index=test_tool_vector_index, + tool_vector_index=openbb_tool_vector_index, + ) + + assert len(actual_result) > 0 + assert actual_result[0].__name__ == "quote" + assert callable(actual_result[0]) + + +@pytest.mark.asyncio +async def test_asearch_tools_no_dependencies(openbb_tool_vector_index): + test_subquestion = SubQuestion(id=1, question="What is the stock price of AAPL?") + actual_result = await asearch_tools( + subquestion=test_subquestion, + answered_subquestions=None, + tool_vector_index=openbb_tool_vector_index, ) assert len(actual_result) > 0 @@ -53,6 +79,24 @@ def test_generate_subquestion_answer_no_dependencies(): ) +@pytest.mark.asyncio +async def test_agenerate_subquestion_answer_no_dependencies(): + test_user_query = "What is the current stock price of AAPL?" + test_subquestion = SubQuestion( + id=1, question="What is the stock price of AAPL? Use yfinance as the provider." + ) + test_tool = obb.equity.price.quote # type: ignore + actual_result: AnsweredSubQuestion = await agenerate_subquestion_answer( + user_query=test_user_query, + subquestion=test_subquestion, + dependencies=[], + tools=[test_tool], + ) + assert with_llm( + actual_result.answer, "the stock price for apple was retrieved successfully" + ) + + def test_generate_subquestion_answer_with_dependencies(): test_user_query = "What is the current stock price of MSFT's biggest competitor?" test_subquestion = SubQuestion( @@ -80,6 +124,34 @@ def test_generate_subquestion_answer_with_dependencies(): ) +@pytest.mark.asyncio +async def test_agenerate_subquestion_answer_with_dependencies(): + test_user_query = "What is the current stock price of MSFT's biggest competitor?" + test_subquestion = SubQuestion( + id=1, + question="What is the stock price of MSFT's biggest competitor? Use yfinance as the provider.", # noqa: E501 + depends_on=[2], + ) + test_dependencies = [ + AnsweredSubQuestion( + subquestion=SubQuestion( + id=2, question="What is the current biggest competitor to MSFT?" + ), + answer="The current biggest competitor to MSFT is AAPL.", + ) + ] + test_tool = obb.equity.price.quote # type: ignore + actual_result: AnsweredSubQuestion = await agenerate_subquestion_answer( + user_query=test_user_query, + subquestion=test_subquestion, + dependencies=test_dependencies, + tools=[test_tool], + ) + assert with_llm( + actual_result.answer, "the stock price for apple was retrieved successfully" + ) + + def test_generate_subquestion_answer_with_generic_error_in_function_call(): test_user_query = "What is the current stock price of AAPL?" test_subquestion = SubQuestion(id=1, question="What is the stock price of AAPL?") @@ -94,7 +166,31 @@ def _get_stock_price(symbol: str) -> str: tools=[_get_stock_price], ) assert isinstance(actual_result, AnsweredSubQuestion) - assert with_llm(actual_result.answer, "The backend is offline.") + assert with_llm( + actual_result.answer, + "The backend is offline, and the answer could not be retrieved.", + ) + + +@pytest.mark.asyncio +async def test_agenerate_subquestion_answer_with_generic_error_in_function_call(): + test_user_query = "What is the current stock price of AAPL?" + test_subquestion = SubQuestion(id=1, question="What is the stock price of AAPL?") + + def _get_stock_price(symbol: str) -> str: + raise ValueError("The backend is currently offline.") + + actual_result: AnsweredSubQuestion = await agenerate_subquestion_answer( + user_query=test_user_query, + subquestion=test_subquestion, + dependencies=[], + tools=[_get_stock_price], + ) + assert isinstance(actual_result, AnsweredSubQuestion) + assert with_llm( + actual_result.answer, + "The backend is offline, and the answer could not be retrieved.", + ) def test_generate_subquestion_answer_self_heals_with_input_validation_error_in_function_call(): # noqa: E501 @@ -122,6 +218,32 @@ class StockPricePayload(BaseModel): assert with_llm(actual_result.answer, "The stock price is 95 USD.") +@pytest.mark.asyncio +async def test_agenerate_subquestion_answer_self_heals_with_input_validation_error_in_function_call(): # noqa: E501 + test_user_query = "What is the current stock price of AAPL? Preferably in EUR." + test_subquestion = SubQuestion(id=1, question="What is the stock price of AAPL?") + + def _get_stock_price(symbol: str, currency: Literal["USD", "EUR"]) -> str: + class StockPricePayload(BaseModel): + symbol: str + currency: Literal["USD"] # Only USD is allowed, but we ask for EUR. + + _ = StockPricePayload(symbol=symbol, currency=currency) # type: ignore + return "The stock price is USD 95." + + actual_result: AnsweredSubQuestion = generate_subquestion_answer( + user_query=test_user_query, + subquestion=test_subquestion, + dependencies=[], + tools=[_get_stock_price], + ) + assert isinstance(actual_result, AnsweredSubQuestion) + assert with_llm( + actual_result.answer, "The stock price could not be retrieved in EUR." + ) + assert with_llm(actual_result.answer, "The stock price is 95 USD.") + + def test_generate_final_answer(): test_user_query = "Who has the highest stock price? AMZN or TSLA?" test_answered_subquestions = [ @@ -140,3 +262,24 @@ def test_generate_final_answer(): answered_subquestions=test_answered_subquestions, ) assert with_llm(actual_result, "The answer says TSLA has the highest stock price.") + + +@pytest.mark.asyncio +async def test_agenerate_final_answer(): + test_user_query = "Who has the highest stock price? AMZN or TSLA?" + test_answered_subquestions = [ + AnsweredSubQuestion( + subquestion=SubQuestion(id=1, question="What is the stock price of AMZN?"), + answer="The stock price of AMZN is $100.", + ), + AnsweredSubQuestion( + subquestion=SubQuestion(id=2, question="What is the stock price of TSLA?"), + answer="The stock price of TSLA is $200.", + ), + ] + + actual_result = await agenerate_final_answer( + user_query=test_user_query, + answered_subquestions=test_answered_subquestions, + ) + assert with_llm(actual_result, "The answer says TSLA has the highest stock price.") diff --git a/tests/test_tools.py b/tests/test_tools.py index 29219d5..2ba8b8e 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -3,7 +3,7 @@ from openbb_agents.models import OpenBBFunctionDescription from openbb_agents.tools import ( _get_flat_properties_from_pydantic_model_as_str, - build_vector_index, + build_vector_index_from_openbb_function_descriptions, get_valid_list_of_providers, get_valid_openbb_function_descriptions, get_valid_openbb_function_names, @@ -103,7 +103,7 @@ def test_build_vector_index( ), ] - actual_result = build_vector_index( + actual_result = build_vector_index_from_openbb_function_descriptions( openbb_function_descriptions=test_openbb_function_descriptions ) assert len(actual_result.docstore._dict) == 2 # type: ignore