Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Convert Params Models To Dictionary Before Assigning As Private Attribute In OBBject. #6492

Merged
merged 11 commits into from
Jun 13, 2024
8 changes: 5 additions & 3 deletions cli/openbb_cli/argparse_translator/obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def all(self) -> Dict[int, Dict]:
def _handle_standard_params(obbject: OBBject) -> str:
"""Handle standard params for obbjects"""
standard_params_json = ""
std_params = obbject._standard_params # pylint: disable=protected-access
if hasattr(std_params, "__dict__"):
std_params = getattr(
obbject, "_standard_params", {}
) # pylint: disable=protected-access
if std_params:
standard_params = {
k: str(v)[:30] for k, v in std_params.__dict__.items() if v
k: str(v)[:30] for k, v in std_params.items() if v and k != "data"
}
standard_params_json = json.dumps(standard_params)

Expand Down
2 changes: 1 addition & 1 deletion cli/tests/test_argparse_translator_obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_json_schema(self):
obb.extra = {"command": "test_command"}
obb._route = "/test/route"
obb._standard_params = Mock()
obb._standard_params.__dict__ = {}
obb._standard_params = {}
obb.results = [MockModel(1), MockModel(2)]
return obb

Expand Down
39 changes: 26 additions & 13 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,20 @@ def _chart(
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
# Here we will pop the chart_params kwargs and flatten them into the kwargs.
chart_params = {}
extra_params = kwargs.get("extra_params", {})
extra_params = getattr(obbject, "_extra_params", {})
deeleeramone marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(extra_params, "__dict__") and hasattr(
extra_params, "chart_params"
):
chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
elif isinstance(extra_params, dict) and "chart_params" in extra_params:
chart_params = kwargs["extra_params"].get("chart_params", {})
if extra_params and "chart_params" in extra_params:
chart_params = extra_params.get("chart_params", {})

if "chart_params" in kwargs and kwargs["chart_params"] is not None:
if kwargs.get("chart_params"):
chart_params.update(kwargs.pop("chart_params", {}))

# Verify that kwargs is not nested as kwargs so we don't miss any chart params.
if (
"kwargs" in kwargs
and "chart_params" in kwargs["kwargs"]
and kwargs["kwargs"].get("chart_params") is not None
and kwargs["kwargs"].get("chart_params")
):
chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))

Expand All @@ -265,6 +262,14 @@ def _chart(
raise OpenBBError(e) from e
warn(str(e), OpenBBWarning)

@classmethod
def _extract_params(cls, kwargs, key) -> Dict:
"""Extract params models from kwargs and convert to a dictionary."""
params = kwargs.get(key, {})
if hasattr(params, "__dict__"):
return params.__dict__
return params

# pylint: disable=R0913, R0914
@classmethod
async def _execute_func(
Expand Down Expand Up @@ -308,9 +313,17 @@ async def _execute_func(

try:
obbject = await cls._command(func, kwargs)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)

# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = cls._extract_params(kwargs, "standard_params") or (
kwargs if "data" in kwargs else {}
)
extra_params = cls._extract_params(kwargs, "extra_params")
obbject._standard_params = ( # pylint: disable=protected-access
std_params
)
obbject._extra_params = extra_params # pylint: disable=protected-access
if chart and obbject.results:
cls._chart(obbject, **kwargs)
finally:
Expand Down
3 changes: 3 additions & 0 deletions openbb_platform/core/openbb_core/app/model/obbject.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class OBBject(Tagged, Generic[T]):
_standard_params: Optional[Dict[str, Any]] = PrivateAttr(
default_factory=dict,
)
_standard_params: Optional[Dict[str, Any]] = PrivateAttr(
default_factory=dict,
)

def __repr__(self) -> str:
"""Human readable representation of the object."""
Expand Down
12 changes: 10 additions & 2 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams
Expand Down Expand Up @@ -364,8 +365,15 @@ def __init__(self, results):

def test_static_command_runner_chart():
"""Test _chart method when charting is in obbject.accessors."""
mock_obbject = Mock()
mock_obbject.accessors = ["charting"]

mock_obbject = OBBject(
results=[
{"date": "1990", "value": 100},
{"date": "1991", "value": 200},
{"date": "1992", "value": 300},
],
accessors={"charting": Mock()},
)
mock_obbject.charting.show = Mock()

StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,20 @@ def show(self, render: bool = True, **kwargs):
charting_function = self._get_chart_function(
self._obbject._route # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access
kwargs["charting_settings"] = (
self._charting_settings
) # pylint: disable=protected-access
kwargs["standard_params"] = (
self._obbject._standard_params
) # pylint: disable=protected-access
kwargs["extra_params"] = (
self._obbject._extra_params
) # pylint: disable=protected-access
kwargs["provider"] = (
self._obbject.provider
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access

if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
fig, content = charting_function(**kwargs)
fig = self._set_chart_style(fig)
content = fig.show(external=True, **kwargs).to_plotly_json()
Expand Down Expand Up @@ -448,24 +445,18 @@ def to_chart(
kwargs["symbol"] = symbol
kwargs["target"] = target
kwargs["index"] = index
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access
kwargs["charting_settings"] = (
self._charting_settings
) # pylint: disable=protected-access
kwargs["standard_params"] = (
self._obbject._standard_params
) # pylint: disable=protected-access
kwargs["extra_params"] = (
self._obbject._extra_params
) # pylint: disable=protected-access
kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
metadata = kwargs["extra"].get("metadata")
kwargs["extra_params"] = (
metadata.arguments.get("extra_params") if metadata else None
)
if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
try:
if has_data:
self.show(data=data_as_df, render=render, **kwargs)
Expand All @@ -488,7 +479,7 @@ def to_chart(

def _set_chart_style(self, figure: Figure):
"""Set the user preference for light or dark mode."""
style = self._charting_settings.chart_style # pylint: disable=protected-access
style = self._charting_settings.chart_style
font_color = "black" if style == "light" else "white"
paper_bgcolor = "white" if style == "light" else "black"
figure = figure.update_layout(
Expand All @@ -498,7 +489,7 @@ def _set_chart_style(self, figure: Figure):
)
return figure

def toggle_chart_style(self): # pylint: disable=protected-access
def toggle_chart_style(self):
"""Toggle the chart style between light and dark mode."""
if not hasattr(self._obbject.chart, "fig"):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,26 @@ def __init__(self):

@pytest.fixture()
def obbject():
"""Mock OOBject."""
"""Mock OBBject."""

class MockStdParams(BaseModel):
"""Mock Standard Parameters."""

param1: str
param2: str

class MockOOBject:
"""Mock OOBject."""
class MockOBBject:
"""Mock OBBject."""

def __init__(self):
"""Mock OOBject."""
"""Mock OBBject."""
self._user_settings = UserSettings()
self._system_settings = SystemSettings()
self._route = "mock/route"
self._standard_params = MockStdParams(
param1="mock_param1", param2="mock_param2"
)
self._extra_params = {}
self.results = "mock_results"

self.provider = "mock_provider"
Expand All @@ -54,7 +55,7 @@ def to_dataframe(self):
"""Mock to_dataframe."""
return mock_dataframe

return MockOOBject()
return MockOBBject()


def test_charting_settings(obbject):
Expand Down
Loading