Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Marigold committed Dec 19, 2024
1 parent a24ce28 commit cb83cfb
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 204 deletions.
190 changes: 0 additions & 190 deletions etl/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,196 +72,6 @@ def get_engine_async(conf: Optional[Dict[str, Any]] = None) -> AsyncEngine:
return engine


def get_dataset_id(
dataset_name: str,
db_conn: Optional[pymysql.Connection] = None,
version: Optional[str] = None,
) -> Any:
"""Get the dataset ID of a specific dataset name from database.
If more than one dataset is found for the same name, or if no dataset is found, an error is raised.
Parameters
----------
dataset_name : str
Dataset name.
db_conn : pymysql.Connection
Connection to database. Defaults to None, in which case a default connection is created (uses etl.config).
version : str
ETL version of the dataset. This is necessary when multiple datasets have the same title. In such a case, if
version is not given, the function will raise an error.
Returns
-------
dataset_id : int
Dataset ID.
"""
if db_conn is None:
db_conn = get_connection()

query = f"""
SELECT id
FROM datasets
WHERE name = '{dataset_name}'
"""

if version:
query += f" AND version = '{version}'"

with db_conn.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchall()

assert len(result) == 1, f"Ambiguous or unknown dataset name '{dataset_name}'"
dataset_id = result[0][0]
return dataset_id


def get_variables_in_dataset(
dataset_id: int,
only_used_in_charts: bool = False,
db_conn: Optional[pymysql.Connection] = None,
) -> Any:
"""Get all variables data for a specific dataset ID from database.
Parameters
----------
dataset_id : int
Dataset ID.
only_used_in_charts : bool
True to select variables only if they have been used in at least one chart. False to select all variables.
db_conn : pymysql.Connection
Connection to database. Defaults to None, in which case a default connection is created (uses etl.config).
Returns
-------
variables_data : pd.DataFrame
Variables data for considered dataset.
"""
if db_conn is None:
db_conn = get_connection()

query = f"""
SELECT *
FROM variables
WHERE datasetId = {dataset_id}
"""
if only_used_in_charts:
query += """
AND id IN (
SELECT DISTINCT variableId
FROM chart_dimensions
)
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
variables_data = pd.read_sql(query, con=db_conn)
return variables_data


def _get_variables_data_with_filter(
field_name: Optional[str] = None,
field_values: Optional[List[Any]] = None,
db_conn: Optional[pymysql.Connection] = None,
) -> Any:
if db_conn is None:
db_conn = get_connection()

if field_values is None:
field_values = []

# Construct the SQL query with a placeholder for each value in the list.
query = "SELECT * FROM variables"

if (field_name is not None) and (len(field_values) > 0):
query += f"\nWHERE {field_name} IN ({', '.join(['%s'] * len(field_values))});"

# Execute the query.
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
variables_data = pd.read_sql(query, con=db_conn, params=field_values)

assert set(variables_data[field_name]) <= set(field_values), f"Unexpected values for {field_name}."

# Warn about values that were not found.
missing_values = set(field_values) - set(variables_data[field_name])
if len(missing_values) > 0:
log.warning(f"Values of {field_name} not found in database: {missing_values}")

return variables_data


def get_variables_data(
filter: Optional[Dict[str, Any]] = None,
condition: Optional[str] = "OR",
db_conn: Optional[pymysql.Connection] = None,
) -> pd.DataFrame:
"""Get data from variables table, given a certain condition.
Parameters
----------
filter : Optional[Dict[str, Any]], optional
Filter to apply to the data, which must contain a field name and a list of field values,
e.g. {"id": [123456, 234567, 345678]}.
In principle, multiple filters can be given.
condition : Optional[str], optional
In case multiple filters are given, this parameter specifies whether the output filters should be the union
("OR") or the intersection ("AND").
db_conn : pymysql.Connection
Connection to database. Defaults to None, in which case a default connection is created (uses etl.config).
Returns
-------
df : pd.DataFrame
Variables data.
"""
# NOTE: This function should be optimized. Instead of fetching data for each filter, their conditions should be
# combined with OR or AND before executing the query.

# Initialize an empty dataframe.
if filter is not None:
df = pd.DataFrame({"id": []}).astype({"id": int})
for field_name, field_values in filter.items():
_df = _get_variables_data_with_filter(field_name=field_name, field_values=field_values, db_conn=db_conn)
if condition == "OR":
df = pd.concat([df, _df], axis=0)
elif condition == "AND":
df = pd.merge(df, _df, on="id", how="inner")
else:
raise ValueError(f"Invalid condition: {condition}")
else:
# Fetch data for all variables.
df = _get_variables_data_with_filter(db_conn=db_conn)

return df


def get_all_datasets(archived: bool = True, db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame:
"""Get all datasets in database.
Parameters
----------
db_conn : pymysql.connections.Connection
Connection to database. Defaults to None, in which case a default connection is created (uses etl.config).
Returns
-------
datasets : pd.DataFrame
All datasets in database. Table with three columns: dataset ID, dataset name, dataset namespace.
"""
if db_conn is None:
db_conn = get_connection()

query = " SELECT namespace, name, id, updatedAt FROM datasets"
if not archived:
query += " WHERE isArchived = 0"
datasets = pd.read_sql(query, con=db_conn)
return datasets.sort_values(["name", "namespace"])


def dict_to_object(d):
return type("DynamicObject", (object,), d)()

Expand Down
1 change: 0 additions & 1 deletion etl/grapher_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def _adapt_table_for_grapher(table: catalog.Table, engine: Engine) -> catalog.Ta
# Add entity code and name
with Session(engine) as session:
table = add_entity_code_and_name(session, table).copy_metadata(table)
# table = dm.add_entity_code_and_name(session, table).copy_metadata(table)

table = table.set_index(["entityId", "entityCode", "entityName", "year"] + dim_names)

Expand Down
10 changes: 7 additions & 3 deletions etl/grapher_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,9 @@ def variable_data_table_from_catalog(


def get_dataset_id(
dataset_name: str, db_conn: Optional[pymysql.Connection] = None, version: Optional[str] = None
dataset_name: str,
db_conn: Optional[pymysql.Connection] = None,
version: Optional[str] = None,
) -> Any:
"""Get the dataset ID of a specific dataset name from database.
Expand Down Expand Up @@ -604,7 +606,9 @@ def get_dataset_id(

@deprecated("This function is deprecated. Its logic will be soon moved to etl.grapher_model.Dataset.")
def get_variables_in_dataset(
dataset_id: int, only_used_in_charts: bool = False, db_conn: Optional[pymysql.Connection] = None
dataset_id: int,
only_used_in_charts: bool = False,
db_conn: Optional[pymysql.Connection] = None,
) -> Any:
"""Get all variables data for a specific dataset ID from database.
Expand Down Expand Up @@ -660,7 +664,7 @@ def get_all_datasets(archived: bool = True, db_conn: Optional[pymysql.Connection
if db_conn is None:
db_conn = get_connection()

query = " SELECT namespace, name, id, updatedAt, isArchived FROM datasets"
query = " SELECT namespace, name, id, updatedAt FROM datasets"
if not archived:
query += " WHERE isArchived = 0"
datasets = pd.read_sql(query, con=db_conn)
Expand Down
10 changes: 0 additions & 10 deletions lib/catalog/owid/catalog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,3 @@ def dataclass_from_dict(cls: Optional[Type[T]], d: Dict[str, Any]) -> T:
init_args[field_name] = v

return cls(**init_args)
return hash(tuple([hash_any(y) for y in x]))
elif isinstance(x, dict):
return hash(tuple([(hash_any(k), hash_any(v)) for k, v in sorted(x.items())]))
elif isinstance(x, str):
# get md5 of the string and truncate to 64 bits
return int(hashlib.md5(x.encode()).hexdigest(), 16) & ((1 << 64) - 1)
elif x is None:
return 0
else:
return hash(x)

0 comments on commit cb83cfb

Please sign in to comment.