Skip to content

Commit

Permalink
Updated metrics serialization (#3095)
Browse files Browse the repository at this point in the history
* Added _tree_map and _tree_apply2 to simplify checkpoint/metric code

* Fixed Metric tests

* Added state_dict support for MetricsLambda and contrib metrics

* Few other fixes

* Test fix

* More fixes in Metrics for distributed mode
  • Loading branch information
vfdev-5 authored Oct 17, 2023
1 parent 29ebf54 commit f3124c9
Show file tree
Hide file tree
Showing 25 changed files with 480 additions and 169 deletions.
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/canberra_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CanberraMetric(_BaseRegression):
- Fixed implementation: ``abs`` in denominator.
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors",)

@reinit__is_reduced
def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FractionalAbsoluteError(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/fractional_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FractionalBias(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class GeometricMeanAbsoluteError(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
0.0...
"""
_state_dict_all_req_keys = ("_predictions", "_targets")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/manhattan_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ManhattanDistance(_BaseRegression):
- Fixed sklearn compatibility.
- Workes with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors",)

@reinit__is_reduced
def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MaximumAbsoluteError(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_max_of_absolute_errors",)

@reinit__is_reduced
def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MeanAbsoluteRelativeError(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_absolute_relative_errors", "_num_samples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MeanError(_BaseRegression):
0.625...
"""
_state_dict_all_req_keys = ("_sum_of_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/mean_normalized_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MeanNormalizedBias(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/r2_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class R2Score(_BaseRegression):
.. versionchanged:: 0.4.3
Works with DDP.
"""
_state_dict_all_req_keys = ("_num_examples", "_sum_of_errors", "_y_sq_sum", "_y_sum")

@reinit__is_reduced
def reset(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/wave_hedges_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class WaveHedgesDistance(_BaseRegression):
.. versionchanged:: 0.4.5
- Works with DDP.
"""
_state_dict_all_req_keys = ("_sum_of_errors",)

@reinit__is_reduced
def reset(self) -> None:
Expand Down
51 changes: 19 additions & 32 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Mapping, NamedTuple, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -22,6 +22,7 @@
import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events
from ignite.utils import _tree_apply2, _tree_map

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]

Expand Down Expand Up @@ -276,7 +277,7 @@ class Checkpoint(Serializable):
"""

Item = NamedTuple("Item", [("priority", int), ("filename", str)])
_state_dict_all_req_keys = ("saved",)
_state_dict_all_req_keys = ("_saved",)

def __init__(
self,
Expand Down Expand Up @@ -465,24 +466,19 @@ def __call__(self, engine: Engine) -> None:
except TypeError:
self.save_handler(checkpoint, filename)

def _setup_checkpoint_recursive(self, objs: Mapping) -> Dict[str, Dict[Any, Any]]:
checkpoint = {}
for k, obj in objs.items():
if isinstance(obj, Mapping):
checkpoint[k] = self._setup_checkpoint_recursive(obj)
else:
def _setup_checkpoint(self) -> Dict[str, Any]:
if self.to_save is not None:

def func(obj: Any, **kwargs: Any) -> Dict:
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer):
obj.consolidate_state_dict(to=self.save_on_rank)
if self.save_on_rank != idist.get_rank():
continue
checkpoint[k] = obj.state_dict()
return checkpoint
return {}
return obj.state_dict()

def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
if self.to_save is not None:
return self._setup_checkpoint_recursive(self.to_save)
return cast(Dict[str, Any], _tree_map(func, self.to_save))
return {}

@staticmethod
Expand Down Expand Up @@ -538,12 +534,12 @@ def setup_filename_pattern(

@staticmethod
def _check_objects(objs: Mapping, attr: str) -> None:
for obj in objs.values():
if isinstance(obj, Mapping):
Checkpoint._check_objects(obj, attr=attr)
elif not hasattr(obj, attr):
def func(obj: Any, **kwargs: Any) -> None:
if not hasattr(obj, attr):
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

_tree_map(func, objs)

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
Expand Down Expand Up @@ -625,17 +621,7 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
_load_object(obj, checkpoint_obj)
return

def _load_objects_recursive(objs: Mapping, chkpt_objs: Mapping) -> None:
for k, obj in objs.items():
if k not in chkpt_objs:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
if isinstance(obj, Mapping):
_load_objects_recursive(obj, chkpt_objs[k])
else:
_load_object(obj, chkpt_objs[k])

# multiple objects to load
_load_objects_recursive(to_load, checkpoint_obj)
_tree_apply2(_load_object, to_load, checkpoint_obj)

def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
Expand Down Expand Up @@ -721,11 +707,12 @@ def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, *

Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
def state_dict(self) -> OrderedDict:
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Can be used to save internal state of the class.
"""
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])
# TODO: this method should use _state_dict_all_req_keys
return OrderedDict([("_saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.
Expand All @@ -734,7 +721,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
"""
super().load_state_dict(state_dict)
self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]]
self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["_saved"]]

@staticmethod
def get_default_score_fn(metric_name: str, score_sign: float = 1.0) -> Callable:
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def _wrapper(
dict_obj[_get_label_for_class(idx)] = {
"precision": p_label.item(),
"recall": re[idx].item(),
"f{0}-score".format(beta): f[idx].item(),
f"f{beta}-score": f[idx].item(),
}
dict_obj["macro avg"] = {
"precision": a_pr.item(),
"recall": a_re.item(),
"f{0}-score".format(beta): a_f.item(),
f"f{beta}-score": a_f.item(),
}
return dict_obj if output_dict else json.dumps(dict_obj)

Expand Down
125 changes: 71 additions & 54 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ignite.base.mixins import Serializable
from ignite.engine import CallableEventWithFilter, Engine, Events
from ignite.utils import _CollectionItem, _tree_apply2, _tree_map

if TYPE_CHECKING:
from ignite.metrics.metrics_lambda import MetricsLambda
Expand Down Expand Up @@ -549,47 +550,62 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.
def _state_dict_per_rank(self) -> OrderedDict:
def func(
x: Union[torch.Tensor, Metric, None, float], **kwargs: Any
) -> Union[torch.Tensor, float, OrderedDict, None]:
if isinstance(x, Metric):
return x._state_dict_per_rank()
if x is None or isinstance(x, (int, float, torch.Tensor)):
return x
else:
raise TypeError(
"Found attribute of unsupported type. Currently, supported types include"
" numeric types, tensor, Metric or sequence/mapping of metrics."
)

If there's an active distributed configuration, some collective operations is done and
the list of values across ranks is saved under each attribute's name in the dict, for numeric
and tensor values.
"""
state: Dict[str, Union[torch.Tensor, List, Dict, None]] = OrderedDict()
state: OrderedDict[str, Union[torch.Tensor, List, Dict, None]] = OrderedDict()
for attr_name in self._state_dict_all_req_keys:
if attr_name not in self.__dict__:
raise ValueError(
f"Found a value in _state_dict_all_req_keys that is not among metric attributes: {attr_name}"
)
attr = getattr(self, attr_name)
if isinstance(attr, Mapping):
state[attr_name] = {k: m.state_dict() for k, m in attr.items()}
elif isinstance(attr, Sequence):
state[attr_name] = [m.state_dict() for m in attr]
elif isinstance(attr, Metric):
state[attr_name] = attr.state_dict()
elif isinstance(attr, (int, float, torch.Tensor)):
if idist.get_world_size() == 1:
state[attr_name] = [attr]
state[attr_name] = _tree_map(func, attr) # type: ignore[assignment]

return state

__state_dict_key_per_rank: str = "__metric_state_per_rank"

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.
"""
state = self._state_dict_per_rank()

if idist.get_world_size() > 1:
return OrderedDict([(Metric.__state_dict_key_per_rank, idist.all_gather(state))])
return OrderedDict([(Metric.__state_dict_key_per_rank, [state])])

def _load_state_dict_per_rank(self, state_dict: Mapping) -> None:
super().load_state_dict(state_dict)

def func(x: Any, y: Any) -> None:
if isinstance(x, Metric):
x._load_state_dict_per_rank(y)
elif isinstance(x, _CollectionItem):
value = x.value()
if y is None or isinstance(y, _CollectionItem.types_as_collection_item):
x.load_value(y)
elif isinstance(value, Metric):
value._load_state_dict_per_rank(y)
else:
if isinstance(attr, (int, float)):
attr_type = type(attr)
attr = float(attr)
gathered_attr = idist.all_gather(attr)
if isinstance(attr, float):
gathered_attr = [attr_type(process_attr) for process_attr in cast(torch.Tensor, gathered_attr)]
state[attr_name] = cast(Union[torch.Tensor, List], gathered_attr)
# Some attributes might be `None` upon serialization e.g. `RunningAverage`'s initial `_value`.
elif attr is None:
state[attr_name] = None
else:
raise TypeError(
"Found attribute of unsupported type. Currently, supported types include"
" numeric types, tensor, Metric or sequence/mapping of metrics."
)
return cast(OrderedDict, state)
raise ValueError(f"Unsupported type for provided state_dict data: {type(y)}")

for attr_name in self._state_dict_all_req_keys:
attr = getattr(self, attr_name)
attr = _CollectionItem.wrap(self.__dict__, attr_name, attr)
_tree_apply2(func, attr, state_dict[attr_name])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.
Expand All @@ -601,28 +617,29 @@ def load_state_dict(self, state_dict: Mapping) -> None:
state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys`
attribute.
"""
super().load_state_dict(state_dict)
if not isinstance(state_dict, Mapping):
raise TypeError(f"Argument state_dict should be a dictionary, but given {type(state_dict)}")

if not (len(state_dict) == 1 and Metric.__state_dict_key_per_rank in state_dict):
raise ValueError(
"Incorrect state_dict object. Argument state_dict should be a dictionary "
"provided by Metric.state_dict(). "
f"Expected single key: {Metric.__state_dict_key_per_rank}, but given {state_dict.keys()}"
)

list_state_dicts_per_rank = state_dict[Metric.__state_dict_key_per_rank]
rank = idist.get_rank()
for attr_name in self._state_dict_all_req_keys:
attr = getattr(self, attr_name)
if isinstance(attr, Mapping):
for metric_name in attr:
attr[metric_name].load_state_dict(state_dict[attr_name][metric_name])
elif isinstance(attr, Sequence):
for i, metric in enumerate(attr):
metric.load_state_dict(state_dict[attr_name][i])
elif isinstance(attr, Metric):
attr.load_state_dict(state_dict[attr_name])
elif state_dict[attr_name] is None:
setattr(self, attr_name, None)
else:
world_size = idist.get_world_size()
len_rank_slice = len(state_dict[attr_name]) // world_size
if len_rank_slice == 1:
setattr(self, attr_name, state_dict[attr_name][rank])
else:
rank_slice = slice(rank * len_rank_slice, (rank + 1) * len_rank_slice)
setattr(self, attr_name, state_dict[attr_name][rank_slice])
world_size = idist.get_world_size()
if len(list_state_dicts_per_rank) != world_size:
raise ValueError(
"Incorrect state_dict object. Argument state_dict should be a dictionary "
"provided by Metric.state_dict(). "
f"Expected a list of state_dicts of size equal world_size: {world_size}, "
f"but got {len(list_state_dicts_per_rank)}"
)

state_dict = list_state_dicts_per_rank[rank]
self._load_state_dict_per_rank(state_dict)

def __add__(self, other: Any) -> "MetricsLambda":
from ignite.metrics.metrics_lambda import MetricsLambda
Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/metrics_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def Fbeta(r, p, beta):
assert not precision.is_attached(engine)
"""

_state_dict_all_req_keys = ("_updated", "args", "kwargs")

def __init__(self, f: Callable, *args: Any, **kwargs: Any) -> None:
self.function = f
self.args = args
self.args = list(args) # we need args to be a list instead of a tuple for state_dict/load_state_dict feature
self.kwargs = kwargs
self.engine: Optional[Engine] = None
self._updated = False
Expand Down
Loading

0 comments on commit f3124c9

Please sign in to comment.