Skip to content

Commit

Permalink
[Enhancement] Convert Params Models To Dictionary Before Assigning As…
Browse files Browse the repository at this point in the history
… Private Attribute In OBBject. (#6492)

* convert params models to dict

* update cli test

* fix test_static_command_runner_chart

* mock_obbject results

* minor linting adjustments

* review changes

* charting test mock obbject

---------

Co-authored-by: Henrique Joaquim <henriquecjoaquim@gmail.com>
  • Loading branch information
deeleeramone and hjoaquim authored Jun 13, 2024
1 parent 99d2256 commit 8b9f461
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 55 deletions.
8 changes: 5 additions & 3 deletions cli/openbb_cli/argparse_translator/obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def all(self) -> Dict[int, Dict]:
def _handle_standard_params(obbject: OBBject) -> str:
"""Handle standard params for obbjects"""
standard_params_json = ""
std_params = obbject._standard_params # pylint: disable=protected-access
if hasattr(std_params, "__dict__"):
std_params = getattr(
obbject, "_standard_params", {}
) # pylint: disable=protected-access
if std_params:
standard_params = {
k: str(v)[:30] for k, v in std_params.__dict__.items() if v
k: str(v)[:30] for k, v in std_params.items() if v and k != "data"
}
standard_params_json = json.dumps(standard_params)

Expand Down
2 changes: 1 addition & 1 deletion cli/tests/test_argparse_translator_obbject_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_json_schema(self):
obb.extra = {"command": "test_command"}
obb._route = "/test/route"
obb._standard_params = Mock()
obb._standard_params.__dict__ = {}
obb._standard_params = {}
obb.results = [MockModel(1), MockModel(2)]
return obb

Expand Down
39 changes: 26 additions & 13 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,20 @@ def _chart(
raise OpenBBError(
"Charting is not installed. Please install `openbb-charting`."
)
# Here we will pop the chart_params kwargs and flatten them into the kwargs.
chart_params = {}
extra_params = kwargs.get("extra_params", {})
extra_params = getattr(obbject, "_extra_params", {})

if hasattr(extra_params, "__dict__") and hasattr(
extra_params, "chart_params"
):
chart_params = kwargs["extra_params"].__dict__.get("chart_params", {})
elif isinstance(extra_params, dict) and "chart_params" in extra_params:
chart_params = kwargs["extra_params"].get("chart_params", {})
if extra_params and "chart_params" in extra_params:
chart_params = extra_params.get("chart_params", {})

if "chart_params" in kwargs and kwargs["chart_params"] is not None:
if kwargs.get("chart_params"):
chart_params.update(kwargs.pop("chart_params", {}))

# Verify that kwargs is not nested as kwargs so we don't miss any chart params.
if (
"kwargs" in kwargs
and "chart_params" in kwargs["kwargs"]
and kwargs["kwargs"].get("chart_params") is not None
and kwargs["kwargs"].get("chart_params")
):
chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {}))

Expand All @@ -265,6 +262,14 @@ def _chart(
raise OpenBBError(e) from e
warn(str(e), OpenBBWarning)

@classmethod
def _extract_params(cls, kwargs, key) -> Dict:
"""Extract params models from kwargs and convert to a dictionary."""
params = kwargs.get(key, {})
if hasattr(params, "__dict__"):
return params.__dict__
return params

# pylint: disable=R0913, R0914
@classmethod
async def _execute_func(
Expand Down Expand Up @@ -308,9 +313,17 @@ async def _execute_func(

try:
obbject = await cls._command(func, kwargs)
# pylint: disable=protected-access
obbject._route = route
obbject._standard_params = kwargs.get("standard_params", None)

# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = cls._extract_params(kwargs, "standard_params") or (
kwargs if "data" in kwargs else {}
)
extra_params = cls._extract_params(kwargs, "extra_params")
obbject._standard_params = ( # pylint: disable=protected-access
std_params
)
obbject._extra_params = extra_params # pylint: disable=protected-access
if chart and obbject.results:
cls._chart(obbject, **kwargs)
finally:
Expand Down
3 changes: 3 additions & 0 deletions openbb_platform/core/openbb_core/app/model/obbject.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class OBBject(Tagged, Generic[T]):
_standard_params: Optional[Dict[str, Any]] = PrivateAttr(
default_factory=dict,
)
_standard_params: Optional[Dict[str, Any]] = PrivateAttr(
default_factory=dict,
)

def __repr__(self) -> str:
"""Human readable representation of the object."""
Expand Down
12 changes: 10 additions & 2 deletions openbb_platform/core/tests/app/test_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from openbb_core.app.model.abstract.warning import OpenBBWarning
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.provider_interface import ExtraParams
Expand Down Expand Up @@ -364,8 +365,15 @@ def __init__(self, results):

def test_static_command_runner_chart():
"""Test _chart method when charting is in obbject.accessors."""
mock_obbject = Mock()
mock_obbject.accessors = ["charting"]

mock_obbject = OBBject(
results=[
{"date": "1990", "value": 100},
{"date": "1991", "value": 200},
{"date": "1992", "value": 300},
],
accessors={"charting": Mock()},
)
mock_obbject.charting.show = Mock()

StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,23 +325,20 @@ def show(self, render: bool = True, **kwargs):
charting_function = self._get_chart_function(
self._obbject._route # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access

This comment has been minimized.

Copy link
@hjoaquim

hjoaquim Jun 20, 2024

Author Contributor

@deeleeramone this issue with the charting extension was introduced here. this one slipped from our attention

kwargs["charting_settings"] = (
self._charting_settings
) # pylint: disable=protected-access
kwargs["standard_params"] = (
self._obbject._standard_params
) # pylint: disable=protected-access
kwargs["extra_params"] = (
self._obbject._extra_params
) # pylint: disable=protected-access
kwargs["provider"] = (
self._obbject.provider
) # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access

if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
fig, content = charting_function(**kwargs)
fig = self._set_chart_style(fig)
content = fig.show(external=True, **kwargs).to_plotly_json()
Expand Down Expand Up @@ -448,24 +445,18 @@ def to_chart(
kwargs["symbol"] = symbol
kwargs["target"] = target
kwargs["index"] = index
kwargs["obbject_item"] = self._obbject.results
kwargs["charting_settings"] = self._charting_settings
if (
hasattr(self._obbject, "_standard_params")
and self._obbject._standard_params # pylint: disable=protected-access
):
kwargs["standard_params"] = (
self._obbject._standard_params.__dict__ # pylint: disable=protected-access
)
kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access
kwargs["charting_settings"] = (
self._charting_settings
) # pylint: disable=protected-access
kwargs["standard_params"] = (
self._obbject._standard_params
) # pylint: disable=protected-access
kwargs["extra_params"] = (
self._obbject._extra_params
) # pylint: disable=protected-access
kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access
kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access
metadata = kwargs["extra"].get("metadata")
kwargs["extra_params"] = (
metadata.arguments.get("extra_params") if metadata else None
)
if "kwargs" in kwargs:
_kwargs = kwargs.pop("kwargs")
kwargs.update(_kwargs.get("chart_params", {}))
try:
if has_data:
self.show(data=data_as_df, render=render, **kwargs)
Expand All @@ -488,7 +479,7 @@ def to_chart(

def _set_chart_style(self, figure: Figure):
"""Set the user preference for light or dark mode."""
style = self._charting_settings.chart_style # pylint: disable=protected-access
style = self._charting_settings.chart_style
font_color = "black" if style == "light" else "white"
paper_bgcolor = "white" if style == "light" else "black"
figure = figure.update_layout(
Expand All @@ -498,7 +489,7 @@ def _set_chart_style(self, figure: Figure):
)
return figure

def toggle_chart_style(self): # pylint: disable=protected-access
def toggle_chart_style(self):
"""Toggle the chart style between light and dark mode."""
if not hasattr(self._obbject.chart, "fig"):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,26 @@ def __init__(self):

@pytest.fixture()
def obbject():
"""Mock OOBject."""
"""Mock OBBject."""

class MockStdParams(BaseModel):
"""Mock Standard Parameters."""

param1: str
param2: str

class MockOOBject:
"""Mock OOBject."""
class MockOBBject:
"""Mock OBBject."""

def __init__(self):
"""Mock OOBject."""
"""Mock OBBject."""
self._user_settings = UserSettings()
self._system_settings = SystemSettings()
self._route = "mock/route"
self._standard_params = MockStdParams(
param1="mock_param1", param2="mock_param2"
)
self._extra_params = {}
self.results = "mock_results"

self.provider = "mock_provider"
Expand All @@ -54,7 +55,7 @@ def to_dataframe(self):
"""Mock to_dataframe."""
return mock_dataframe

return MockOOBject()
return MockOBBject()


def test_charting_settings(obbject):
Expand Down

0 comments on commit 8b9f461

Please sign in to comment.