diff --git a/cli/openbb_cli/argparse_translator/obbject_registry.py b/cli/openbb_cli/argparse_translator/obbject_registry.py index 372254b4b545..aa0f876942bf 100644 --- a/cli/openbb_cli/argparse_translator/obbject_registry.py +++ b/cli/openbb_cli/argparse_translator/obbject_registry.py @@ -51,10 +51,12 @@ def all(self) -> Dict[int, Dict]: def _handle_standard_params(obbject: OBBject) -> str: """Handle standard params for obbjects""" standard_params_json = "" - std_params = obbject._standard_params # pylint: disable=protected-access - if hasattr(std_params, "__dict__"): + std_params = getattr( + obbject, "_standard_params", {} + ) # pylint: disable=protected-access + if std_params: standard_params = { - k: str(v)[:30] for k, v in std_params.__dict__.items() if v + k: str(v)[:30] for k, v in std_params.items() if v and k != "data" } standard_params_json = json.dumps(standard_params) diff --git a/cli/tests/test_argparse_translator_obbject_registry.py b/cli/tests/test_argparse_translator_obbject_registry.py index a37e5a335415..53a4a1cea80f 100644 --- a/cli/tests/test_argparse_translator_obbject_registry.py +++ b/cli/tests/test_argparse_translator_obbject_registry.py @@ -35,7 +35,7 @@ def model_json_schema(self): obb.extra = {"command": "test_command"} obb._route = "/test/route" obb._standard_params = Mock() - obb._standard_params.__dict__ = {} + obb._standard_params = {} obb.results = [MockModel(1), MockModel(2)] return obb diff --git a/openbb_platform/core/openbb_core/app/command_runner.py b/openbb_platform/core/openbb_core/app/command_runner.py index e3dce19f4650..edfe22436ccf 100644 --- a/openbb_platform/core/openbb_core/app/command_runner.py +++ b/openbb_platform/core/openbb_core/app/command_runner.py @@ -237,23 +237,20 @@ def _chart( raise OpenBBError( "Charting is not installed. Please install `openbb-charting`." ) + # Here we will pop the chart_params kwargs and flatten them into the kwargs. chart_params = {} - extra_params = kwargs.get("extra_params", {}) + extra_params = getattr(obbject, "_extra_params", {}) - if hasattr(extra_params, "__dict__") and hasattr( - extra_params, "chart_params" - ): - chart_params = kwargs["extra_params"].__dict__.get("chart_params", {}) - elif isinstance(extra_params, dict) and "chart_params" in extra_params: - chart_params = kwargs["extra_params"].get("chart_params", {}) + if extra_params and "chart_params" in extra_params: + chart_params = extra_params.get("chart_params", {}) - if "chart_params" in kwargs and kwargs["chart_params"] is not None: + if kwargs.get("chart_params"): chart_params.update(kwargs.pop("chart_params", {})) - + # Verify that kwargs is not nested as kwargs so we don't miss any chart params. if ( "kwargs" in kwargs and "chart_params" in kwargs["kwargs"] - and kwargs["kwargs"].get("chart_params") is not None + and kwargs["kwargs"].get("chart_params") ): chart_params.update(kwargs.pop("kwargs", {}).get("chart_params", {})) @@ -265,6 +262,14 @@ def _chart( raise OpenBBError(e) from e warn(str(e), OpenBBWarning) + @classmethod + def _extract_params(cls, kwargs, key) -> Dict: + """Extract params models from kwargs and convert to a dictionary.""" + params = kwargs.get(key, {}) + if hasattr(params, "__dict__"): + return params.__dict__ + return params + # pylint: disable=R0913, R0914 @classmethod async def _execute_func( @@ -308,9 +313,17 @@ async def _execute_func( try: obbject = await cls._command(func, kwargs) - # pylint: disable=protected-access - obbject._route = route - obbject._standard_params = kwargs.get("standard_params", None) + + # This section prepares the obbject to pass to the charting service. + obbject._route = route # pylint: disable=protected-access + std_params = cls._extract_params(kwargs, "standard_params") or ( + kwargs if "data" in kwargs else {} + ) + extra_params = cls._extract_params(kwargs, "extra_params") + obbject._standard_params = ( # pylint: disable=protected-access + std_params + ) + obbject._extra_params = extra_params # pylint: disable=protected-access if chart and obbject.results: cls._chart(obbject, **kwargs) finally: diff --git a/openbb_platform/core/openbb_core/app/model/obbject.py b/openbb_platform/core/openbb_core/app/model/obbject.py index 67f41e9d15d8..75078ff2919b 100644 --- a/openbb_platform/core/openbb_core/app/model/obbject.py +++ b/openbb_platform/core/openbb_core/app/model/obbject.py @@ -72,6 +72,9 @@ class OBBject(Tagged, Generic[T]): _standard_params: Optional[Dict[str, Any]] = PrivateAttr( default_factory=dict, ) + _standard_params: Optional[Dict[str, Any]] = PrivateAttr( + default_factory=dict, + ) def __repr__(self) -> str: """Human readable representation of the object.""" diff --git a/openbb_platform/core/tests/app/test_command_runner.py b/openbb_platform/core/tests/app/test_command_runner.py index 41205ca75e5d..7c20059c1002 100644 --- a/openbb_platform/core/tests/app/test_command_runner.py +++ b/openbb_platform/core/tests/app/test_command_runner.py @@ -16,6 +16,7 @@ ) from openbb_core.app.model.abstract.warning import OpenBBWarning from openbb_core.app.model.command_context import CommandContext +from openbb_core.app.model.obbject import OBBject from openbb_core.app.model.system_settings import SystemSettings from openbb_core.app.model.user_settings import UserSettings from openbb_core.app.provider_interface import ExtraParams @@ -364,8 +365,15 @@ def __init__(self, results): def test_static_command_runner_chart(): """Test _chart method when charting is in obbject.accessors.""" - mock_obbject = Mock() - mock_obbject.accessors = ["charting"] + + mock_obbject = OBBject( + results=[ + {"date": "1990", "value": 100}, + {"date": "1991", "value": 200}, + {"date": "1992", "value": 300}, + ], + accessors={"charting": Mock()}, + ) mock_obbject.charting.show = Mock() StaticCommandRunner._chart(mock_obbject) # pylint: disable=protected-access diff --git a/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py b/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py index 7560b6fb3691..d20908557d4c 100644 --- a/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py +++ b/openbb_platform/obbject_extensions/charting/openbb_charting/__init__.py @@ -325,23 +325,20 @@ def show(self, render: bool = True, **kwargs): charting_function = self._get_chart_function( self._obbject._route # pylint: disable=protected-access ) - kwargs["obbject_item"] = self._obbject.results - kwargs["charting_settings"] = self._charting_settings - if ( - hasattr(self._obbject, "_standard_params") - and self._obbject._standard_params # pylint: disable=protected-access - ): - kwargs["standard_params"] = ( - self._obbject._standard_params.__dict__ # pylint: disable=protected-access - ) + kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access + kwargs["charting_settings"] = ( + self._charting_settings + ) # pylint: disable=protected-access + kwargs["standard_params"] = ( + self._obbject._standard_params + ) # pylint: disable=protected-access + kwargs["extra_params"] = ( + self._obbject._extra_params + ) # pylint: disable=protected-access kwargs["provider"] = ( self._obbject.provider ) # pylint: disable=protected-access kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access - - if "kwargs" in kwargs: - _kwargs = kwargs.pop("kwargs") - kwargs.update(_kwargs.get("chart_params", {})) fig, content = charting_function(**kwargs) fig = self._set_chart_style(fig) content = fig.show(external=True, **kwargs).to_plotly_json() @@ -448,24 +445,18 @@ def to_chart( kwargs["symbol"] = symbol kwargs["target"] = target kwargs["index"] = index - kwargs["obbject_item"] = self._obbject.results - kwargs["charting_settings"] = self._charting_settings - if ( - hasattr(self._obbject, "_standard_params") - and self._obbject._standard_params # pylint: disable=protected-access - ): - kwargs["standard_params"] = ( - self._obbject._standard_params.__dict__ # pylint: disable=protected-access - ) + kwargs["obbject_item"] = self._obbject # pylint: disable=protected-access + kwargs["charting_settings"] = ( + self._charting_settings + ) # pylint: disable=protected-access + kwargs["standard_params"] = ( + self._obbject._standard_params + ) # pylint: disable=protected-access + kwargs["extra_params"] = ( + self._obbject._extra_params + ) # pylint: disable=protected-access kwargs["provider"] = self._obbject.provider # pylint: disable=protected-access kwargs["extra"] = self._obbject.extra # pylint: disable=protected-access - metadata = kwargs["extra"].get("metadata") - kwargs["extra_params"] = ( - metadata.arguments.get("extra_params") if metadata else None - ) - if "kwargs" in kwargs: - _kwargs = kwargs.pop("kwargs") - kwargs.update(_kwargs.get("chart_params", {})) try: if has_data: self.show(data=data_as_df, render=render, **kwargs) @@ -488,7 +479,7 @@ def to_chart( def _set_chart_style(self, figure: Figure): """Set the user preference for light or dark mode.""" - style = self._charting_settings.chart_style # pylint: disable=protected-access + style = self._charting_settings.chart_style font_color = "black" if style == "light" else "white" paper_bgcolor = "white" if style == "light" else "black" figure = figure.update_layout( @@ -498,7 +489,7 @@ def _set_chart_style(self, figure: Figure): ) return figure - def toggle_chart_style(self): # pylint: disable=protected-access + def toggle_chart_style(self): """Toggle the chart style between light and dark mode.""" if not hasattr(self._obbject.chart, "fig"): raise ValueError( diff --git a/openbb_platform/obbject_extensions/charting/tests/test_charting.py b/openbb_platform/obbject_extensions/charting/tests/test_charting.py index 5849386987f3..9e148c72808d 100644 --- a/openbb_platform/obbject_extensions/charting/tests/test_charting.py +++ b/openbb_platform/obbject_extensions/charting/tests/test_charting.py @@ -24,7 +24,7 @@ def __init__(self): @pytest.fixture() def obbject(): - """Mock OOBject.""" + """Mock OBBject.""" class MockStdParams(BaseModel): """Mock Standard Parameters.""" @@ -32,17 +32,18 @@ class MockStdParams(BaseModel): param1: str param2: str - class MockOOBject: - """Mock OOBject.""" + class MockOBBject: + """Mock OBBject.""" def __init__(self): - """Mock OOBject.""" + """Mock OBBject.""" self._user_settings = UserSettings() self._system_settings = SystemSettings() self._route = "mock/route" self._standard_params = MockStdParams( param1="mock_param1", param2="mock_param2" ) + self._extra_params = {} self.results = "mock_results" self.provider = "mock_provider" @@ -54,7 +55,7 @@ def to_dataframe(self): """Mock to_dataframe.""" return mock_dataframe - return MockOOBject() + return MockOBBject() def test_charting_settings(obbject):