Skip to content

Commit

Permalink
fix: move extra_params warning to query.py (#6259)
Browse files Browse the repository at this point in the history
  • Loading branch information
montezdesousa authored Mar 27, 2024
1 parent 657fd1f commit 4b5787b
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 63 deletions.
21 changes: 1 addition & 20 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from warnings import catch_warnings, showwarning, warn

from fastapi.params import Query
from pydantic import BaseModel, ConfigDict, create_model

from openbb_core.app.logs.logging_service import LoggingService
Expand Down Expand Up @@ -179,7 +178,6 @@ def _get_default_provider(

@staticmethod
def _warn_kwargs(
provider_choices: Dict[str, Any],
extra_params: Dict[str, Any],
model: Type[BaseModel],
) -> None:
Expand All @@ -192,25 +190,9 @@ def _warn_kwargs(
if is_dataclass(annotation) and any(
t is ExtraParams for t in getattr(annotation, "__bases__", [])
):
# We only warn when endpoint defines ExtraParams, so we need
# to check if the annotation is a dataclass and child of ExtraParams
valid = asdict(annotation()) # type: ignore
provider = provider_choices.get("provider", None)
for p in extra_params:
if field := valid.get(p):
if provider:
providers = (
field.title
if isinstance(field, Query) and isinstance(field.title, str)
else ""
).split(",")
if provider not in providers:
warn(
message=f"Parameter '{p}' is not supported by '{provider}'."
f" Available for: {', '.join(providers)}.",
category=OpenBBWarning,
)
else:
if p not in valid:
warn(
message=f"Parameter '{p}' not found.",
category=OpenBBWarning,
Expand Down Expand Up @@ -246,7 +228,6 @@ def validate_kwargs(
# Validate and coerce
model = ValidationModel(**kwargs)
ParametersBuilder._warn_kwargs(
ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
ValidationModel,
)
Expand Down
48 changes: 44 additions & 4 deletions openbb_platform/core/openbb_core/app/query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Query class."""

import warnings
from dataclasses import asdict
from typing import Any
from typing import Any, Dict

from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
Expand All @@ -28,14 +30,52 @@ def __init__(
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
self.query_executor = ProviderInterface().create_executor()
self.provider_interface = ProviderInterface()

def filter_extra_params(
self,
extra_params: ExtraParams,
provider_name: str,
) -> Dict[str, Any]:
"""Filter extra params based on the provider and warn if not supported."""
original = asdict(extra_params)
filtered = {}

query = extra_params.__class__.__name__
fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore

for k, v in original.items():
f = fields[k]
providers = f.title.split(",") if hasattr(f, "title") else []

# We only filter/warn if the value is not the default, because fastapi
# Depends always sends the default value, even if it's not in the request.
if v != f.default:
if provider_name in providers:
filtered[k] = v
else:
available = ", ".join(providers)
warnings.warn(
message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
category=OpenBBWarning,
)

return filtered

async def execute(self) -> Any:
"""Execute the query."""
return await self.query_executor.execute(
standard_dict = asdict(self.standard_params)
extra_dict = (
self.filter_extra_params(self.extra_params, self.provider)
if self.extra_params
else {}
)
query_executor = self.provider_interface.create_executor()

return await query_executor.execute(
provider_name=self.provider,
model_name=self.name,
params={**asdict(self.standard_params), **asdict(self.extra_params)},
params={**standard_dict, **extra_dict},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)
45 changes: 6 additions & 39 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,61 +229,28 @@ def test_parameters_builder_validate_kwargs(mock_func):


@pytest.mark.parametrize(
"provider_choices, extra_params, base, expect",
"extra_params, base, expect",
[
(
{"provider": "provider1"},
{"exists_in_2": ...},
{"exists": ...},
ExtraParams,
OpenBBWarning,
),
(
{"provider": "inexistent_provider"},
{"exists_in_both": ...},
ExtraParams,
OpenBBWarning,
None,
),
(
{},
{"inexistent_field": ...},
ExtraParams,
OpenBBWarning,
),
(
{},
{"inexistent_field": ...},
object,
None,
),
(
{"provider": "provider2"},
{"exists_in_2": ...},
ExtraParams,
None,
),
(
{"provider": "provider2"},
{"exists_in_both": ...},
ExtraParams,
None,
),
(
{},
{"exists_in_both": ...},
ExtraParams,
None,
),
],
)
def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect):
def test_parameters_builder__warn_kwargs(extra_params, base, expect):
"""Test _warn_kwargs."""

@dataclass
class SomeModel(base):
"""SomeModel"""

exists_in_2: QueryParam = Query(..., title="provider2")
exists_in_both: QueryParam = Query(..., title="provider1,provider2")
exists: QueryParam = Query(...)

class Model(BaseModel):
"""Model"""
Expand All @@ -293,7 +260,7 @@ class Model(BaseModel):

with pytest.warns(expect) as warning_info:
# pylint: disable=protected-access
ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model)
ParametersBuilder._warn_kwargs(extra_params, Model)

if not expect:
assert len(warning_info) == 0
Expand Down
26 changes: 26 additions & 0 deletions openbb_platform/core/tests/app/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def query_instance():
)


def test_filter_extra_params(query):
"""Test filter_extra_params."""
extra_params = create_mock_extra_params()
extra_params = query.filter_extra_params(extra_params, "fmp")

assert isinstance(extra_params, dict)
assert len(extra_params) == 0


def test_filter_extra_params_wrong_param(query):
"""Test filter_extra_params."""

@dataclass
class EquityHistorical:
"""Mock ExtraParams dataclass."""

sort: str = "desc"
limit: int = 4

extra_params = EquityHistorical()

extra = query.filter_extra_params(extra_params, "fmp")
assert isinstance(extra, dict)
assert len(extra) == 0


@pytest.mark.asyncio
async def test_execute_method_fake_credentials(query_instance: Query, mock_registry):
"""Test execute method without setting credentials."""
Expand Down

0 comments on commit 4b5787b

Please sign in to comment.