Skip to content

Commit

Permalink
[BugFix] Filter OECD data using start_date and end_date parameters (
Browse files Browse the repository at this point in the history
#6144)

* move constant dicts to constants.py

filter by start_date and end_date

* black

* black again again
  • Loading branch information
the-praxs authored Feb 29, 2024
1 parent bbc869d commit b47c9cc
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,11 @@
CLIQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_CLI, COUNTRY_TO_CODE_CLI
from pydantic import Field, field_validator

cli_mapping = {
"USA": "united_states",
"GBR": "united_kingdom",
"JPN": "japan",
"MEX": "mexico",
"IDN": "indonesia",
"AUS": "australia",
"BRA": "brazil",
"CAN": "canada",
"ITA": "italy",
"DEU": "germany",
"TUR": "turkey",
"FRA": "france",
"ZAF": "south_africa",
"KOR": "south_korea",
"ESP": "spain",
"IND": "india",
"CHN": "china",
"G7": "g7",
"G20": "g20",
}


countries = tuple(cli_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_CLI.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in cli_mapping.items()}


class OECDCLIQueryParams(CLIQueryParams):
Expand Down Expand Up @@ -70,10 +47,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand All @@ -99,9 +76,10 @@ def transform_query(params: Dict[str, Any]) -> OECDCLIQueryParams:

return OECDCLIQueryParams(**transformed_params)

# pylint: disable=unused-argument
@staticmethod
def extract_data(
query: OECDCLIQueryParams, # pylint: disable=W0613
query: OECDCLIQueryParams,
credentials: Optional[Dict[str, str]],
**kwargs: Any,
) -> Dict:
Expand All @@ -112,19 +90,27 @@ def extract_data(
)

if query.country != "all":
data = data.query(f"REF_AREA == '{country_to_code[query.country]}'")
data = data.query(f"REF_AREA == '{COUNTRY_TO_CODE_CLI[query.country]}'")

# Filter down
data = data.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]].rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
data["country"] = data["country"].map(cli_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_CLI)

return data.to_dict(orient="records")
data = data.to_dict(orient="records")
start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

# pylint: disable=unused-argument
@staticmethod
def transform_data(
query: OECDCLIQueryParams, data: Dict, **kwargs: Any
query: OECDCLIQueryParams,
data: Dict,
**kwargs: Any,
) -> List[OECDCLIData]:
"""Transform the data from the OECD endpoint."""
return [OECDCLIData.model_validate(d) for d in data]
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,11 @@
LTIRQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR
from pydantic import Field, field_validator

ltir_mapping = {
"BEL": "belgium",
"IRL": "ireland",
"MEX": "mexico",
"IDN": "indonesia",
"NZL": "new_zealand",
"JPN": "japan",
"GBR": "united_kingdom",
"FRA": "france",
"CHL": "chile",
"CAN": "canada",
"NLD": "netherlands",
"USA": "united_states",
"KOR": "south_korea",
"NOR": "norway",
"AUT": "austria",
"ZAF": "south_africa",
"DNK": "denmark",
"CHE": "switzerland",
"HUN": "hungary",
"LUX": "luxembourg",
"AUS": "australia",
"DEU": "germany",
"SWE": "sweden",
"ISL": "iceland",
"TUR": "turkey",
"GRC": "greece",
"ISR": "israel",
"CZE": "czech_republic",
"LVA": "latvia",
"SVN": "slovenia",
"POL": "poland",
"EST": "estonia",
"LTU": "lithuania",
"PRT": "portugal",
"CRI": "costa_rica",
"SVK": "slovakia",
"FIN": "finland",
"ESP": "spain",
"RUS": "russia",
"EA19": "euro_area19",
"COL": "colombia",
"ITA": "italy",
"IND": "india",
"CHN": "china",
"HRV": "croatia",
}

countries = tuple(ltir_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in ltir_mapping.items()}


class OECDLTIRQueryParams(LTIRQueryParams):
Expand Down Expand Up @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand Down Expand Up @@ -138,24 +90,30 @@ def extract_data(
) -> Dict:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else country_to_code[query.country]
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IRLT...."
data = helpers.get_possibly_cached_data(
url, function="economy_long_term_interest_rate"
)
query = f"FREQ=='{frequency}'"
query = query + f" & REF_AREA=='{country}'" if country else query
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
data = (
data.query(query)
data.query(url_query)
.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
.rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
)
data["country"] = data["country"].map(ltir_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
return data.to_dict(orient="records")
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

@staticmethod
def transform_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,11 @@
STIRQueryParams,
)
from openbb_oecd.utils import helpers
from openbb_oecd.utils.constants import CODE_TO_COUNTRY_IR, COUNTRY_TO_CODE_IR
from pydantic import Field, field_validator

stir_mapping = {
"BEL": "belgium",
"IRL": "ireland",
"MEX": "mexico",
"IDN": "indonesia",
"NZL": "new_zealand",
"JPN": "japan",
"GBR": "united_kingdom",
"FRA": "france",
"CHL": "chile",
"CAN": "canada",
"NLD": "netherlands",
"USA": "united_states",
"KOR": "south_korea",
"NOR": "norway",
"AUT": "austria",
"ZAF": "south_africa",
"DNK": "denmark",
"CHE": "switzerland",
"HUN": "hungary",
"LUX": "luxembourg",
"AUS": "australia",
"DEU": "germany",
"SWE": "sweden",
"ISL": "iceland",
"TUR": "turkey",
"GRC": "greece",
"ISR": "israel",
"CZE": "czech_republic",
"LVA": "latvia",
"SVN": "slovenia",
"POL": "poland",
"EST": "estonia",
"LTU": "lithuania",
"PRT": "portugal",
"CRI": "costa_rica",
"SVK": "slovakia",
"FIN": "finland",
"ESP": "spain",
"RUS": "russia",
"EA19": "euro_area19",
"COL": "colombia",
"ITA": "italy",
"IND": "india",
"CHN": "china",
"HRV": "croatia",
}

countries = tuple(stir_mapping.values()) + ("all",)
countries = tuple(CODE_TO_COUNTRY_IR.values()) + ("all",)
CountriesLiteral = Literal[countries] # type: ignore
country_to_code = {v: k for k, v in stir_mapping.items()}


class OECDSTIRQueryParams(STIRQueryParams):
Expand Down Expand Up @@ -101,10 +53,10 @@ def date_validate(cls, in_date: Union[date, str]): # pylint: disable=E0213
return date(_year, 12, 31)
# Now match if it is monthly, i.e 2022-01
elif re.match(r"\d{4}-\d{2}$", in_date):
year, month = map(int, in_date.split("-"))
year, month = map(int, in_date.split("-")) # type: ignore
if month == 12:
return date(year, month, 31)
next_month = date(year, month + 1, 1)
return date(year, month, 31) # type: ignore
next_month = date(year, month + 1, 1) # type: ignore
return date(next_month.year, next_month.month, 1) - timedelta(days=1)
# Now match if it is yearly, i.e 2022
elif re.match(r"\d{4}$", in_date):
Expand Down Expand Up @@ -138,24 +90,30 @@ def extract_data(
) -> Dict:
"""Return the raw data from the OECD endpoint."""
frequency = query.frequency[0].upper()
country = "" if query.country == "all" else country_to_code[query.country]
country = "" if query.country == "all" else COUNTRY_TO_CODE_IR[query.country]
url = "https://sdmx.oecd.org/public/rest/data/OECD.SDD.STES,DSD_KEI@DF_KEI,4.0/..IR3TIB...."
data = helpers.get_possibly_cached_data(
url, function="economy_short_term_interest_rate"
)
query = f"FREQ=='{frequency}'"
query = query + f" & REF_AREA=='{country}'" if country else query
url_query = f"FREQ=='{frequency}'"
url_query = url_query + f" & REF_AREA=='{country}'" if country else url_query
# Filter down
data = (
data.query(query)
data.query(url_query)
.reset_index(drop=True)[["REF_AREA", "TIME_PERIOD", "VALUE"]]
.rename(
columns={"REF_AREA": "country", "TIME_PERIOD": "date", "VALUE": "value"}
)
)
data["country"] = data["country"].map(stir_mapping)
data["country"] = data["country"].map(CODE_TO_COUNTRY_IR)
data = data.fillna("N/A").replace("N/A", None)
return data.to_dict(orient="records")
data = data.to_dict(orient="records")

start_date = query.start_date.strftime("%Y-%m-%d") # type: ignore
end_date = query.end_date.strftime("%Y-%m-%d") # type: ignore
data = list(filter(lambda x: start_date <= x["date"] <= end_date, data))

return data

@staticmethod
def transform_data(
Expand Down
Loading

0 comments on commit b47c9cc

Please sign in to comment.