Skip to content

Commit

Permalink
[Feature] - Warn if inexistent kwargs (#6236)
Browse files Browse the repository at this point in the history
* feat: warn if wrong kwargs

* fix: remove query test

* fix: add tests and fix bug

---------

Co-authored-by: Igor Radovanovic <74266147+IgorWounds@users.noreply.github.com>
  • Loading branch information
montezdesousa and IgorWounds authored Mar 20, 2024
1 parent bab42a0 commit 7f4007a
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 161 deletions.
208 changes: 121 additions & 87 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Command runner module."""

import warnings
from copy import deepcopy
from dataclasses import asdict, is_dataclass
from datetime import datetime
from inspect import Parameter, signature
from sys import exc_info
from time import perf_counter_ns
from typing import Any, Callable, Dict, List, Optional, Tuple
from warnings import catch_warnings, showwarning, warn

from pydantic import ConfigDict, create_model
from fastapi.params import Query
from pydantic import BaseModel, ConfigDict, create_model

from openbb_core.app.logs.logging_service import LoggingService
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.app.model.abstract.warning import cast_warning
from openbb_core.app.model.abstract.warning import OpenBBWarning, cast_warning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
Expand Down Expand Up @@ -175,14 +177,53 @@ def _get_default_provider(

return kwargs

@staticmethod
def _warn_kwargs(
provider_choices: Dict[str, Any],
extra_params: Dict[str, Any],
model: BaseModel,
) -> None:
"""Warn if kwargs received and ignored by the validation model."""
# We only check the extra_params annotation because ignored fields
# will always be kwargs
annotation = getattr(
model.model_fields.get("extra_params", None), "annotation", None
)
if annotation:
# When there is no annotation there is nothing to warn
valid = asdict(annotation()) if is_dataclass(annotation) else {} # type: ignore
provider = provider_choices.get("provider", None)
for p in extra_params:
if field := valid.get(p):
if provider:
providers = (
field.title
if isinstance(field, Query) and isinstance(field.title, str)
else ""
).split(",")
if provider not in providers:
warn(
message=f"Parameter '{p}' is not supported by '{provider}'."
f" Available for: {', '.join(providers)}.",
category=OpenBBWarning,
)
else:
warn(
message=f"Parameter '{p}' not found.",
category=OpenBBWarning,
)

@staticmethod
def _as_dict(obj: Any) -> Dict[str, Any]:
"""Safely convert an object to a dict."""
return asdict(obj) if is_dataclass(obj) else dict(obj)

@staticmethod
def validate_kwargs(
func: Callable,
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Validate kwargs and if possible coerce to the correct type."""
config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

sig = signature(func)
fields = {
n: (
Expand All @@ -191,11 +232,17 @@ def validate_kwargs(
)
for n, p in sig.parameters.items()
}
# We allow extra fields to return with model with 'cc: CommandContext'
config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
ValidationModel = create_model(func.__name__, __config__=config, **fields) # type: ignore
# Validate and coerce
model = ValidationModel(**kwargs)
result = dict(model)

return result
ParametersBuilder._warn_kwargs(
ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
ValidationModel,
)
return dict(model)

@classmethod
def build(
Expand Down Expand Up @@ -230,7 +277,10 @@ def build(
kwargs=kwargs,
route_default=user_settings.defaults.routes.get(route, None),
)
kwargs = cls.validate_kwargs(func=func, kwargs=kwargs)
kwargs = cls.validate_kwargs(
func=func,
kwargs=kwargs,
)
return kwargs


Expand All @@ -242,30 +292,12 @@ async def _command(
cls,
func: Callable,
kwargs: Dict[str, Any],
show_warnings: bool = True,
) -> OBBject:
"""Run a command and return the output."""

with warnings.catch_warnings(record=True) as warning_list:
obbject = await maybe_coroutine(func, **kwargs)
obbject.provider = getattr(
kwargs.get("provider_choices", None), "provider", None
)

if warning_list:
obbject.warnings = []
for w in warning_list:
obbject.warnings.append(cast_warning(w))
if show_warnings:
warnings.showwarning(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
file=w.file,
line=w.line,
)

obbject = await maybe_coroutine(func, **kwargs)
obbject.provider = getattr(
kwargs.get("provider_choices", None), "provider", None
)
return obbject

@classmethod
Expand Down Expand Up @@ -294,68 +326,70 @@ async def _execute_func(
user_settings = execution_context.user_settings
system_settings = execution_context.system_settings

# If we're on Jupyter we need to pop here because we will lose "chart" after
# ParametersBuilder.build. This needs to be fixed in a way that chart is
# added to the function signature and shared for jupyter and api
# We can check in the router decorator if the given function has a chart
# in the charting extension then we add it there. This way we can remove
# the chart parameter from the commands.py and package_builder, it will be
# added to the function signature in the router decorator
chart = kwargs.pop("chart", False)
with catch_warnings(record=True) as warning_list:
# If we're on Jupyter we need to pop here because we will lose "chart" after
# ParametersBuilder.build. This needs to be fixed in a way that chart is
# added to the function signature and shared for jupyter and api
# We can check in the router decorator if the given function has a chart
# in the charting extension then we add it there. This way we can remove
# the chart parameter from the commands.py and package_builder, it will be
# added to the function signature in the router decorator
chart = kwargs.pop("chart", False)

kwargs = ParametersBuilder.build(
args=args,
execution_context=execution_context,
func=func,
route=route,
kwargs=kwargs,
)

kwargs = ParametersBuilder.build(
args=args,
execution_context=execution_context,
func=func,
route=route,
kwargs=kwargs,
)
# If we're on the api we need to remove "chart" here because the parameter is added on
# commands.py and the function signature does not expect "chart"
kwargs.pop("chart", None)
# We also pop custom headers
model_headers = system_settings.api_settings.custom_headers or {}
custom_headers = {
name: kwargs.pop(name.replace("-", "_"), default)
for name, default in model_headers.items() or {}
} or None

# If we're on the api we need to remove "chart" here because the parameter is added on
# commands.py and the function signature does not expect "chart"
kwargs.pop("chart", None)
# We also pop custom headers
try:
obbject = await cls._command(func, kwargs)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)

model_headers = (
SystemService().system_settings.api_settings.custom_headers or {}
)
custom_headers = {
name: kwargs.pop(name.replace("-", "_"), default)
for name, default in model_headers.items() or {}
} or None
if chart and obbject.results:
cls._chart(obbject, **kwargs)

try:
obbject = await cls._command(
func=func,
kwargs=kwargs,
show_warnings=user_settings.preferences.show_warnings,
)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)

if chart and obbject.results:
cls._chart(
obbject=obbject,
**kwargs,
except Exception as e:
raise OpenBBError(e) from e
finally:
ls = LoggingService(system_settings, user_settings)
ls.log(
user_settings=user_settings,
system_settings=system_settings,
route=route,
func=func,
kwargs=kwargs,
exec_info=exc_info(),
custom_headers=custom_headers,
)

except Exception as e:
raise OpenBBError(e) from e
finally:
ls = LoggingService(
user_settings=user_settings, system_settings=system_settings
)
ls.log(
user_settings=user_settings,
system_settings=system_settings,
route=route,
func=func,
kwargs=kwargs,
exec_info=exc_info(),
custom_headers=custom_headers,
)

if warning_list:
obbject.warnings = []
for w in warning_list:
obbject.warnings.append(cast_warning(w))
if user_settings.preferences.show_warnings:
showwarning(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
file=w.file,
line=w.line,
)
return obbject

@classmethod
Expand Down
2 changes: 0 additions & 2 deletions openbb_platform/core/openbb_core/app/model/user_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class UserSettings(Tagged):

def __repr__(self) -> str:
"""Human readable representation of the object."""
# We use the __dict__ because Credentials.model_dump() will use the serializer
# and unmask the credentials
return f"{self.__class__.__name__}\n\n" + "\n".join(
f"{k}: {v}" for k, v in self.model_dump().items()
)
45 changes: 4 additions & 41 deletions openbb_platform/core/openbb_core/app/query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Query class."""

import warnings
from dataclasses import asdict
from typing import Any, Dict
from typing import Any

from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
Expand All @@ -30,49 +28,14 @@ def __init__(
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
self.provider_interface = ProviderInterface()

def filter_extra_params(
self,
extra_params: ExtraParams,
provider_name: str,
) -> Dict[str, Any]:
"""Filter extra params based on the provider and warn if not supported."""
original = asdict(extra_params)
filtered = {}

query = extra_params.__class__.__name__
fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore

for k, v in original.items():
f = fields[k]
providers = f.title.split(",") if hasattr(f, "title") else []
if v != f.default:
if provider_name in providers:
filtered[k] = v
else:
available = ", ".join(providers)
warnings.warn(
message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
category=OpenBBWarning,
)

return filtered
self.query_executor = ProviderInterface().create_executor()

async def execute(self) -> Any:
"""Execute the query."""
standard_dict = asdict(self.standard_params)
extra_dict = (
self.filter_extra_params(self.extra_params, self.provider)
if self.extra_params
else {}
)
query_executor = self.provider_interface.create_executor()

return await query_executor.execute(
return await self.query_executor.execute(
provider_name=self.provider,
model_name=self.name,
params={**standard_dict, **extra_dict},
params={**asdict(self.standard_params), **asdict(self.extra_params)},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)
6 changes: 1 addition & 5 deletions openbb_platform/core/openbb_core/provider/query_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,4 @@ async def execute(
filtered_credentials = self.filter_credentials(
credentials, provider, fetcher.require_credentials
)

try:
return await fetcher.fetch_data(params, filtered_credentials, **kwargs)
except Exception as e:
raise OpenBBError(e) from e
return await fetcher.fetch_data(params, filtered_credentials, **kwargs)
Loading

0 comments on commit 7f4007a

Please sign in to comment.