Skip to content

Commit

Permalink
WIP: Add async variants + tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mnicstruwig committed May 16, 2024
1 parent 06d7e58 commit c28724d
Show file tree
Hide file tree
Showing 10 changed files with 612 additions and 188 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
- name: Run Pytest
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run : poetry run pytest tests/
run : poetry run pytest -n 8 tests/
289 changes: 180 additions & 109 deletions openbb_agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import logging
from typing import Callable

from langchain.vectorstores import VectorStore

from .chains import (
agenerate_subquestion_answer,
agenerate_subquestions_from_query,
asearch_tools,
generate_final_answer,
generate_subquestion_answer,
generate_subquestions_from_query,
Expand All @@ -13,85 +15,15 @@
from .models import AnsweredSubQuestion, SubQuestion
from .tools import (
build_openbb_tool_vector_index,
build_vector_index_from_openbb_function_descriptions,
map_name_to_openbb_function_description,
)
from .utils import get_dependencies

logger = logging.getLogger(__name__)


async def _fetch_tools_and_answer_subquestion(
user_query: str,
subquestion: SubQuestion,
tool_vector_index: VectorStore,
answered_subquestions: list[AnsweredSubQuestion],
) -> AnsweredSubQuestion:
logger.info("Attempting to select tools for: %s", {subquestion.question})
dependencies = get_dependencies(
answered_subquestions=answered_subquestions, subquestion=subquestion
)
tools = await search_tools(
subquestion=subquestion,
tool_vector_index=tool_vector_index,
answered_subquestions=dependencies,
)
tool_names = [tool.__name__ for tool in tools]
logger.info("Retrieved tool(s): %s", tool_names)

# Then attempt to answer subquestion
logger.info("Answering subquestion: %s", subquestion.question)
answered_subquestion = await generate_subquestion_answer(
user_query=user_query,
subquestion=subquestion,
tools=tools,
dependencies=dependencies,
)

logger.info("Answered subquestion: %s", answered_subquestion.answer)
return answered_subquestion


def _get_unanswered_subquestions(
answered_subquestions: list[AnsweredSubQuestion], subquestions: list[SubQuestion]
) -> list[SubQuestion]:
answered_subquestion_ids = [
answered_subquestion.subquestion.id
for answered_subquestion in answered_subquestions
]
return [
subquestion
for subquestion in subquestions
if subquestion.id not in answered_subquestion_ids
]


def _is_subquestion_answerable(
subquestion: SubQuestion, answered_subquestions: list[AnsweredSubQuestion]
) -> bool:
if not subquestion.depends_on:
return True

for id_ in subquestion.depends_on:
if id_ not in [
answered_subquestion.subquestion.id
for answered_subquestion in answered_subquestions
]:
return False
return True


def _get_answerable_subquestions(
subquestions: list[SubQuestion], answered_subquestions: list[AnsweredSubQuestion]
) -> list[SubQuestion]:
return [
subquestion
for subquestion in subquestions
if _is_subquestion_answerable(
subquestion=subquestion, answered_subquestions=answered_subquestions
)
]


def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str:
def openbb_agent(query: str, openbb_tools: list[str] | None = None) -> str:
"""Answer a query using the OpenBB Agent equipped with tools.
By default all available openbb tools are used. You can have a query
Expand All @@ -102,62 +34,98 @@ def openbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str:
----------
query : str
The query you want to have answered.
openbb_tools : optional[list[str]]
Optional. Specify the OpenBB collections or commands that you use to use. If not
openbb_tools : list[Callable]
Optional. Specify the OpenBB functions you want to use. If not
specified, every available OpenBB tool will be used.
Examples
--------
>>> # Use all OpenBB tools to answer the query
>>> openbb_agent("What is the market cap of TSLA?")
>>> openbb_agent("What is the stock price of TSLA?")
>>> # Use only the specified tools to answer the query
>>> openbb_agent("What is the market cap of TSLA?",
... openbb_tools=["/equity/fundamental", "/equity/price/historical"])
>>> openbb_agent("What is the stock price of TSLA?",
... openbb_tools=['.equity.price.quote'])
"""

tool_vector_index = _handle_tool_vector_index(openbb_tools)
subquestions = generate_subquestions_from_query(user_query=query)

logger.info("Generated subquestions: %s", subquestions)
tool_vector_index = build_openbb_tool_vector_index()

answered_subquestions = []
while unanswered_subquestions := _get_unanswered_subquestions(
answered_subquestions=answered_subquestions, subquestions=subquestions
):
logger.info("Unanswered subquestions: %s", unanswered_subquestions)
answerable_subquestions = _get_answerable_subquestions(
subquestions=unanswered_subquestions,
answered_subquestions=answered_subquestions,
)
logger.info("Answerable subquestions: %s", answerable_subquestions)
for subquestion in answerable_subquestions: # TODO: Do in parallel
for subquestion in subquestions:
if _is_subquestion_answerable(
subquestion=subquestion, answered_subquestions=answered_subquestions
):
logger.info("Answering subquestion: %s", subquestion)
answered_subquestion = _fetch_tools_and_answer_subquestion(
user_query=query,
subquestion=subquestion,
tool_vector_index=tool_vector_index,
answered_subquestions=answered_subquestions,
)
answered_subquestions.append(answered_subquestion)
else:
logger.info("Skipping unanswerable subquestion: %s", subquestion)
return generate_final_answer(
user_query=query,
answered_subquestions=answered_subquestions,
)


async def aopenbb_agent(query: str, openbb_tools: list[str] | None = None) -> str:
"""Answer a query using the OpenBB Agent equipped with tools.
Async variant of `openbb_agent`.
By default all available openbb tools are used. You can have a query
answered using a smaller subset of OpenBB tools by using the `openbb_tools`
argument.
Parameters
----------
query : str
The query you want to have answered.
openbb_tools : list[Callable]
Optional. Specify the OpenBB functions you want to use. If not
specified, every available OpenBB tool will be used.
Examples
--------
>>> # Use all OpenBB tools to answer the query
>>> openbb_agent("What is the stock price of TSLA?")
>>> # Use only the specified tools to answer the query
>>> openbb_agent("What is the stock price of TSLA?",
... openbb_tools=['.equity.price.quote'])
"""
tool_vector_index = _handle_tool_vector_index(openbb_tools)

subquestions = await agenerate_subquestions_from_query(user_query=query)
answered_subquestions = await _aprocess_subquestions(
user_query=query,
subquestions=subquestions,
tool_vector_index=tool_vector_index,
)

return generate_final_answer(
user_query=query,
answered_subquestions=answered_subquestions,
)


async def _process_subquestions(
user_query, subquestions, tool_vector_index
async def _aprocess_subquestions(
user_query: str, subquestions: list[SubQuestion], tool_vector_index: VectorStore
) -> list[AnsweredSubQuestion]:
answered_subquestions = []
queued_subquestions = []

tasks = []
while True:
logger.info("Loop...")
unanswered_subquestions = _get_unanswered_subquestions(
answered_subquestions=answered_subquestions, subquestions=subquestions
)
logger.info("Unanswered subquestions: %s", unanswered_subquestions)
logger.info("Pending subquestions: %s", unanswered_subquestions)

new_answerable_subquestions = _get_answerable_subquestions(
subquestions=unanswered_subquestions,
Expand All @@ -166,12 +134,12 @@ async def _process_subquestions(
logger.info("Answerable subquestions: %s", new_answerable_subquestions)

for subquestion in new_answerable_subquestions:
logger.info("Scheduling task for subquestion: %s", subquestion)
logger.info("Scheduling subquestion for answer: %s", subquestion)
# Make sure we only submit newly answerable questions (since the
# other ones have been submitted already)
if subquestion not in queued_subquestions:
task = asyncio.create_task(
_fetch_tools_and_answer_subquestion(
_afetch_tools_and_answer_subquestion(
user_query=user_query,
subquestion=subquestion,
tool_vector_index=tool_vector_index,
Expand All @@ -185,30 +153,133 @@ async def _process_subquestions(
break

done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
logger.info("Task completed...")
tasks = [task for task in tasks if not task.done()]

for task in done:
if task.exception():
logger.error("Error in task: %s", task.exception())
logger.error("Unexpected error in task: %s", task.exception())
else:
answered_subquestion = task.result()
logger.info("Finished task for subquestion: %s", answered_subquestion)
answered_subquestions.append(answered_subquestion)

tasks = [task for task in tasks if not task.done()]
return answered_subquestions


async def aopenbb_agent(query: str, openbb_tools: list[Callable] | None = None) -> str:
tool_vector_index = build_openbb_tool_vector_index()
def _fetch_tools_and_answer_subquestion(
user_query: str,
subquestion: SubQuestion,
tool_vector_index: VectorStore,
answered_subquestions: list[AnsweredSubQuestion],
) -> AnsweredSubQuestion:
logger.info("Attempting to select tools for: %s", {subquestion.question})
dependencies = get_dependencies(
answered_subquestions=answered_subquestions, subquestion=subquestion
)
tools = search_tools(
subquestion=subquestion,
tool_vector_index=tool_vector_index,
answered_subquestions=dependencies,
)
tool_names = [tool.__name__ for tool in tools]
logger.info("Retrieved tool(s): %s", tool_names)

subquestions = await generate_subquestions_from_query(user_query=query)
answered_subquestions = await _process_subquestions(
user_query=query,
subquestions=subquestions,
# Then attempt to answer subquestion
logger.info("Answering subquestion: %s", subquestion.question)
answered_subquestion = generate_subquestion_answer(
user_query=user_query,
subquestion=subquestion,
tools=tools,
dependencies=dependencies,
)

logger.info("Answered subquestion: %s", answered_subquestion.answer)
return answered_subquestion


async def _afetch_tools_and_answer_subquestion(
user_query: str,
subquestion: SubQuestion,
tool_vector_index: VectorStore,
answered_subquestions: list[AnsweredSubQuestion],
) -> AnsweredSubQuestion:
logger.info("Attempting to select tools for: %s", {subquestion.question})
dependencies = get_dependencies(
answered_subquestions=answered_subquestions, subquestion=subquestion
)
tools = await asearch_tools(
subquestion=subquestion,
tool_vector_index=tool_vector_index,
answered_subquestions=dependencies,
)
tool_names = [tool.__name__ for tool in tools]
logger.info("Retrieved tool(s): %s", tool_names)

return generate_final_answer(
user_query=query,
answered_subquestions=answered_subquestions,
# Then attempt to answer subquestion
logger.info("Answering subquestion: %s", subquestion.question)
answered_subquestion = await agenerate_subquestion_answer(
user_query=user_query,
subquestion=subquestion,
tools=tools,
dependencies=dependencies,
)

logger.info("Answered subquestion: %s", answered_subquestion.answer)
return answered_subquestion


def _get_unanswered_subquestions(
answered_subquestions: list[AnsweredSubQuestion], subquestions: list[SubQuestion]
) -> list[SubQuestion]:
answered_subquestion_ids = [
answered_subquestion.subquestion.id
for answered_subquestion in answered_subquestions
]
return [
subquestion
for subquestion in subquestions
if subquestion.id not in answered_subquestion_ids
]


def _is_subquestion_answerable(
subquestion: SubQuestion, answered_subquestions: list[AnsweredSubQuestion]
) -> bool:
if not subquestion.depends_on:
return True

for id_ in subquestion.depends_on:
if id_ not in [
answered_subquestion.subquestion.id
for answered_subquestion in answered_subquestions
]:
return False
return True


def _get_answerable_subquestions(
subquestions: list[SubQuestion], answered_subquestions: list[AnsweredSubQuestion]
) -> list[SubQuestion]:
return [
subquestion
for subquestion in subquestions
if _is_subquestion_answerable(
subquestion=subquestion, answered_subquestions=answered_subquestions
)
]


def _handle_tool_vector_index(openbb_tools: list[str] | None) -> VectorStore:
if not openbb_tools:
logger.info("Using all available OpenBB tools.")
tool_vector_index = build_openbb_tool_vector_index()
else:
logger.info("Using specified OpenBB tools: %s", openbb_tools)
openbb_function_descriptions = [
map_name_to_openbb_function_description(obb_function_name)
for obb_function_name in openbb_tools
]
tool_vector_index = build_vector_index_from_openbb_function_descriptions(
openbb_function_descriptions
)
return tool_vector_index
Loading

0 comments on commit c28724d

Please sign in to comment.