diff --git a/openbb_terminal/alternative/covid/covid_controller.py b/openbb_terminal/alternative/covid/covid_controller.py index 39232970da5a..2feb268aba1f 100644 --- a/openbb_terminal/alternative/covid/covid_controller.py +++ b/openbb_terminal/alternative/covid/covid_controller.py @@ -127,6 +127,8 @@ def call_country(self, other_args: List[str]): ) return country = ns_parser.country.title().replace("_", " ") + if country == "Us": + country = "US" self.country = country console.print(f"[cyan]{country}[/cyan] loaded\n") else: diff --git a/openbb_terminal/alternative/covid/covid_model.py b/openbb_terminal/alternative/covid/covid_model.py index 28c9066b635f..4f039214e21e 100644 --- a/openbb_terminal/alternative/covid/covid_model.py +++ b/openbb_terminal/alternative/covid/covid_model.py @@ -60,6 +60,9 @@ def get_global_cases(country: str) -> pd.DataFrame: .T ) cases.index = pd.to_datetime(cases.index) + if country not in cases: + console.print("[red]The selection `{country}` is not a valid option.[/red]\n") + return pd.DataFrame() cases = pd.DataFrame(cases[country]).diff().dropna() if cases.shape[1] > 1: return pd.DataFrame(cases.sum(axis=1)) @@ -89,6 +92,9 @@ def get_global_deaths(country: str) -> pd.DataFrame: .T ) deaths.index = pd.to_datetime(deaths.index) + if country not in deaths: + console.print("[red]The selection `{country}` is not a valid option.[/red]\n") + return pd.DataFrame() deaths = pd.DataFrame(deaths[country]).diff().dropna() if deaths.shape[1] > 1: return pd.DataFrame(deaths.sum(axis=1)) @@ -114,8 +120,12 @@ def get_covid_ov( pd.DataFrame Dataframe of historical cases and deaths """ + if country.lower() == "us": + country = "US" cases = get_global_cases(country) deaths = get_global_deaths(country) + if cases.empty or deaths.empty: + return pd.DataFrame() data = pd.concat([cases, deaths], axis=1) data.columns = ["Cases", "Deaths"] data.index = [x.strftime("%Y-%m-%d") for x in data.index] diff --git a/openbb_terminal/alternative/covid/covid_view.py b/openbb_terminal/alternative/covid/covid_view.py index 6f7dfb5d0e81..68c56c877dc9 100644 --- a/openbb_terminal/alternative/covid/covid_view.py +++ b/openbb_terminal/alternative/covid/covid_view.py @@ -40,6 +40,8 @@ def plot_covid_ov( """ cases = covid_model.get_global_cases(country) / 1_000 deaths = covid_model.get_global_deaths(country) + if cases.empty or deaths.empty: + return ov = pd.concat([cases, deaths], axis=1) ov.columns = ["Cases", "Deaths"] @@ -148,6 +150,8 @@ def display_covid_ov( plot: bool Flag to display historical plot """ + if country.lower() == "us": + country = "US" if plot: plot_covid_ov(country) if raw: