Skip to content

Commit

Permalink
Refactor RegionProcessor (#213)
Browse files Browse the repository at this point in the history
* Add drop_negative_weights to VariableCode

* Add aggregation aggs properties to VariableCode

* Make usage of VariableCode more consistent

* Add variable selection method based on agg args

* Remove no longer needed global variables

* Clean up unused imports

* Make check_unexpected_regions method of RegionAggregationMapping

* Refactor region processing

* Apply suggestions from code review

Co-authored-by: Daniel Huppmann <dh@dergelbesalon.at>

* Change order of pyam agg kwrgs for consistency

* Appease stickler

* Use proper docstring

Co-authored-by: Daniel Huppmann <dh@dergelbesalon.at>
  • Loading branch information
phackstock and danielhuppmann authored Jan 10, 2023
1 parent cac01c4 commit 387cb8a
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 107 deletions.
23 changes: 23 additions & 0 deletions nomenclature/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class VariableCode(Code):
method: Optional[str] = None
check_aggregate: Optional[bool] = Field(False, alias="check-aggregate")
components: Optional[Union[List[str], List[Dict[str, List[str]]]]] = None
drop_negative_weights: Optional[bool] = None

class Config:
# this allows using both "check_aggregate" and "check-aggregate" for attribute
Expand All @@ -153,3 +154,25 @@ def named_attributes(cls) -> Set[str]:
.named_attributes()
.union(f.alias for f in cls.__dict__["__fields__"].values())
)

@property
def pyam_agg_kwargs(self) -> Dict[str, Any]:
# return a dict of all not None pyam aggregation properties
return {
field: getattr(self, field)
for field in (
"weight",
"method",
"components",
"drop_negative_weights",
)
if getattr(self, field) is not None
}

@property
def agg_kwargs(self) -> Dict[str, Any]:
return (
{**self.pyam_agg_kwargs, **{"region_aggregation": self.region_aggregation}}
if self.region_aggregation is not None
else self.pyam_agg_kwargs
)
89 changes: 46 additions & 43 deletions nomenclature/codelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@
VariableRenameTargetError,
)

# arguments of the method `pyam.IamDataFrame.aggregate_region`
# required for checking validity of variable-CodeList-attributes
PYAM_AGG_KWARGS = [
"components",
"method",
"weight",
"drop_negative_weights",
]

here = Path(__file__).parent.absolute()

Expand Down Expand Up @@ -401,43 +393,35 @@ class VariableCodeList(CodeList):
@validator("mapping")
def check_variable_region_aggregation_args(cls, v):
"""Check that any variable "region-aggregation" mappings are valid"""
items = [
(name, code)
for (name, code) in v.items()
if code.region_aggregation is not None
]

for (name, code) in items:
# ensure that there no pyam-aggregation-kwargs and
conflict_args = [
i
for i, val in code.dict().items()
if i in PYAM_AGG_KWARGS and val is not None
]
if conflict_args:
raise VariableRenameArgError(
variable=name,
file=code.file,
args=conflict_args,
)

# ensure that mapped variables are defined in the nomenclature
invalid = []
for inst in code.region_aggregation:
invalid.extend(var for var in inst if var not in v)
if invalid:
raise VariableRenameTargetError(
variable=name, file=code.file, target=invalid
)
for var in v.values():
# ensure that a variable does not have both individual
# pyam-aggregation-kwargs and a 'region-aggregation' attribute
if var.region_aggregation is not None:
if conflict_args := list(var.pyam_agg_kwargs.keys()):
raise VariableRenameArgError(
variable=var.name,
file=var.file,
args=conflict_args,
)

# ensure that mapped variables are defined in the nomenclature
invalid = []
for inst in var.region_aggregation:
invalid.extend(var for var in inst if var not in v)
if invalid:
raise VariableRenameTargetError(
variable=var.name, file=var.file, target=invalid
)
return v

@validator("mapping")
def check_weight_in_vars(cls, v):
# Check that all variables specified in 'weight' are present in the codelist
"""Check that all variables specified in 'weight' are present in the codelist"""
if missing_weights := [
(name, code.weight, code.file)
for name, code in v.items()
if code.weight is not None and code.weight not in v
(var.name, var.weight, var.file)
for var in v.values()
if var.weight is not None and var.weight not in v
]:
raise MissingWeightError(
missing_weights="".join(
Expand All @@ -450,16 +434,35 @@ def check_weight_in_vars(cls, v):
@validator("mapping")
def cast_variable_components_args(cls, v):
"""Cast "components" list of dicts to a codelist"""

# translate a list of single-key dictionaries to a simple dictionary
for name, code in v.items():
if code.components and isinstance(code.components[0], dict):
for var in v.values():
if var.components and isinstance(var.components[0], dict):
comp = {}
for val in code.components:
for val in var.components:
comp.update(val)
v[name].components = comp
v[var.name].components = comp

return v

def vars_default_args(self, variables: List[str]) -> List[VariableCode]:
"""return subset of variables which does not feature any special pyam
aggregation arguments and where skip_region_aggregation is False"""
return [
self[var]
for var in variables
if not self[var].agg_kwargs and not self[var].skip_region_aggregation
]

def vars_kwargs(self, variables: List[str]) -> List[VariableCode]:
# return subset of variables which features special pyam aggregation arguments
# and where skip_region_aggregation is False
return [
self[var]
for var in variables
if self[var].agg_kwargs and not self[var].skip_region_aggregation
]


class RegionCodeList(CodeList):
"""A subclass of CodeList specified for regions
Expand Down
110 changes: 46 additions & 64 deletions nomenclature/processor/region.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Union

import jsonschema
import pyam
Expand All @@ -13,7 +13,6 @@
from pydantic.error_wrappers import ErrorWrapper
from pydantic.types import DirectoryPath, FilePath

from nomenclature.codelist import PYAM_AGG_KWARGS
from nomenclature.definition import DataStructureDefinition
from nomenclature.error.region import (
ExcludeRegionOverlapError,
Expand All @@ -23,8 +22,6 @@
)
from nomenclature.processor.utils import get_relative_path

AGG_KWARGS = PYAM_AGG_KWARGS + ["region_aggregation"]

logger = logging.getLogger(__name__)

here = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -289,11 +286,30 @@ def rename_mapping(self) -> Dict[str, str]:

def validate_regions(self, dsd: DataStructureDefinition) -> None:
if hasattr(dsd, "region"):
# invalid = [c for c in self.all_regions if c not in dsd.region]
invalid = dsd.region.validate_items(self.all_regions)
if invalid:
if invalid := dsd.region.validate_items(self.all_regions):
raise RegionNotDefinedError(region=invalid, file=self.file)

def check_unexpected_regions(self, df: IamDataFrame) -> None:
# Raise error if a region in the input data is not used in the model mapping

if regions_not_found := set(df.region) - set(
self.model_native_region_names
+ self.common_region_names
+ [
const_reg
for comm_reg in self.common_regions or []
for const_reg in comm_reg.constituent_regions
]
+ (self.exclude_regions or [])
):
raise ValueError(
f"Did not find region(s) {regions_not_found} in 'native_regions', "
"'common_regions' or 'exclude_regions' in model mapping for "
f"{self.model} in {self.file}. If they are not meant to be included "
"in the results add to the 'exclude_regions' section in the model "
"mapping to silence this error."
)


class RegionProcessor(BaseModel):
"""Region aggregation mappings for scenario processing"""
Expand Down Expand Up @@ -397,7 +413,7 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:
f"{self.mappings[model].file}"
)
# Check for regions not mentioned in the model mapping
_check_unexpected_regions(model_df, self.mappings[model])
self.mappings[model].check_unexpected_regions(model_df)
_processed_dfs = []

# Silence pyam's empty filter warnings
Expand All @@ -414,16 +430,7 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:

# Aggregate
if self.mappings[model].common_regions is not None:
vars = self._filter_dict_args(model_df.variable, dsd)
vars_default_args = [
var for var, kwargs in vars.items() if not kwargs
]
# TODO skip if required weight does not exist
vars_kwargs = {
var: kwargs
for var, kwargs in vars.items()
if var not in vars_default_args
}

for cr in self.mappings[model].common_regions:
# If the common region is only comprised of a single model
# native region, just rename
Expand All @@ -434,35 +441,48 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:
).rename(region=cr.rename_dict)
)
continue

# if there are multiple constituent regions, aggregate
regions = [cr.name, cr.constituent_regions]

# First, perform 'simple' aggregation (no arguments)
_processed_dfs.append(
model_df.aggregate_region(vars_default_args, *regions)
model_df.aggregate_region(
[
var.name
for var in dsd.variable.vars_default_args(
df.variable
)
],
*regions,
)
)

# Second, special weighted aggregation
for var, kwargs in vars_kwargs.items():
if "region_aggregation" not in kwargs:
for var in dsd.variable.vars_kwargs(df.variable):
if var.region_aggregation is None:
_df = _aggregate_region(
model_df,
var,
var.name,
*regions,
**kwargs,
**var.pyam_agg_kwargs,
)
if _df is not None and not _df.empty:
_processed_dfs.append(_df)
else:
for rename_var in kwargs["region_aggregation"]:
for rename_var in var.region_aggregation:
for _rename, _kwargs in rename_var.items():
_df = _aggregate_region(
model_df,
var,
var.name,
*regions,
**_kwargs,
)
if _df is not None and not _df.empty:
_processed_dfs.append(
_df.rename(variable={var: _rename})
_df.rename(
variable={var.name: _rename}
)
)

common_region_df = model_df.filter(
Expand All @@ -488,19 +508,6 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:

return pyam.concat(processed_dfs)

def _filter_dict_args(
self, variables, dsd: DataStructureDefinition, keys: Set[str] = AGG_KWARGS
) -> Dict[str, Dict]:
return {
name: {
key: value
for key, value in code.dict().items()
if key in keys and value is not None
}
for name, code in dsd.variable.items()
if name in variables and not code.skip_region_aggregation
}


def _aggregate_region(df, var, *regions, **kwargs):
"""Perform region aggregation with kwargs catching inconsistent-index errors"""
Expand Down Expand Up @@ -549,28 +556,3 @@ def _check_exclude_region_overlap(values: Dict, region_type: str) -> Dict:
region=overlap, region_type=region_type, file=values["file"]
)
return values


def _check_unexpected_regions(
df: IamDataFrame, mapping: RegionAggregationMapping
) -> None:
# Raise value error if a region in the input data is not mentioned in the model
# mapping

if regions_not_found := set(df.region) - set(
mapping.model_native_region_names
+ mapping.common_region_names
+ [
const_reg
for comm_reg in mapping.common_regions or []
for const_reg in comm_reg.constituent_regions
]
+ (mapping.exclude_regions or [])
):
raise ValueError(
f"Did not find region(s) {regions_not_found} in 'native_regions', "
"'common_regions' or 'exclude_regions' in model mapping for "
f"{mapping.model} in {mapping.file}. If they are not meant to be included "
"in the results add to the 'exclude_regions' section in the model mapping "
"to silence this error."
)

0 comments on commit 387cb8a

Please sign in to comment.