Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] - Move extra_params warning to query.py #6259

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from warnings import catch_warnings, showwarning, warn

from fastapi.params import Query
from pydantic import BaseModel, ConfigDict, create_model

from openbb_core.app.logs.logging_service import LoggingService
Expand Down Expand Up @@ -179,7 +178,6 @@ def _get_default_provider(

@staticmethod
def _warn_kwargs(
provider_choices: Dict[str, Any],
extra_params: Dict[str, Any],
model: Type[BaseModel],
) -> None:
Expand All @@ -192,25 +190,9 @@ def _warn_kwargs(
if is_dataclass(annotation) and any(
t is ExtraParams for t in getattr(annotation, "__bases__", [])
):
# We only warn when endpoint defines ExtraParams, so we need
# to check if the annotation is a dataclass and child of ExtraParams
valid = asdict(annotation()) # type: ignore
provider = provider_choices.get("provider", None)
for p in extra_params:
if field := valid.get(p):
if provider:
providers = (
field.title
if isinstance(field, Query) and isinstance(field.title, str)
else ""
).split(",")
if provider not in providers:
warn(
message=f"Parameter '{p}' is not supported by '{provider}'."
f" Available for: {', '.join(providers)}.",
category=OpenBBWarning,
)
else:
if p not in valid:
warn(
message=f"Parameter '{p}' not found.",
category=OpenBBWarning,
Expand Down Expand Up @@ -246,7 +228,6 @@ def validate_kwargs(
# Validate and coerce
model = ValidationModel(**kwargs)
ParametersBuilder._warn_kwargs(
ParametersBuilder._as_dict(kwargs.get("provider_choices", {})),
ParametersBuilder._as_dict(kwargs.get("extra_params", {})),
ValidationModel,
)
Expand Down
48 changes: 44 additions & 4 deletions openbb_platform/core/openbb_core/app/query.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Query class."""

import warnings
from dataclasses import asdict
from typing import Any
from typing import Any, Dict

from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.provider_interface import (
ExtraParams,
Expand All @@ -28,14 +30,52 @@ def __init__(
self.standard_params = standard_params
self.extra_params = extra_params
self.name = self.standard_params.__class__.__name__
self.query_executor = ProviderInterface().create_executor()
self.provider_interface = ProviderInterface()

def filter_extra_params(
self,
extra_params: ExtraParams,
provider_name: str,
) -> Dict[str, Any]:
"""Filter extra params based on the provider and warn if not supported."""
original = asdict(extra_params)
filtered = {}

query = extra_params.__class__.__name__
fields = asdict(self.provider_interface.params[query]["extra"]()) # type: ignore

for k, v in original.items():
f = fields[k]
providers = f.title.split(",") if hasattr(f, "title") else []

# We only filter/warn if the value is not the default, because fastapi
# Depends always sends the default value, even if it's not in the request.
if v != f.default:
if provider_name in providers:
filtered[k] = v
else:
available = ", ".join(providers)
warnings.warn(
message=f"Parameter '{k}' is not supported by {provider_name}. Available for: {available}.",
category=OpenBBWarning,
)

return filtered

async def execute(self) -> Any:
"""Execute the query."""
return await self.query_executor.execute(
standard_dict = asdict(self.standard_params)
extra_dict = (
self.filter_extra_params(self.extra_params, self.provider)
if self.extra_params
else {}
)
query_executor = self.provider_interface.create_executor()

return await query_executor.execute(
provider_name=self.provider,
model_name=self.name,
params={**asdict(self.standard_params), **asdict(self.extra_params)},
params={**standard_dict, **extra_dict},
credentials=self.cc.user_settings.credentials.model_dump(),
preferences=self.cc.user_settings.preferences.model_dump(),
)
45 changes: 6 additions & 39 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,61 +229,28 @@ def test_parameters_builder_validate_kwargs(mock_func):


@pytest.mark.parametrize(
"provider_choices, extra_params, base, expect",
"extra_params, base, expect",
[
(
{"provider": "provider1"},
{"exists_in_2": ...},
{"exists": ...},
ExtraParams,
OpenBBWarning,
),
(
{"provider": "inexistent_provider"},
{"exists_in_both": ...},
ExtraParams,
OpenBBWarning,
None,
),
(
{},
{"inexistent_field": ...},
ExtraParams,
OpenBBWarning,
),
(
{},
{"inexistent_field": ...},
object,
None,
),
(
{"provider": "provider2"},
{"exists_in_2": ...},
ExtraParams,
None,
),
(
{"provider": "provider2"},
{"exists_in_both": ...},
ExtraParams,
None,
),
(
{},
{"exists_in_both": ...},
ExtraParams,
None,
),
],
)
def test_parameters_builder__warn_kwargs(provider_choices, extra_params, base, expect):
def test_parameters_builder__warn_kwargs(extra_params, base, expect):
"""Test _warn_kwargs."""

@dataclass
class SomeModel(base):
"""SomeModel"""

exists_in_2: QueryParam = Query(..., title="provider2")
exists_in_both: QueryParam = Query(..., title="provider1,provider2")
exists: QueryParam = Query(...)

class Model(BaseModel):
"""Model"""
Expand All @@ -293,7 +260,7 @@ class Model(BaseModel):

with pytest.warns(expect) as warning_info:
# pylint: disable=protected-access
ParametersBuilder._warn_kwargs(provider_choices, extra_params, Model)
ParametersBuilder._warn_kwargs(extra_params, Model)

if not expect:
assert len(warning_info) == 0
Expand Down
26 changes: 26 additions & 0 deletions openbb_platform/core/tests/app/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ def query_instance():
)


def test_filter_extra_params(query):
"""Test filter_extra_params."""
extra_params = create_mock_extra_params()
extra_params = query.filter_extra_params(extra_params, "fmp")

assert isinstance(extra_params, dict)
assert len(extra_params) == 0


def test_filter_extra_params_wrong_param(query):
"""Test filter_extra_params."""

@dataclass
class EquityHistorical:
"""Mock ExtraParams dataclass."""

sort: str = "desc"
limit: int = 4

extra_params = EquityHistorical()

extra = query.filter_extra_params(extra_params, "fmp")
assert isinstance(extra, dict)
assert len(extra) == 0


@pytest.mark.asyncio
async def test_execute_method_fake_credentials(query_instance: Query, mock_registry):
"""Test execute method without setting credentials."""
Expand Down
Loading