Skip to content

Commit

Permalink
WIP: Use openbb coverage to retrieve tools
Browse files Browse the repository at this point in the history
  • Loading branch information
mnicstruwig committed May 7, 2024
1 parent 66d0fa5 commit 3762a9c
Show file tree
Hide file tree
Showing 5 changed files with 807 additions and 1,273 deletions.
287 changes: 73 additions & 214 deletions openbb_agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
"""Load OpenBB functions at OpenAI tools for function calling in Langchain"""
import inspect
from functools import wraps
from types import ModuleType
from typing import Callable, List, Union
from typing import Any

import tiktoken
from langchain.schema import Document
from langchain.tools import StructuredTool
from langchain.tools.base import ToolException
from langchain_community.vectorstores import FAISS, VectorStore
from langchain_openai import OpenAIEmbeddings
from openbb import obb
from pydantic.v1 import ValidationError, create_model
from pydantic.v1.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from pydantic import BaseModel


def enable_openbb_llm_mode():
from openbb import obb

obb.user.preferences.output_type = "llm" # type: ignore
obb.system.python_settings.docstring_sections = ["description", "examples"] # type: ignore
obb.system.python_settings.docstring_max_length = 1024 # type: ignore

import openbb

openbb.build()


enable_openbb_llm_mode()


def create_tool_index(tools: list[StructuredTool]) -> VectorStore:
Expand All @@ -27,212 +35,63 @@ def create_tool_index(tools: list[StructuredTool]) -> VectorStore:
return vector_store


def _fetch_obb_module(openbb_command_root: str) -> ModuleType:
module_path_split = openbb_command_root.split("/")[1:]
module_path = ".".join(module_path_split)

# Iteratively get module
module = obb
for attr in module_path.split("."):
module = getattr(module, attr)

return module


def _fetch_schemas(openbb_command_root: str) -> dict:
# Ugly hack to make it compatiable with the look-up (even though we convert
# it back) so that we have a nicer API for the user.
module_root_path = openbb_command_root.replace("/", ".")
schemas = {
k.replace(".", "/"): v
for k, v in obb.coverage.command_model.items()
if module_root_path in k
}
return schemas


def _fetch_callables(openbb_command_root):
module = _fetch_obb_module(openbb_command_root)

if inspect.ismethod(
module
): # Handle case where a final command endpoint is passed.
members_dict = {module.__name__: module}
else: # If a command root is passed instead
members = inspect.getmembers(module)
members_dict = {
x[0]: x[1] for x in members if "__" not in x[0] and "_run" not in x[0]
}

schemas = _fetch_schemas(openbb_command_root)
# Create callables dict, with the same key as used in the schemas
callables = {}
for k in schemas.keys():
try:
callables[k] = members_dict[k.split("/")[-1]]
except (
KeyError
): # Sometimes we don't have a specific callable for an endpoint, so we skip.
pass
return callables


def _fetch_outputs(schema):
outputs = []
output_fields = schema["openbb"]["Data"]["fields"]
for name, t in output_fields.items():
if isinstance(t.annotation, type):
type_str = t.annotation.__name__
else:
type_str = str(t.annotation).replace("typing.", "")
outputs.append((name, type_str))
return outputs


def from_schema_to_pydantic_model(model_name, schema):
create_model_kwargs = {}
for field, field_info in schema.items():
field_type = field_info.annotation

# Handle default values
if not isinstance(field_info.default, PydanticUndefinedType):
field_default_value = field_info.default
new_field_info = (
FieldInfo( # Weird hack, because of how the default field value works
description=field_info.description,
default=field_default_value,
)
)
else:
new_field_info = FieldInfo(
description=field_info.description,
def create_document(dict):
...


class OpenBBFunctionDescription(BaseModel):
name: str
input: Any
output: Any
callable: Any


def get_openbb_coverage_providers() -> dict:
return obb.coverage.providers # type: ignore


def get_openbb_user_credentials() -> dict:
return obb.user.credentials.model_dump() # type: ignore


def get_openbb_coverage_command_schemas() -> dict:
return obb.coverage.command_schemas() # type: ignore


def get_valid_list_of_providers() -> list[str]:
credentials = get_openbb_user_credentials()
valid_providers = []
for name, value in credentials.items():
if value is not None:
valid_providers.append(name.split("_api_key")[0].split("_token")[0])
return valid_providers


def get_valid_openbb_function_names() -> list[str]:
valid_providers = get_valid_list_of_providers()
valid_function_names = set()
for provider in valid_providers:
valid_function_names |= set(get_openbb_coverage_providers()[provider])
return sorted(list(valid_function_names))


def get_valid_openbb_function_descriptions() -> list[OpenBBFunctionDescription]:
command_schemas = get_openbb_coverage_command_schemas()
obb_function_descriptions = []
for obb_function_name in get_valid_openbb_function_names():
dict_ = command_schemas[obb_function_name]
obb_function_descriptions.append(
OpenBBFunctionDescription(
name=obb_function_name,
input=dict_["input"],
output=dict_["output"],
callable=dict_["callable"],
)
create_model_kwargs[field] = (field_type, new_field_info)
return create_model(model_name, **create_model_kwargs)


def return_results(func):
"""Return the results rather than the OBBject."""

def wrapper_func(*args, **kwargs):
try:
result = func(*args, **kwargs).results
encoding = tiktoken.encoding_for_model("gpt-4-1106-preview")
num_tokens = len(encoding.encode(str(result)))
if num_tokens > 90000:
raise ToolException(
"The returned output is too large to fit into context. Consider using another tool, or trying again with different input arguments." # noqa: E501
)
return result
# Necessary to catch general exception in this case, since we want the
# LLM to be able to correct a bad call, if possible.
except Exception as err:
raise ToolException(err) from err

return wrapper_func


def from_openbb_to_langchain_func(
openbb_command_root: str, openbb_callable: Callable, openbb_schema: dict
) -> StructuredTool:
func_schema = openbb_schema["openbb"]["QueryParams"]["fields"]
# Lookup the default provider's input arguments...
default_provider = obb.coverage.commands[openbb_command_root.replace("/", ".")][0]
# ... and add them to the func schema.
func_schema.update(openbb_schema[default_provider]["QueryParams"]["fields"])
pydantic_model = from_schema_to_pydantic_model(
model_name=f"{openbb_command_root}InputModel", schema=func_schema
)

outputs = _fetch_outputs(openbb_schema)
description = openbb_callable.__doc__.split("\n")[0]
description += "\nThe following data is available in the output:\n\n"
description += ", ".join(e[0].replace("_", " ") for e in outputs)

tool = StructuredTool(
name=openbb_command_root, # We use the command root for the name of the tool
func=return_results(openbb_callable),
description=description,
args_schema=pydantic_model,
handle_tool_error=True,
)

# We have to do some magic here to prevent a bad input argument from
# breaking the langchain flow
# https://github.com/langchain-ai/langchain/issues/13662#issuecomment-1831242057
def handle_validation_error(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ValidationError as err:
return str(err)

return wrapper

# Monkey-patch the run method
object.__setattr__(tool, "run", handle_validation_error(tool.run))

return tool


def map_openbb_functions_to_langchain_tools(
openbb_command_root, schemas_dict, callables_dict
):
tools = []
for route in callables_dict.keys():
tool = from_openbb_to_langchain_func(
openbb_command_root=route,
openbb_callable=callables_dict[route],
openbb_schema=schemas_dict[route],
)
tools.append(tool)
return tools


def map_openbb_routes_to_langchain_tools(
openbb_commands_root: Union[str, List[str]],
) -> list[StructuredTool]:
"""Map a collection of OpenBB callables from a command root to StructuredTools.
Examples
--------
>>> fundamental_tools = map_openbb_collection_to_langchain_tools(
... "/equity/fundamental"
... )
>>> crypto_price_tools = map_openbb_collection_to_langchain_tools(
... "/crypto/price"
... )
"""
openbb_commands_root_list = (
[openbb_commands_root]
if isinstance(openbb_commands_root, str)
else openbb_commands_root
)

tools: List = []
for obb_cmd_root in openbb_commands_root_list:
schemas = _fetch_schemas(obb_cmd_root)
callables = _fetch_callables(obb_cmd_root)
tools += map_openbb_functions_to_langchain_tools(
openbb_command_root=obb_cmd_root,
schemas_dict=schemas,
callables_dict=callables,
)
return tools

return obb_function_descriptions

def get_all_openbb_tools():
tool_routes = list(obb.coverage.commands.keys())
tool_routes = [
route.replace(".", "/") for route in tool_routes if "metrics" not in route
]

tools = []
for route in tool_routes:
schema = _fetch_schemas(route)
callables = _fetch_callables(route)
tools += map_openbb_functions_to_langchain_tools(route, schema, callables)
return tools
def make_vector_index_description(
openbb_function_description: OpenBBFunctionDescription,
) -> str:
...
Loading

0 comments on commit 3762a9c

Please sign in to comment.