Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RegionProcessor #213

Merged
merged 12 commits into from
Jan 10, 2023
23 changes: 23 additions & 0 deletions nomenclature/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class VariableCode(Code):
method: Optional[str] = None
check_aggregate: Optional[bool] = Field(False, alias="check-aggregate")
components: Optional[Union[List[str], List[Dict[str, List[str]]]]] = None
drop_negative_weights: Optional[bool] = None

class Config:
# this allows using both "check_aggregate" and "check-aggregate" for attribute
Expand All @@ -153,3 +154,25 @@ def named_attributes(cls) -> Set[str]:
.named_attributes()
.union(f.alias for f in cls.__dict__["__fields__"].values())
)

@property
def pyam_agg_kwargs(self) -> Dict[str, Any]:
# return a dict of all not None pyam aggregation properties
return {
field: getattr(self, field)
for field in (
"components",
phackstock marked this conversation as resolved.
Show resolved Hide resolved
"method",
"weight",
"drop_negative_weights",
)
if getattr(self, field) is not None
}

@property
def agg_kwargs(self) -> Dict[str, Any]:
return (
{**self.pyam_agg_kwargs, **{"region_aggregation": self.region_aggregation}}
if self.region_aggregation is not None
else self.pyam_agg_kwargs
)
86 changes: 44 additions & 42 deletions nomenclature/codelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@
VariableRenameTargetError,
)

# arguments of the method `pyam.IamDataFrame.aggregate_region`
# required for checking validity of variable-CodeList-attributes
PYAM_AGG_KWARGS = [
"components",
"method",
"weight",
"drop_negative_weights",
]

here = Path(__file__).parent.absolute()

Expand Down Expand Up @@ -401,43 +393,35 @@ class VariableCodeList(CodeList):
@validator("mapping")
def check_variable_region_aggregation_args(cls, v):
"""Check that any variable "region-aggregation" mappings are valid"""
items = [
(name, code)
for (name, code) in v.items()
if code.region_aggregation is not None
]

for (name, code) in items:
# ensure that there no pyam-aggregation-kwargs and
conflict_args = [
i
for i, val in code.dict().items()
if i in PYAM_AGG_KWARGS and val is not None
]
if conflict_args:
raise VariableRenameArgError(
variable=name,
file=code.file,
args=conflict_args,
)

# ensure that mapped variables are defined in the nomenclature
invalid = []
for inst in code.region_aggregation:
invalid.extend(var for var in inst if var not in v)
if invalid:
raise VariableRenameTargetError(
variable=name, file=code.file, target=invalid
)
for var in v.values():
# ensure that a variable does not have both pyam-aggregation-kwargs and
# region-aggregation
phackstock marked this conversation as resolved.
Show resolved Hide resolved
if var.region_aggregation is not None:
if conflict_args := list(var.pyam_agg_kwargs.keys()):
raise VariableRenameArgError(
variable=var.name,
file=var.file,
args=conflict_args,
)

# ensure that mapped variables are defined in the nomenclature
invalid = []
for inst in var.region_aggregation:
invalid.extend(var for var in inst if var not in v)
if invalid:
raise VariableRenameTargetError(
variable=var.name, file=var.file, target=invalid
)
return v

@validator("mapping")
def check_weight_in_vars(cls, v):
# Check that all variables specified in 'weight' are present in the codelist
if missing_weights := [
(name, code.weight, code.file)
for name, code in v.items()
if code.weight is not None and code.weight not in v
(var.name, var.weight, var.file)
for var in v.values()
if var.weight is not None and var.weight not in v
]:
raise MissingWeightError(
missing_weights="".join(
Expand All @@ -451,15 +435,33 @@ def check_weight_in_vars(cls, v):
def cast_variable_components_args(cls, v):
"""Cast "components" list of dicts to a codelist"""
# translate a list of single-key dictionaries to a simple dictionary
for name, code in v.items():
if code.components and isinstance(code.components[0], dict):
for var in v.values():
if var.components and isinstance(var.components[0], dict):
comp = {}
for val in code.components:
for val in var.components:
comp.update(val)
v[name].components = comp
v[var.name].components = comp

return v

def vars_default_args(self, variables: List[str]) -> List[VariableCode]:
# return subset of variables which does not feature any special pyam aggregation
phackstock marked this conversation as resolved.
Show resolved Hide resolved
# arguments and where skip_region_aggregation is False
return [
self[var]
for var in variables
if not self[var].agg_kwargs and not self[var].skip_region_aggregation
]

def vars_kwargs(self, variables: List[str]) -> List[VariableCode]:
phackstock marked this conversation as resolved.
Show resolved Hide resolved
# return subset of variables which features special pyam aggregation arguments
# and where skip_region_aggregation is False
return [
self[var]
for var in variables
if self[var].agg_kwargs and not self[var].skip_region_aggregation
]


class RegionCodeList(CodeList):
"""A subclass of CodeList specified for regions
Expand Down
111 changes: 47 additions & 64 deletions nomenclature/processor/region.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Union

import jsonschema
import pyam
Expand All @@ -13,7 +13,6 @@
from pydantic.error_wrappers import ErrorWrapper
from pydantic.types import DirectoryPath, FilePath

from nomenclature.codelist import PYAM_AGG_KWARGS
from nomenclature.definition import DataStructureDefinition
from nomenclature.error.region import (
ExcludeRegionOverlapError,
Expand All @@ -23,8 +22,6 @@
)
from nomenclature.processor.utils import get_relative_path

AGG_KWARGS = PYAM_AGG_KWARGS + ["region_aggregation"]

logger = logging.getLogger(__name__)

here = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -289,11 +286,31 @@ def rename_mapping(self) -> Dict[str, str]:

def validate_regions(self, dsd: DataStructureDefinition) -> None:
if hasattr(dsd, "region"):
# invalid = [c for c in self.all_regions if c not in dsd.region]
invalid = dsd.region.validate_items(self.all_regions)
if invalid:
if invalid := dsd.region.validate_items(self.all_regions):
raise RegionNotDefinedError(region=invalid, file=self.file)

def check_unexpected_regions(self, df: IamDataFrame) -> None:
# Raise value error if a region in the input data is not mentioned in the model
# mapping
phackstock marked this conversation as resolved.
Show resolved Hide resolved

if regions_not_found := set(df.region) - set(
self.model_native_region_names
+ self.common_region_names
+ [
const_reg
for comm_reg in self.common_regions or []
for const_reg in comm_reg.constituent_regions
]
+ (self.exclude_regions or [])
):
raise ValueError(
f"Did not find region(s) {regions_not_found} in 'native_regions', "
"'common_regions' or 'exclude_regions' in model mapping for "
f"{self.model} in {self.file}. If they are not meant to be included "
"in the results add to the 'exclude_regions' section in the model "
"mapping to silence this error."
)


class RegionProcessor(BaseModel):
"""Region aggregation mappings for scenario processing"""
Expand Down Expand Up @@ -397,7 +414,7 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:
f"{self.mappings[model].file}"
)
# Check for regions not mentioned in the model mapping
_check_unexpected_regions(model_df, self.mappings[model])
self.mappings[model].check_unexpected_regions(model_df)
_processed_dfs = []

# Silence pyam's empty filter warnings
Expand All @@ -414,16 +431,7 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:

# Aggregate
if self.mappings[model].common_regions is not None:
vars = self._filter_dict_args(model_df.variable, dsd)
vars_default_args = [
var for var, kwargs in vars.items() if not kwargs
]
# TODO skip if required weight does not exist
vars_kwargs = {
var: kwargs
for var, kwargs in vars.items()
if var not in vars_default_args
}

for cr in self.mappings[model].common_regions:
# If the common region is only comprised of a single model
# native region, just rename
Expand All @@ -434,35 +442,48 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:
).rename(region=cr.rename_dict)
)
continue

# if there are multiple constituent regions, aggregate
regions = [cr.name, cr.constituent_regions]

# First, perform 'simple' aggregation (no arguments)
_processed_dfs.append(
model_df.aggregate_region(vars_default_args, *regions)
model_df.aggregate_region(
[
var.name
for var in dsd.variable.vars_default_args(
df.variable
)
],
*regions,
)
)

# Second, special weighted aggregation
for var, kwargs in vars_kwargs.items():
if "region_aggregation" not in kwargs:
for var in dsd.variable.vars_kwargs(df.variable):
if var.region_aggregation is None:
_df = _aggregate_region(
model_df,
var,
var.name,
*regions,
**kwargs,
**var.pyam_agg_kwargs,
)
if _df is not None and not _df.empty:
_processed_dfs.append(_df)
else:
for rename_var in kwargs["region_aggregation"]:
for rename_var in var.region_aggregation:
for _rename, _kwargs in rename_var.items():
_df = _aggregate_region(
model_df,
var,
var.name,
*regions,
**_kwargs,
)
if _df is not None and not _df.empty:
_processed_dfs.append(
_df.rename(variable={var: _rename})
_df.rename(
variable={var.name: _rename}
)
)

common_region_df = model_df.filter(
Expand All @@ -488,19 +509,6 @@ def apply(self, df: IamDataFrame, dsd: DataStructureDefinition) -> IamDataFrame:

return pyam.concat(processed_dfs)

def _filter_dict_args(
self, variables, dsd: DataStructureDefinition, keys: Set[str] = AGG_KWARGS
) -> Dict[str, Dict]:
return {
name: {
key: value
for key, value in code.dict().items()
if key in keys and value is not None
}
for name, code in dsd.variable.items()
if name in variables and not code.skip_region_aggregation
}


def _aggregate_region(df, var, *regions, **kwargs):
"""Perform region aggregation with kwargs catching inconsistent-index errors"""
Expand Down Expand Up @@ -549,28 +557,3 @@ def _check_exclude_region_overlap(values: Dict, region_type: str) -> Dict:
region=overlap, region_type=region_type, file=values["file"]
)
return values


def _check_unexpected_regions(
df: IamDataFrame, mapping: RegionAggregationMapping
) -> None:
# Raise value error if a region in the input data is not mentioned in the model
# mapping

if regions_not_found := set(df.region) - set(
mapping.model_native_region_names
+ mapping.common_region_names
+ [
const_reg
for comm_reg in mapping.common_regions or []
for const_reg in comm_reg.constituent_regions
]
+ (mapping.exclude_regions or [])
):
raise ValueError(
f"Did not find region(s) {regions_not_found} in 'native_regions', "
"'common_regions' or 'exclude_regions' in model mapping for "
f"{mapping.model} in {mapping.file}. If they are not meant to be included "
"in the results add to the 'exclude_regions' section in the model mapping "
"to silence this error."
)