diff --git a/ignite/_utils.py b/ignite/_utils.py index 2cc1c57a9da..0021435c96c 100644 --- a/ignite/_utils.py +++ b/ignite/_utils.py @@ -1,9 +1,10 @@ +from typing import Union, Tuple # For compatibilty from ignite.utils import convert_tensor, apply_to_tensor, apply_to_type, to_onehot -def _to_hours_mins_secs(time_taken): +def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, int]: """Convert seconds to hours, mins, and seconds.""" mins, secs = divmod(time_taken, 60) hours, mins = divmod(mins, 60) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 4675e901aa6..511ccff7076 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -126,7 +126,6 @@ def _setup_common_distrib_training_handlers(trainer, train_sampler=None, lr_scheduler=None, with_gpu_stats=True, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, device='cuda'): - if not (dist.is_available() and dist.is_initialized()): raise RuntimeError("Distributed setting is not initialized, please call `dist.init_process_group` before.") @@ -162,7 +161,6 @@ def empty_cuda_cache(_): def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters): - if optimizers is not None: from torch.optim.optimizer import Optimizer @@ -279,7 +277,6 @@ def setup_plx_logging(trainer, optimizers=None, evaluators=None, log_every_iters def get_default_score_fn(metric_name): - def wrapper(engine): score = engine.state.metrics[metric_name] return score diff --git a/ignite/contrib/engines/tbptt.py b/ignite/contrib/engines/tbptt.py index fb85d58ad31..8e91dc4b216 100644 --- a/ignite/contrib/engines/tbptt.py +++ b/ignite/contrib/engines/tbptt.py @@ -29,14 +29,14 @@ def _detach_hidden(hidden): def create_supervised_tbptt_trainer( - model, - optimizer, - loss_fn, - tbtt_step, - dim=0, - device=None, - non_blocking=False, - prepare_batch=_prepare_batch + model, + optimizer, + loss_fn, + tbtt_step, + dim=0, + device=None, + non_blocking=False, + prepare_batch=_prepare_batch ): """Create a trainer for truncated backprop through time supervised models. diff --git a/ignite/contrib/handlers/__init__.py b/ignite/contrib/handlers/__init__.py index 3ff06ed6840..abb60e882be 100644 --- a/ignite/contrib/handlers/__init__.py +++ b/ignite/contrib/handlers/__init__.py @@ -1,4 +1,3 @@ - from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler, CosineAnnealingScheduler, \ ConcatScheduler, LRScheduler, create_lr_scheduler_with_warmup, PiecewiseLinear, ParamGroupScheduler diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 028f7b87ebb..4c26b6c243a 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -14,6 +14,7 @@ class BaseLogger: Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ... """ + def attach(self, engine, log_handler, event_name): """Attach the logger to the engine and execute `log_handler` function at `event_name` events. diff --git a/ignite/contrib/handlers/custom_events.py b/ignite/contrib/handlers/custom_events.py index d8eff747289..daca6ae3fe0 100644 --- a/ignite/contrib/handlers/custom_events.py +++ b/ignite/contrib/handlers/custom_events.py @@ -73,10 +73,11 @@ def __init__(self, n_iterations=None, n_epochs=None): self.custom_state_attr = "{}_{}".format(prefix, self.period) event_name = "{}_{}".format(prefix.upper(), self.period) - setattr(self, "Events", Enum("Events", " ".join([ - "{}_STARTED".format(event_name), - "{}_COMPLETED".format(event_name)]) - )) + setattr(self, "Events", + Enum("Events", + " ".join(["{}_STARTED".format(event_name), + "{}_COMPLETED".format(event_name)]) + )) # Update State.event_to_attr for e in self.Events: State.event_to_attr[e] = self.custom_state_attr diff --git a/ignite/contrib/handlers/mlflow_logger.py b/ignite/contrib/handlers/mlflow_logger.py index 8a6afdef83b..6be195637e6 100644 --- a/ignite/contrib/handlers/mlflow_logger.py +++ b/ignite/contrib/handlers/mlflow_logger.py @@ -6,7 +6,6 @@ from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler, BaseOptimizerParamsHandler, \ global_step_from_engine - __all__ = ['MLflowLogger', 'OutputHandler', 'OptimizerParamsHandler', 'global_step_from_engine'] @@ -90,6 +89,7 @@ def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name) """ + def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None): super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform) diff --git a/ignite/contrib/handlers/param_scheduler.py b/ignite/contrib/handlers/param_scheduler.py index 11487697cc7..955686501fe 100644 --- a/ignite/contrib/handlers/param_scheduler.py +++ b/ignite/contrib/handlers/param_scheduler.py @@ -782,8 +782,8 @@ def _get_start_end(self): elif self.milestones[-1] <= self.event_index: return self.event_index, self.event_index + 1, self.values[-1], self.values[-1], elif self.milestones[self._index] <= self.event_index < self.milestones[self._index + 1]: - return self.milestones[self._index], self.milestones[self._index + 1], \ - self.values[self._index], self.values[self._index + 1] + return (self.milestones[self._index], self.milestones[self._index + 1], + self.values[self._index], self.values[self._index + 1]) else: self._index += 1 return self._get_start_end() diff --git a/ignite/contrib/handlers/polyaxon_logger.py b/ignite/contrib/handlers/polyaxon_logger.py index 1b93b762476..9b7e45baf1b 100644 --- a/ignite/contrib/handlers/polyaxon_logger.py +++ b/ignite/contrib/handlers/polyaxon_logger.py @@ -6,7 +6,6 @@ from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler, BaseOptimizerParamsHandler, \ global_step_from_engine - __all__ = ['PolyaxonLogger', 'OutputHandler', 'OptimizerParamsHandler', 'global_step_from_engine'] @@ -91,6 +90,7 @@ def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name) """ + def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None): super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform) @@ -225,4 +225,5 @@ def __init__(self): def __getattr__(self, attr): def wrapper(*args, **kwargs): return getattr(self.experiment, attr)(*args, **kwargs) + return wrapper diff --git a/ignite/contrib/handlers/tensorboard_logger.py b/ignite/contrib/handlers/tensorboard_logger.py index 4fa2bbedca1..d75e3ce240d 100644 --- a/ignite/contrib/handlers/tensorboard_logger.py +++ b/ignite/contrib/handlers/tensorboard_logger.py @@ -6,7 +6,6 @@ from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler, \ BaseWeightsScalarHandler, BaseWeightsHistHandler, global_step_from_engine - __all__ = ['TensorboardLogger', 'OptimizerParamsHandler', 'OutputHandler', 'WeightsScalarHandler', 'WeightsHistHandler', 'GradsScalarHandler', 'GradsHistHandler', 'global_step_from_engine'] @@ -92,6 +91,7 @@ def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name) """ + def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None): super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform) @@ -184,6 +184,7 @@ class WeightsScalarHandler(BaseWeightsScalarHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ + def __init__(self, model, reduction=torch.norm, tag=None): super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag) @@ -271,6 +272,7 @@ class GradsScalarHandler(BaseWeightsScalarHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ + def __init__(self, model, reduction=torch.norm, tag=None): super(GradsScalarHandler, self).__init__(model, reduction, tag=tag) @@ -312,6 +314,7 @@ class GradsHistHandler(BaseWeightsHistHandler): tag (str, optional): common title for all produced plots. For example, 'generator' """ + def __init__(self, model, tag=None): super(GradsHistHandler, self).__init__(model, tag=tag) diff --git a/ignite/contrib/handlers/tqdm_logger.py b/ignite/contrib/handlers/tqdm_logger.py index 46a467c3ffc..cf6b48bafb2 100644 --- a/ignite/contrib/handlers/tqdm_logger.py +++ b/ignite/contrib/handlers/tqdm_logger.py @@ -203,6 +203,7 @@ class _OutputHandler(BaseOutputHandler): :meth:`~ignite.engine.Engine.register_events`. """ + def __init__(self, description, metric_names=None, output_transform=None, closing_event_name=Events.EPOCH_COMPLETED): if metric_names is None and output_transform is None: diff --git a/ignite/contrib/handlers/visdom_logger.py b/ignite/contrib/handlers/visdom_logger.py index f48ca0cd74e..8aab03b5f4c 100644 --- a/ignite/contrib/handlers/visdom_logger.py +++ b/ignite/contrib/handlers/visdom_logger.py @@ -7,7 +7,6 @@ from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler, \ BaseWeightsScalarHandler, global_step_from_engine - __all__ = ['VisdomLogger', 'OptimizerParamsHandler', 'OutputHandler', 'WeightsScalarHandler', 'GradsScalarHandler', 'global_step_from_engine'] @@ -142,6 +141,7 @@ def global_step_transform(engine, event_name): return engine.state.get_event_attrib_value(event_name) """ + def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, show_legend=False): super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform) @@ -473,7 +473,6 @@ def close(self): class _DummyExecutor: - class _DummyFuture: def __init__(self, result): diff --git a/ignite/contrib/metrics/average_precision.py b/ignite/contrib/metrics/average_precision.py index 9cde2ea6582..5e81bcefdaf 100644 --- a/ignite/contrib/metrics/average_precision.py +++ b/ignite/contrib/metrics/average_precision.py @@ -36,5 +36,6 @@ def activated_output_transform(output): avg_precision = AveragePrecision(activated_output_transform) """ + def __init__(self, output_transform=lambda x: x): super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform) diff --git a/ignite/contrib/metrics/regression/fractional_bias.py b/ignite/contrib/metrics/regression/fractional_bias.py index 71f4803f64e..d9b35fd9750 100644 --- a/ignite/contrib/metrics/regression/fractional_bias.py +++ b/ignite/contrib/metrics/regression/fractional_bias.py @@ -22,6 +22,7 @@ class FractionalBias(_BaseRegression): __ https://arxiv.org/abs/1809.03006 """ + def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 diff --git a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py index 1d33b2770c6..7ad389e9845 100644 --- a/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py @@ -22,6 +22,7 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression): __ https://arxiv.org/abs/1809.03006 """ + def reset(self): self._sum_y = 0.0 self._num_examples = 0 diff --git a/ignite/contrib/metrics/regression/manhattan_distance.py b/ignite/contrib/metrics/regression/manhattan_distance.py index 38c1d913c4c..044057084eb 100644 --- a/ignite/contrib/metrics/regression/manhattan_distance.py +++ b/ignite/contrib/metrics/regression/manhattan_distance.py @@ -21,6 +21,7 @@ class ManhattanDistance(_BaseRegression): __ https://arxiv.org/abs/1809.03006 """ + def reset(self): self._sum_of_errors = 0.0 diff --git a/ignite/contrib/metrics/regression/mean_error.py b/ignite/contrib/metrics/regression/mean_error.py index 70d31bf0398..cf839acf64e 100644 --- a/ignite/contrib/metrics/regression/mean_error.py +++ b/ignite/contrib/metrics/regression/mean_error.py @@ -22,6 +22,7 @@ class MeanError(_BaseRegression): __ https://arxiv.org/abs/1809.03006 """ + def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 diff --git a/ignite/contrib/metrics/regression/mean_normalized_bias.py b/ignite/contrib/metrics/regression/mean_normalized_bias.py index 88e48fc339b..2a558ee1277 100644 --- a/ignite/contrib/metrics/regression/mean_normalized_bias.py +++ b/ignite/contrib/metrics/regression/mean_normalized_bias.py @@ -22,6 +22,7 @@ class MeanNormalizedBias(_BaseRegression): __ https://arxiv.org/abs/1809.03006 """ + def reset(self): self._sum_of_errors = 0.0 self._num_examples = 0 diff --git a/ignite/contrib/metrics/regression/median_absolute_error.py b/ignite/contrib/metrics/regression/median_absolute_error.py index bdbddd67e5b..08c0741ce0d 100644 --- a/ignite/contrib/metrics/regression/median_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_error.py @@ -30,5 +30,6 @@ class MedianAbsoluteError(_BaseRegressionEpoch): __ https://arxiv.org/abs/1809.03006 """ + def __init__(self, output_transform=lambda x: x): super(MedianAbsoluteError, self).__init__(median_absolute_error_compute_fn, output_transform) diff --git a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py index 78840a08a30..f9bada485ec 100644 --- a/ignite/contrib/metrics/regression/median_absolute_percentage_error.py +++ b/ignite/contrib/metrics/regression/median_absolute_percentage_error.py @@ -32,6 +32,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch): __ https://arxiv.org/abs/1809.03006 """ + def __init__(self, output_transform=lambda x: x): super(MedianAbsolutePercentageError, self).__init__(median_absolute_percentage_error_compute_fn, output_transform) diff --git a/ignite/contrib/metrics/regression/median_relative_absolute_error.py b/ignite/contrib/metrics/regression/median_relative_absolute_error.py index 373ddc87412..7165f619802 100644 --- a/ignite/contrib/metrics/regression/median_relative_absolute_error.py +++ b/ignite/contrib/metrics/regression/median_relative_absolute_error.py @@ -32,6 +32,7 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch): __ https://arxiv.org/abs/1809.03006 """ + def __init__(self, output_transform=lambda x: x): super(MedianRelativeAbsoluteError, self).__init__(median_relative_absolute_error_compute_fn, output_transform) diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 6bcdd3b0ca6..5c84e9c7393 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -19,6 +19,7 @@ class R2Score(_BaseRegression): - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`. """ + def reset(self): self._num_examples = 0 self._sum_of_errors = 0 diff --git a/ignite/contrib/metrics/roc_auc.py b/ignite/contrib/metrics/roc_auc.py index 25e39cdd8e0..dfbb8038926 100644 --- a/ignite/contrib/metrics/roc_auc.py +++ b/ignite/contrib/metrics/roc_auc.py @@ -37,5 +37,6 @@ def activated_output_transform(output): roc_auc = ROC_AUC(activated_output_transform) """ + def __init__(self, output_transform=lambda x: x): super(ROC_AUC, self).__init__(roc_auc_compute_fn, output_transform=output_transform) diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 254abe9b84f..528445552e1 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -1,8 +1,10 @@ +from typing import Sequence, Union, Optional, Callable, Dict, Any, Tuple import torch from ignite.engine.engine import Engine from ignite.engine.events import State, Events from ignite.utils import convert_tensor +from ignite.metrics import Metric __all__ = [ 'create_supervised_trainer', @@ -12,7 +14,8 @@ ] -def _prepare_batch(batch, device=None, non_blocking=False): +def _prepare_batch(batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False): """Prepare batch for training: pass to a device with options. """ @@ -21,10 +24,11 @@ def _prepare_batch(batch, device=None, non_blocking=False): convert_tensor(y, device=device, non_blocking=non_blocking)) -def create_supervised_trainer(model, optimizer, loss_fn, - device=None, non_blocking=False, - prepare_batch=_prepare_batch, - output_transform=lambda x, y, y_pred, loss: loss.item()): +def create_supervised_trainer(model: torch.nn.Module, optimizer: torch.optim.Optimizer, + loss_fn: Union[Callable, torch.nn.Module], + device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + prepare_batch: Callable = _prepare_batch, + output_transform: Callable = lambda x, y, y_pred, loss: loss.item()) -> Engine: """ Factory function for creating a trainer for supervised models. @@ -50,7 +54,7 @@ def create_supervised_trainer(model, optimizer, loss_fn, if device: model.to(device) - def _update(engine, batch): + def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) @@ -63,10 +67,10 @@ def _update(engine, batch): return Engine(_update) -def create_supervised_evaluator(model, metrics=None, - device=None, non_blocking=False, - prepare_batch=_prepare_batch, - output_transform=lambda x, y, y_pred: (y_pred, y,)): +def create_supervised_evaluator(model: torch.nn.Module, metrics: Optional[Dict[str, Metric]] = None, + device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, + prepare_batch: Callable = _prepare_batch, + output_transform: Callable = lambda x, y, y_pred: (y_pred, y,)) -> Engine: """ Factory function for creating an evaluator for supervised models. @@ -94,7 +98,7 @@ def create_supervised_evaluator(model, metrics=None, if device: model.to(device) - def _inference(engine, batch): + def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.eval() with torch.no_grad(): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index f3bd150b34e..166275eeeab 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import logging import time @@ -6,6 +8,8 @@ import weakref import random import warnings +from typing import Union, Optional, Callable, Iterable, Iterator, Any, \ + Tuple import torch @@ -14,9 +18,7 @@ from ignite._utils import _to_hours_mins_secs __all__ = [ - 'Engine', - 'Events', - 'State' + 'Engine' ] @@ -123,7 +125,7 @@ def compute_mean_std(engine, batch): _state_dict_all_req_keys = ("seed", "epoch_length", "max_epochs") _state_dict_one_of_opt_keys = ("iteration", "epoch") - def __init__(self, process_function): + def __init__(self, process_function: Callable): self._event_handlers = defaultdict(list) self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self._process_function = process_function @@ -143,7 +145,7 @@ def __init__(self, process_function): _check_signature(self, process_function, 'process_function', None) - def register_events(self, *event_names, **kwargs): + def register_events(self, *event_names: Union[str, int, Any], **kwargs) -> None: """Add events that can be fired. Registering an event will let the user fire these events at any point. @@ -205,9 +207,9 @@ class TBPTT_Events(CallableEvents, Enum): State.event_to_attr[e] = event_to_attr[e] @staticmethod - def _handler_wrapper(handler, event_name, event_filter): + def _handler_wrapper(handler: Callable, event_name: str, event_filter: Callable) -> Callable: - def wrapper(engine, *args, **kwargs): + def wrapper(engine: Engine, *args, **kwargs) -> Any: event = engine.state.get_event_attrib_value(event_name) if event_filter(engine, event): return handler(engine, *args, **kwargs) @@ -216,7 +218,7 @@ def wrapper(engine, *args, **kwargs): wrapper._parent = weakref.ref(handler) return wrapper - def add_event_handler(self, event_name, handler, *args, **kwargs): + def add_event_handler(self, event_name: str, handler: Callable, *args, **kwargs): """Add an event handler to be executed when the specified event is fired. Args: @@ -261,7 +263,7 @@ def print_epoch(engine): self.logger.error("attempt to add event handler to an invalid event %s.", event_name) raise ValueError("Event {} is not a valid event for this Engine.".format(event_name)) - event_args = (Exception(), ) if event_name == Events.EXCEPTION_RAISED else () + event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else () _check_signature(self, handler, 'handler', *(event_args + args), **kwargs) self._event_handlers[event_name].append((handler, args, kwargs)) @@ -270,12 +272,12 @@ def print_epoch(engine): return RemovableEventHandle(event_name, handler, self) @staticmethod - def _assert_non_callable_event(event_name): + def _assert_non_callable_event(event_name: str): if isinstance(event_name, EventWithFilter): raise TypeError("Argument event_name should not be a callable event, " "please use event without any event filtering") - def has_event_handler(self, handler, event_name=None): + def has_event_handler(self, handler: Callable, event_name: Optional[str] = None): """Check if the specified event has the specified handler. Args: @@ -298,12 +300,12 @@ def has_event_handler(self, handler, event_name=None): return False @staticmethod - def _compare_handlers(user_handler, registered_handler): + def _compare_handlers(user_handler: Callable, registered_handler: Callable) -> bool: if hasattr(registered_handler, "_parent"): registered_handler = registered_handler._parent() return registered_handler == user_handler - def remove_event_handler(self, handler, event_name): + def remove_event_handler(self, handler: Callable, event_name: str): """Remove event handler `handler` from registered handlers of the engine Args: @@ -331,12 +333,14 @@ def on(self, event_name, *args, **kwargs): **kwargs: optional keyword args to be passed to `handler`. """ - def decorator(f): + + def decorator(f: Callable) -> Callable: self.add_event_handler(event_name, f, *args, **kwargs) return f + return decorator - def _fire_event(self, event_name, *event_args, **event_kwargs): + def _fire_event(self, event_name: str, *event_args, **event_kwargs) -> None: """Execute all the handlers associated with given event. This method executes all handlers associated with the event @@ -359,7 +363,7 @@ def _fire_event(self, event_name, *event_args, **event_kwargs): kwargs.update(event_kwargs) func(self, *(event_args + args), **kwargs) - def fire_event(self, event_name): + def fire_event(self, event_name: str) -> None: """Execute all the handlers associated with given event. This method executes all handlers associated with the event @@ -382,20 +386,20 @@ def fire_event(self, event_name): """ return self._fire_event(event_name) - def terminate(self): + def terminate(self) -> None: """Sends terminate signal to the engine, so that it terminates completely the run after the current iteration. """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True - def terminate_epoch(self): + def terminate_epoch(self) -> None: """Sends terminate signal to the engine, so that it terminates the current epoch after the current iteration. """ self.logger.info("Terminate current epoch is signaled. " "Current epoch iteration will stop after current iteration is finished.") self.should_terminate_single_epoch = True - def _run_once_on_dataset(self): + def _run_once_on_dataset(self) -> Tuple[int, int, int]: start_time = time.time() # We need to setup iter_counter > 0 if we resume from an iteration @@ -462,13 +466,13 @@ def _run_once_on_dataset(self): return hours, mins, secs - def _handle_exception(self, e): + def _handle_exception(self, e: Exception) -> None: if Events.EXCEPTION_RAISED in self._event_handlers: self._fire_event(Events.EXCEPTION_RAISED, e) else: raise e - def state_dict(self): + def state_dict(self) -> OrderedDict: """Returns a dictionary containing engine's state: "seed", "epoch_length", "max_epochs" and "iteration" Returns: @@ -478,10 +482,10 @@ def state_dict(self): """ if self.state is None: return OrderedDict() - keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0], ) + keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Mapping) -> None: """Setups engine from `state_dict`. State dictionary should contain keys: `iteration` or `epoch` and `max_epochs`, `epoch_length` and @@ -526,10 +530,11 @@ def load_state_dict(self, state_dict): self.state.iteration = self.state.epoch_length * self.state.epoch @staticmethod - def _is_done(state): + def _is_done(state: State) -> bool: return state.iteration == state.epoch_length * state.max_epochs - def run(self, data, max_epochs=None, epoch_length=None, seed=None): + def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Optional[int] = None, + seed: Optional[int] = None) -> State: """Runs the `process_function` over the passed data. Engine has a state and the following logic is applied in this function: @@ -602,7 +607,7 @@ def switch_batch(engine): self.state.dataloader = data return self._internal_run() - def _setup_engine(self): + def _setup_engine(self) -> None: try: self._dataloader_len = len(self.state.dataloader) if hasattr(self.state.dataloader, "__len__") else None @@ -637,7 +642,7 @@ def _setup_engine(self): self._init_iter.append(iteration) @staticmethod - def _from_iteration(data, iteration): + def _from_iteration(data: Union[Iterable, torch.utils.data.DataLoader], iteration: int) -> Iterator: if isinstance(data, torch.utils.data.DataLoader): try: # following is unsafe for IterableDatasets @@ -664,7 +669,7 @@ def _from_iteration(data, iteration): return data_iter @staticmethod - def _manual_seed(seed, epoch): + def _manual_seed(seed: int, epoch: int) -> None: random.seed(seed + epoch) torch.manual_seed(seed + epoch) try: @@ -673,7 +678,7 @@ def _manual_seed(seed, epoch): except ImportError: pass - def setup_seed(self): + def setup_seed(self) -> None: # seed value should be related to input data iterator length -> iteration at data iterator restart # - seed can not be epoch because during a single epoch we can have multiple `_dataloader_len` # - seed can not be iteration because when resuming from iteration we need to set the seed from the start of the @@ -681,7 +686,7 @@ def setup_seed(self): le = self._dataloader_len if self._dataloader_len is not None else 1 self._manual_seed(self.state.seed, self.state.iteration // le) - def _internal_run(self): + def _internal_run(self) -> State: self.should_terminate = self.should_terminate_single_epoch = False try: start_time = time.time() diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 4aa1c21f913..2427a377536 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +from typing import Callable, Optional, Union + from enum import Enum import numbers import weakref from ignite.engine.utils import _check_signature + __all__ = [ 'Events', 'State' @@ -12,13 +17,13 @@ class EventWithFilter: - def __init__(self, event, filter): + def __init__(self, event: CallableEvents, filter: Callable): if not callable(filter): raise TypeError("Argument filter should be callable") self.event = event self.filter = filter - def __str__(self): + def __str__(self) -> str: return "<%s event=%s, filter=%r>" % (self.__class__.__name__, self.event, self.filter) @@ -41,9 +46,11 @@ def call_on_test_event_every(engine): # do something """ - def __call__(self, event_filter=None, every=None, once=None): - if not((event_filter is not None) ^ (every is not None) ^ (once is not None)): + def __call__(self, event_filter: Optional[Callable] = None, + every: Optional[int] = None, once: Optional[int] = None) -> Union[CallableEvents, EventWithFilter]: + + if not ((event_filter is not None) ^ (every is not None) ^ (once is not None)): raise ValueError("Only one of the input arguments should be specified") if (event_filter is not None) and not callable(event_filter): @@ -70,19 +77,21 @@ def __call__(self, event_filter=None, every=None, once=None): return EventWithFilter(self, event_filter) @staticmethod - def every_event_filter(every): - def wrapper(engine, event): + def every_event_filter(every: int) -> Callable: + def wrapper(engine, event: bool): if event % every == 0: return True return False + return wrapper @staticmethod - def once_event_filter(once): - def wrapper(engine, event): + def once_event_filter(once: int) -> Callable: + def wrapper(engine, event: int) -> bool: if event == once: return True return False + return wrapper @@ -178,14 +187,14 @@ def __init__(self, **kwargs): if not hasattr(self, value): setattr(self, value, 0) - def get_event_attrib_value(self, event_name): + def get_event_attrib_value(self, event_name: Union[EventWithFilter, CallableEvents, Enum]) -> int: if isinstance(event_name, EventWithFilter): event_name = event_name.event if event_name not in State.event_to_attr: raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) - def __repr__(self): + def __repr__(self) -> str: s = "State:\n" for attr, value in self.__dict__.items(): if not isinstance(value, (numbers.Number, str)): @@ -223,12 +232,12 @@ def print_epoch(engine): # print_epoch handler is now unregistered """ - def __init__(self, event_name, handler, engine): + def __init__(self, event_name: Union[EventWithFilter, CallableEvents, Enum], handler: Callable, engine): self.event_name = event_name self.handler = weakref.ref(handler) self.engine = weakref.ref(engine) - def remove(self): + def remove(self) -> None: """Remove handler from engine.""" handler = self.handler() engine = self.engine() @@ -239,8 +248,8 @@ def remove(self): if engine.has_event_handler(handler, self.event_name): engine.remove_event_handler(handler, self.event_name) - def __enter__(self): + def __enter__(self) -> RemovableEventHandle: return self - def __exit__(self, type, value, tb): + def __exit__(self, *args, **kwargs) -> None: self.remove() diff --git a/ignite/engine/utils.py b/ignite/engine/utils.py index 9cf02314859..9298bee3bd2 100644 --- a/ignite/engine/utils.py +++ b/ignite/engine/utils.py @@ -1,33 +1,19 @@ +from __future__ import annotations import inspect +from typing import Optional, Generator, Callable import torch -def _check_signature(engine, fn, fn_description, *args, **kwargs): - exception_msg = None - - signature = inspect.signature(fn) - try: - signature.bind(engine, *args, **kwargs) - except TypeError as exc: - fn_params = list(signature.parameters) - exception_msg = str(exc) - - if exception_msg: - passed_params = [engine] + list(args) + list(kwargs) - raise ValueError("Error adding {} '{}': " - "takes parameters {} but will be called with {} " - "({}).".format(fn, fn_description, fn_params, passed_params, exception_msg)) - - -def _update_dataloader(dataloader, new_batch_sampler): +def _update_dataloader(dataloader: torch.utils.data.DataLoader, + new_batch_sampler: torch.utils.data.sampler.BatchSampler) -> torch.utils.data.DataLoader: params_keys = [k for k in dataloader.__dict__.keys() if not k.startswith("_")] for k in ['batch_size', 'sampler', 'drop_last', 'batch_sampler', 'dataset_kind']: if k in params_keys: params_keys.remove(k) params = {k: getattr(dataloader, k) for k in params_keys} params['batch_sampler'] = new_batch_sampler - return torch.utils.data.DataLoader(**params) + return type(dataloader)(**params) class ReproducibleBatchSampler(torch.utils.data.sampler.BatchSampler): @@ -38,7 +24,8 @@ class ReproducibleBatchSampler(torch.utils.data.sampler.BatchSampler): `torch.utils.data.DataLoader` start_iteration (int, optional): optional start iteration """ - def __init__(self, batch_sampler, start_iteration=None): + + def __init__(self, batch_sampler: torch.utils.data.sampler.BatchSampler, start_iteration: Optional[int] = None): if not isinstance(batch_sampler, torch.utils.data.sampler.BatchSampler): raise TypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler") @@ -47,7 +34,7 @@ def __init__(self, batch_sampler, start_iteration=None): self.start_iteration = start_iteration self.sampler = self.batch_sampler.sampler - def setup_batch_indices(self): + def setup_batch_indices(self) -> None: self.batch_indices = [] for batch in self.batch_sampler: self.batch_indices.append(batch) @@ -56,7 +43,7 @@ def setup_batch_indices(self): self.batch_indices = self.batch_indices[self.start_iteration:] self.start_iteration = None - def __iter__(self): + def __iter__(self) -> Generator: if self.batch_indices is None: self.setup_batch_indices() for batch in self.batch_indices: @@ -64,5 +51,19 @@ def __iter__(self): self.batch_indices = None - def __len__(self): + def __len__(self) -> int: return len(self.batch_sampler) + + +def _check_signature(engine, fn: Callable, fn_description: str, *args, **kwargs) -> None: + + signature = inspect.signature(fn) + try: + signature.bind(engine, *args, **kwargs) + except TypeError as exc: + fn_params = list(signature.parameters) + exception_msg = str(exc) + passed_params = [engine] + list(args) + list(kwargs) + raise ValueError("Error adding {} '{}': " + "takes parameters {} but will be called with {} " + "({}).".format(fn, fn_description, fn_params, passed_params, exception_msg)) diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index 74c8ffb31a0..39ca6932b14 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -1,3 +1,9 @@ +from typing import Callable, Any, Union +from enum import Enum + +from ignite.engine import Engine +from ignite.engine.events import EventWithFilter, CallableEvents + from ignite.handlers.checkpoint import ModelCheckpoint, Checkpoint, DiskSaver from ignite.handlers.timing import Timer from ignite.handlers.early_stopping import EarlyStopping @@ -14,7 +20,7 @@ ] -def global_step_from_engine(engine): +def global_step_from_engine(engine: Engine) -> Callable: """Helper method to setup `global_step_transform` function using another engine. This can be helpful for logging trainer epoch/iteration while output handler is attached to an evaluator. @@ -25,7 +31,7 @@ def global_step_from_engine(engine): global step """ - def wrapper(_, event_name): + def wrapper(_: Any, event_name: Union[EventWithFilter, CallableEvents, Enum]): return engine.state.get_event_attrib_value(event_name) return wrapper diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index d58860f5186..f823d070485 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -5,9 +5,11 @@ import collections.abc as collections import warnings +from typing import Optional, Callable, Mapping, Union + import torch -from ignite.engine import Events +from ignite.engine import Events, Engine __all__ = [ 'Checkpoint', @@ -123,9 +125,9 @@ def score_function(engine): Item = namedtuple("Item", ["priority", "filename"]) - def __init__(self, to_save, save_handler, filename_prefix="", - score_function=None, score_name=None, n_saved=1, - global_step_transform=None, archived=False): + def __init__(self, to_save: dict, save_handler: Callable, filename_prefix: str = "", + score_function: Optional[Callable] = None, score_name: Optional[str] = None, n_saved: int = 1, + global_step_transform: Callable = None, archived: bool = False): if not isinstance(to_save, collections.Mapping): raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save))) @@ -156,7 +158,7 @@ def __init__(self, to_save, save_handler, filename_prefix="", self.global_step_transform = global_step_transform @property - def last_checkpoint(self): + def last_checkpoint(self) -> str: if len(self._saved) < 1: return None return self._saved[0].filename @@ -166,7 +168,7 @@ def _check_lt_n_saved(self, or_equal=False): return True return len(self._saved) < self._n_saved + int(or_equal) - def __call__(self, engine): + def __call__(self, engine: Engine) -> None: suffix = "" if self.global_step_transform is not None: @@ -209,20 +211,20 @@ def __call__(self, engine): item = self._saved.pop(0) self.save_handler.remove(item.filename) - def _setup_checkpoint(self): + def _setup_checkpoint(self) -> dict: checkpoint = {} for k, obj in self.to_save.items(): checkpoint[k] = obj.state_dict() return checkpoint @staticmethod - def _check_objects(objs, attr): + def _check_objects(objs: dict, attr: str) -> None: for k, obj in objs.items(): if not hasattr(obj, attr): raise TypeError("Object {} should have `{}` method".format(type(obj), attr)) @staticmethod - def load_objects(to_load, checkpoint): + def load_objects(to_load: Mapping, checkpoint: Mapping) -> None: """Helper method to apply `load_state_dict` on the objects from `to_load` using states from `checkpoint`. Args: @@ -251,7 +253,7 @@ class DiskSaver: require_empty (bool, optional): If True, will raise exception if there are any files in the directory 'dirname'. """ - def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): + def __init__(self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True): self.dirname = os.path.expanduser(dirname) self._atomic = atomic if create_dir: @@ -269,7 +271,7 @@ def __init__(self, dirname, atomic=True, create_dir=True, require_empty=True): "directory anyway, pass `require_empty=False`." "".format(matched, dirname)) - def __call__(self, checkpoint, filename): + def __call__(self, checkpoint: Mapping, filename: str) -> None: path = os.path.join(self.dirname, filename) if not self._atomic: @@ -286,7 +288,7 @@ def __call__(self, checkpoint, filename): tmp.close() os.rename(tmp.name, path) - def remove(self, filename): + def remove(self, filename: str) -> None: path = os.path.join(self.dirname, filename) os.remove(path) @@ -357,13 +359,14 @@ class ModelCheckpoint(Checkpoint): ['/tmp/models/myprefix_mymodel_6.pth'] """ - def __init__(self, dirname, filename_prefix, - save_interval=None, - score_function=None, score_name=None, - n_saved=1, - atomic=True, require_empty=True, - create_dir=True, - save_as_state_dict=True, global_step_transform=None, archived=False): + def __init__(self, dirname: str, filename_prefix: str, + save_interval: Optional[Callable] = None, + score_function: Optional[Callable] = None, score_name: Optional[str] = None, + n_saved: int = 1, + atomic: bool = True, require_empty: bool = True, + create_dir: bool = True, + save_as_state_dict: bool = True, global_step_transform: Optional[Callable] = None, + archived: bool = False): if not save_as_state_dict: raise ValueError("Argument save_as_state_dict is deprecated and should be True") @@ -398,12 +401,12 @@ def __init__(self, dirname, filename_prefix, self.global_step_transform = global_step_transform @property - def last_checkpoint(self): + def last_checkpoint(self) -> Union[str, None]: if len(self._saved) < 1: return None return os.path.join(self.save_handler.dirname, self._saved[0].filename) - def __call__(self, engine, to_save): + def __call__(self, engine: Engine, to_save: Mapping) -> None: if len(to_save) == 0: raise RuntimeError("No objects to checkpoint found.") diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index 4f4df95c824..fe93fe7d27c 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -1,4 +1,5 @@ import logging +from typing import Callable from ignite.engine import Engine @@ -42,7 +43,8 @@ def score_function(engine): """ - def __init__(self, patience, score_function, trainer, min_delta=0., cumulative_delta=False): + def __init__(self, patience: int, score_function: Callable, trainer: Engine, min_delta: float = 0., + cumulative_delta: bool = False): if not callable(score_function): raise TypeError("Argument score_function should be a function.") @@ -65,7 +67,7 @@ def __init__(self, patience, score_function, trainer, min_delta=0., cumulative_d self.best_score = None self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) - def __call__(self, engine): + def __call__(self, engine: Engine) -> None: score = self.score_function(engine) if self.best_score is None: diff --git a/ignite/handlers/terminate_on_nan.py b/ignite/handlers/terminate_on_nan.py index e7d50d6df2e..a4178ad0a79 100644 --- a/ignite/handlers/terminate_on_nan.py +++ b/ignite/handlers/terminate_on_nan.py @@ -1,9 +1,11 @@ import logging import numbers +from typing import Union, Callable import torch from ignite.utils import apply_to_type +from ignite.engine import Engine __all__ = [ 'TerminateOnNan' @@ -32,15 +34,15 @@ class TerminateOnNan: """ - def __init__(self, output_transform=lambda x: x): + def __init__(self, output_transform: Callable = lambda x: x): self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) self.logger.addHandler(logging.StreamHandler()) self._output_transform = output_transform - def __call__(self, engine): + def __call__(self, engine: Engine) -> None: output = self._output_transform(engine.state.output) - def raise_error(x): + def raise_error(x: Union[numbers.Number, torch.Tensor]) -> None: if isinstance(x, numbers.Number): x = torch.tensor(x) diff --git a/ignite/handlers/timing.py b/ignite/handlers/timing.py index 26ad5f5c924..5146498e53c 100644 --- a/ignite/handlers/timing.py +++ b/ignite/handlers/timing.py @@ -1,6 +1,9 @@ +from __future__ import annotations + from time import perf_counter +from typing import Optional -from ignite.engine import Events +from ignite.engine import Events, Engine __all__ = [ 'Timer' @@ -77,7 +80,7 @@ class Timer: ... step=Events.ITERATION_COMPLETED) """ - def __init__(self, average=False): + def __init__(self, average: bool = False): self._average = average self._t0 = perf_counter() @@ -85,7 +88,8 @@ def __init__(self, average=False): self.step_count = 0. self.running = True - def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None): + def attach(self, engine: Engine, start: str = Events.STARTED, + pause: str = Events.COMPLETED, resume: Optional[str] = None, step: Optional[str] = None) -> Timer: """ Register callbacks to control the timer. Args: @@ -116,21 +120,21 @@ def attach(self, engine, start=Events.STARTED, pause=Events.COMPLETED, resume=No return self - def reset(self, *args): + def reset(self, *args) -> Timer: self.__init__(self._average) return self - def pause(self, *args): + def pause(self, *args) -> None: if self.running: self.total += self._elapsed() self.running = False - def resume(self, *args): + def resume(self, *args) -> None: if not self.running: self.running = True self._t0 = perf_counter() - def value(self): + def value(self) -> float: total = self.total if self.running: total += self._elapsed() @@ -142,8 +146,8 @@ def value(self): return total / denominator - def step(self, *args): + def step(self, *args) -> None: self.step_count += 1. - def _elapsed(self): + def _elapsed(self) -> float: return perf_counter() - self._t0 diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 972e433afb9..4e1c9af9537 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -1,5 +1,7 @@ import numbers +from typing import Callable, Union, Any, Optional + from ignite.metrics import Metric from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced from ignite.exceptions import NotComputableError @@ -43,7 +45,8 @@ class VariableAccumulation(Metric): """ _required_output_keys = None - def __init__(self, op, output_transform=lambda x: x, device=None): + def __init__(self, op: Callable, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): if not callable(op): raise TypeError("Argument op should be a callable, but given {}".format(type(op))) self.accumulator = None @@ -53,16 +56,17 @@ def __init__(self, op, output_transform=lambda x: x, device=None): super(VariableAccumulation, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device) self.num_examples = torch.tensor(0.0, dtype=torch.long, device=self._device) - def _check_output_type(self, output): + def _check_output_type(self, output: Union[Any, torch.Tensor, + numbers.Number]) -> None: if not (isinstance(output, numbers.Number) or isinstance(output, torch.Tensor)): raise TypeError("Output should be a number or torch.Tensor, but given {}".format(type(output))) @reinit__is_reduced - def update(self, output): + def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: self._check_output_type(output) if self._device is not None: @@ -77,7 +81,7 @@ def update(self, output): self.num_examples += 1 @sync_all_reduce('accumulator', 'num_examples') - def compute(self): + def compute(self) -> list: return [self.accumulator, self.num_examples] @@ -120,7 +124,8 @@ class Average(VariableAccumulation): """ - def __init__(self, output_transform=lambda x: x, device=None): + def __init__(self, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): def _mean_op(a, x): if isinstance(x, torch.Tensor) and x.ndim > 1: @@ -130,7 +135,7 @@ def _mean_op(a, x): super(Average, self).__init__(op=_mean_op, output_transform=output_transform, device=device) @sync_all_reduce('accumulator', 'num_examples') - def compute(self): + def compute(self) -> Union[Any, torch.Tensor, numbers.Number]: if self.num_examples < 1: raise NotComputableError("{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)) @@ -165,9 +170,11 @@ class GeometricAverage(VariableAccumulation): """ - def __init__(self, output_transform=lambda x: x, device=None): + def __init__(self, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): - def _geom_op(a, x): + def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, + torch.Tensor]) -> torch.Tensor: if not isinstance(x, torch.Tensor): x = torch.tensor(x) x = torch.log(x) @@ -178,7 +185,7 @@ def _geom_op(a, x): super(GeometricAverage, self).__init__(op=_geom_op, output_transform=output_transform, device=device) @sync_all_reduce('accumulator', 'num_examples') - def compute(self): + def compute(self) -> torch.Tensor: if self.num_examples < 1: raise NotComputableError("{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 64331047e1b..5db9c0b745b 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,3 +1,5 @@ +from typing import Callable, Union, Optional, Sequence + from ignite.metrics import Metric from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced from ignite.exceptions import NotComputableError @@ -11,17 +13,18 @@ class _BaseClassification(Metric): - def __init__(self, output_transform=lambda x: x, is_multilabel=False, device=None): + def __init__(self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, + device: Optional[Union[str, torch.device]] = None): self._is_multilabel = is_multilabel self._type = None self._num_classes = None super(_BaseClassification, self).__init__(output_transform=output_transform, device=device) - def reset(self): + def reset(self) -> None: self._type = None self._num_classes = None - def _check_shape(self, output): + def _check_shape(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): @@ -41,7 +44,7 @@ def _check_shape(self, output): if self._is_multilabel and not (y.shape == y_pred.shape and y.ndimension() > 1 and y.shape[1] != 1): raise ValueError("y and y_pred must have same shape of (batch_size, num_categories, ...).") - def _check_binary_multilabel_cases(self, output): + def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if not torch.equal(y, y ** 2): @@ -50,7 +53,7 @@ def _check_binary_multilabel_cases(self, output): if not torch.equal(y_pred, y_pred ** 2): raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.") - def _check_type(self, output): + def _check_type(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if y.ndimension() + 1 == y_pred.ndimension(): @@ -118,7 +121,9 @@ def thresholded_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, is_multilabel=False, device=None): + def __init__(self, output_transform: Callable = lambda x: x, + is_multilabel: bool = False, + device: Optional[Union[str, torch.device]] = None): self._num_correct = None self._num_examples = None super(Accuracy, self).__init__(output_transform=output_transform, @@ -126,13 +131,13 @@ def __init__(self, output_transform=lambda x: x, is_multilabel=False, device=Non device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._num_correct = 0 self._num_examples = 0 super(Accuracy, self).reset() @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output self._check_shape((y_pred, y)) self._check_type((y_pred, y)) @@ -154,7 +159,7 @@ def update(self, output): self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") - def compute(self): + def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError('Accuracy must have at least one example before it can be computed.') return self._num_correct / self._num_examples diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index ffb2e970c74..47889590768 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -1,4 +1,5 @@ import numbers +from typing import Optional, Union, Any, Callable, Sequence import torch @@ -49,7 +50,8 @@ class ConfusionMatrix(Metric): """ - def __init__(self, num_classes, average=None, output_transform=lambda x: x, device=None): + def __init__(self, num_classes: int, average: Optional[str] = None, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): if average is not None and average not in ("samples", "recall", "precision"): raise ValueError("Argument average can None or one of ['samples', 'recall', 'precision']") @@ -60,13 +62,13 @@ def __init__(self, num_classes, average=None, output_transform=lambda x: x, devi super(ConfusionMatrix, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes, dtype=torch.int64, device=self._device) self._num_examples = 0 - def _check_shape(self, output): + def _check_shape(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if y_pred.ndimension() < 2: @@ -92,7 +94,7 @@ def _check_shape(self, output): raise ValueError("y and y_pred must have compatible shapes.") @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) y_pred, y = output @@ -111,7 +113,7 @@ def update(self, output): self.confusion_matrix += m.to(self.confusion_matrix) @sync_all_reduce('confusion_matrix', '_num_examples') - def compute(self): + def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError('Confusion matrix must have at least one example before it can be computed.') if self.average: @@ -125,7 +127,7 @@ def compute(self): return self.confusion_matrix -def IoU(cm, ignore_index=None): +def IoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda: """Calculates Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -173,7 +175,7 @@ def ignore_index_fn(iou_vector): return iou -def mIoU(cm, ignore_index=None): +def mIoU(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda: """Calculates mean Intersection over Union using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -200,7 +202,7 @@ def mIoU(cm, ignore_index=None): return IoU(cm=cm, ignore_index=ignore_index).mean() -def cmAccuracy(cm): +def cmAccuracy(cm: ConfusionMatrix) -> MetricsLambda: """Calculates accuracy using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -214,7 +216,7 @@ def cmAccuracy(cm): return cm.diag().sum() / (cm.sum() + 1e-15) -def cmPrecision(cm, average=True): +def cmPrecision(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda: """Calculates precision using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -232,7 +234,7 @@ def cmPrecision(cm, average=True): return precision -def cmRecall(cm, average=True): +def cmRecall(cm: ConfusionMatrix, average: bool = True) -> MetricsLambda: """ Calculates recall using :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -250,7 +252,7 @@ def cmRecall(cm, average=True): return recall -def DiceCoefficient(cm, ignore_index=None): +def DiceCoefficient(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda: """Calculates Dice Coefficient for a given :class:`~ignite.metrics.ConfusionMatrix` metric. Args: @@ -271,7 +273,7 @@ def DiceCoefficient(cm, ignore_index=None): if ignore_index is not None: - def ignore_index_fn(dice_vector): + def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor: if ignore_index >= len(dice_vector): raise ValueError("ignore_index {} is larger than the length of Dice vector {}" .format(ignore_index, len(dice_vector))) diff --git a/ignite/metrics/epoch_metric.py b/ignite/metrics/epoch_metric.py index 098ac70754a..3c6033745cc 100644 --- a/ignite/metrics/epoch_metric.py +++ b/ignite/metrics/epoch_metric.py @@ -1,4 +1,5 @@ import warnings +from typing import Callable, Sequence import torch @@ -37,7 +38,8 @@ class EpochMetric(Metric): you want to compute the metric with respect to one of the outputs. """ - def __init__(self, compute_fn, output_transform=lambda x: x): + + def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x): if not callable(compute_fn): raise TypeError("Argument compute_fn should be callable.") @@ -45,11 +47,11 @@ def __init__(self, compute_fn, output_transform=lambda x: x): super(EpochMetric, self).__init__(output_transform=output_transform, device='cpu') self.compute_fn = compute_fn - def reset(self): + def reset(self) -> None: self._predictions = torch.tensor([], dtype=torch.float32) self._targets = torch.tensor([], dtype=torch.long) - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output if y_pred.ndimension() not in (1, 2): @@ -82,7 +84,7 @@ def update(self, output): warnings.warn("Probably, there can be a problem with `compute_fn`:\n {}.".format(e), EpochMetricWarning) - def compute(self): + def compute(self) -> None: return self.compute_fn(self._predictions, self._targets) diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index 74322e61d76..35d71f21fa4 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -1,11 +1,17 @@ -from ignite.metrics import Precision, Recall +from typing import Optional, Union, Callable __all__ = [ 'Fbeta' ] +import torch -def Fbeta(beta, average=True, precision=None, recall=None, output_transform=None, device=None): +from ignite.metrics import Precision, Recall, MetricsLambda + + +def Fbeta(beta: float, average: bool = True, precision: Optional[Precision] = None, + recall: Optional[Recall] = None, output_transform: Optional[Callable] = None, + device: Optional[Union[str, torch.device]] = None) -> MetricsLambda: """Calculates F-beta score Args: diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 807a35b27d2..ca3456e96d4 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -1,3 +1,7 @@ +from typing import Callable, Union, Optional, Sequence + +import torch + from ignite.exceptions import NotComputableError from ignite.metrics import Metric from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced @@ -33,19 +37,19 @@ class Loss(Metric): """ _required_output_keys = None - def __init__(self, loss_fn, output_transform=lambda x: x, - batch_size=lambda x: len(x), device=None): + def __init__(self, loss_fn: Callable, output_transform: Callable = lambda x: x, + batch_size: Callable = lambda x: len(x), device: Optional[Union[str, torch.device]] = None): super(Loss, self).__init__(output_transform, device=device) self._loss_fn = loss_fn self._batch_size = batch_size @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._sum = 0 self._num_examples = 0 @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None: if len(output) == 2: y_pred, y = output kwargs = {} @@ -61,7 +65,7 @@ def update(self, output): self._num_examples += N @sync_all_reduce("_sum", "_num_examples") - def compute(self): + def compute(self) -> None: if self._num_examples == 0: raise NotComputableError( 'Loss must have at least one example before it can be computed.') diff --git a/ignite/metrics/mean_absolute_error.py b/ignite/metrics/mean_absolute_error.py index 3a48073ed82..454e5b36e4a 100644 --- a/ignite/metrics/mean_absolute_error.py +++ b/ignite/metrics/mean_absolute_error.py @@ -1,3 +1,5 @@ +from typing import Sequence, Union + import torch from ignite.exceptions import NotComputableError @@ -15,20 +17,21 @@ class MeanAbsoluteError(Metric): - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. """ + @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._sum_of_absolute_errors = 0.0 self._num_examples = 0 @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output absolute_errors = torch.abs(y_pred - y.view_as(y_pred)) self._sum_of_absolute_errors += torch.sum(absolute_errors).item() self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_absolute_errors", "_num_examples") - def compute(self): + def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.') return self._sum_of_absolute_errors / self._num_examples diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index 341ca5a86a2..0824e7bf96e 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -1,3 +1,5 @@ +from typing import Union, Sequence, Optional, Callable + import torch from torch.nn.functional import pairwise_distance @@ -16,7 +18,9 @@ class MeanPairwiseDistance(Metric): - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. """ - def __init__(self, p=2, eps=1e-6, output_transform=lambda x: x, device=None): + + def __init__(self, p: int = 2, eps: float = 1e-6, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): super(MeanPairwiseDistance, self).__init__(output_transform, device=device) self._p = p self._eps = eps @@ -27,14 +31,14 @@ def reset(self): self._num_examples = 0 @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps) self._sum_of_distances += torch.sum(distances).item() self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_distances", "_num_examples") - def compute(self): + def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError('MeanAbsoluteError must have at least one example before it can be computed.') return self._sum_of_distances / self._num_examples diff --git a/ignite/metrics/mean_squared_error.py b/ignite/metrics/mean_squared_error.py index 718101ef6d2..42e27a7c1b7 100644 --- a/ignite/metrics/mean_squared_error.py +++ b/ignite/metrics/mean_squared_error.py @@ -1,3 +1,5 @@ +from typing import Union, Sequence + import torch from ignite.exceptions import NotComputableError @@ -15,20 +17,21 @@ class MeanSquaredError(Metric): - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. """ + @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._sum_of_squared_errors = 0.0 self._num_examples = 0 @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output squared_errors = torch.pow(y_pred - y.view_as(y_pred), 2) self._sum_of_squared_errors += torch.sum(squared_errors).item() self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_squared_errors", "_num_examples") - def compute(self): + def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError('MeanSquaredError must have at least one example before it can be computed.') return self._sum_of_squared_errors / self._num_examples diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 7445272877a..23c3b4709a3 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -1,13 +1,16 @@ +from __future__ import annotations import numbers from abc import ABCMeta, abstractmethod from functools import wraps from collections.abc import Mapping import warnings +from typing import Callable, Union, Optional, Any + import torch import torch.distributed as dist -from ignite.engine import Events +from ignite.engine import Events, Engine __all__ = [ 'Metric' @@ -32,7 +35,8 @@ class Metric(metaclass=ABCMeta): """ _required_output_keys = ("y_pred", "y") - def __init__(self, output_transform=lambda x: x, device=None): + def __init__(self, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): self._output_transform = output_transform # Check device if distributed is initialized: @@ -51,7 +55,7 @@ def __init__(self, output_transform=lambda x: x, device=None): self.reset() @abstractmethod - def reset(self): + def reset(self) -> None: """ Resets the metric to it's initial state. @@ -60,7 +64,7 @@ def reset(self): pass @abstractmethod - def update(self, output): + def update(self, output) -> None: """ Updates the metric's state using the passed batch output. @@ -72,7 +76,7 @@ def update(self, output): pass @abstractmethod - def compute(self): + def compute(self) -> Any: """ Computes the metric based on it's accumulated state. @@ -86,7 +90,10 @@ def compute(self): """ pass - def _sync_all_reduce(self, tensor): + def _sync_all_reduce( + self, + tensor: Union[torch.Tensor, numbers.Number]) -> Union[torch.Tensor, + numbers.Number]: if not (dist.is_available() and dist.is_initialized()): # Nothing to reduce return tensor @@ -111,11 +118,11 @@ def _sync_all_reduce(self, tensor): return tensor.item() return tensor - def started(self, engine): + def started(self, engine: Engine) -> None: self.reset() @torch.no_grad() - def iteration_completed(self, engine): + def iteration_completed(self, engine: Engine) -> None: output = self._output_transform(engine.state.output) if isinstance(output, Mapping): if self._required_output_keys is None: @@ -128,76 +135,76 @@ def iteration_completed(self, engine): output = tuple(output[k] for k in self._required_output_keys) self.update(output) - def completed(self, engine, name): + def completed(self, engine: Engine, name: str) -> None: result = self.compute() if torch.is_tensor(result) and len(result.shape) == 0: result = result.item() engine.state.metrics[name] = result - def attach(self, engine, name): + def attach(self, engine: Engine, name: str) -> None: engine.add_event_handler(Events.EPOCH_COMPLETED, self.completed, name) if not engine.has_event_handler(self.started, Events.EPOCH_STARTED): engine.add_event_handler(Events.EPOCH_STARTED, self.started) if not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed) - def __add__(self, other): + def __add__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, self, other) - def __radd__(self, other): + def __radd__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x + y, other, self) - def __sub__(self, other): + def __sub__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, self, other) - def __rsub__(self, other): + def __rsub__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x - y, other, self) - def __mul__(self, other): + def __mul__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, self, other) - def __rmul__(self, other): + def __rmul__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x * y, other, self) - def __pow__(self, other): + def __pow__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, self, other) - def __rpow__(self, other): + def __rpow__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x ** y, other, self) - def __mod__(self, other): + def __mod__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x % y, self, other) - def __div__(self, other): + def __div__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), self, other) - def __rdiv__(self, other): + def __rdiv__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__div__(y), other, self) - def __truediv__(self, other): + def __truediv__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), self, other) - def __rtruediv__(self, other): + def __rtruediv__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x.__truediv__(y), other, self) - def __floordiv__(self, other): + def __floordiv__(self, other: Metric) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x, y: x // y, self, other) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Callable: from ignite.metrics import MetricsLambda def fn(x, *args, **kwargs): @@ -205,19 +212,19 @@ def fn(x, *args, **kwargs): def wrapper(*args, **kwargs): return MetricsLambda(fn, self, *args, **kwargs) + return wrapper - def __getitem__(self, index): + def __getitem__(self, index: Any) -> Metric: from ignite.metrics import MetricsLambda return MetricsLambda(lambda x: x[index], self) -def sync_all_reduce(*attrs): - - def wrapper(func): +def sync_all_reduce(*attrs) -> Callable: + def wrapper(func: Callable) -> Callable: @wraps(func) - def another_wrapper(self, *args, **kwargs): + def another_wrapper(self: Metric, *args, **kwargs) -> Callable: if not isinstance(self, Metric): raise RuntimeError("Decorator sync_all_reduce should be used on " "ignite.metric.Metric class methods only") @@ -238,8 +245,7 @@ def another_wrapper(self, *args, **kwargs): return wrapper -def reinit__is_reduced(func): - +def reinit__is_reduced(func: Callable) -> Callable: @wraps(func) def wrapper(self, *args, **kwargs): func(self, *args, **kwargs) diff --git a/ignite/metrics/metrics_lambda.py b/ignite/metrics/metrics_lambda.py index 77c9ac6a739..8061e7d8bc2 100644 --- a/ignite/metrics/metrics_lambda.py +++ b/ignite/metrics/metrics_lambda.py @@ -1,7 +1,8 @@ import itertools +from typing import Callable, Any from ignite.metrics.metric import Metric, reinit__is_reduced -from ignite.engine import Events +from ignite.engine import Events, Engine __all__ = [ 'MetricsLambda' @@ -39,31 +40,32 @@ def Fbeta(r, p, beta): F3 = MetricsLambda(Fbeta, recall, precision, 3) F4 = MetricsLambda(Fbeta, recall, precision, 4) """ - def __init__(self, f, *args, **kwargs): + + def __init__(self, f: Callable, *args, **kwargs): self.function = f self.args = args self.kwargs = kwargs super(MetricsLambda, self).__init__(device='cpu') @reinit__is_reduced - def reset(self): + def reset(self) -> None: for i in itertools.chain(self.args, self.kwargs.values()): if isinstance(i, Metric): i.reset() @reinit__is_reduced - def update(self, output): + def update(self, output) -> None: # NB: this method does not recursively update dependency metrics, # which might cause duplicate update issue. To update this metric, # users should manually update its dependencies. pass - def compute(self): + def compute(self) -> Any: materialized = [i.compute() if isinstance(i, Metric) else i for i in self.args] materialized_kwargs = {k: (v.compute() if isinstance(v, Metric) else v) for k, v in self.kwargs.items()} return self.function(*materialized, **materialized_kwargs) - def _internal_attach(self, engine): + def _internal_attach(self, engine: Engine) -> None: for index, metric in enumerate(itertools.chain(self.args, self.kwargs.values())): if isinstance(metric, MetricsLambda): metric._internal_attach(engine) @@ -73,7 +75,7 @@ def _internal_attach(self, engine): if not engine.has_event_handler(metric.iteration_completed, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, metric.iteration_completed) - def attach(self, engine, name): + def attach(self, engine: Engine, name: str) -> None: # recursively attach all its dependencies self._internal_attach(engine) # attach only handler on EPOCH_COMPLETED diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 8ca1304b302..d4cd93359c6 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,4 +1,5 @@ import warnings +from typing import Sequence, Callable, Optional, Union import torch @@ -14,7 +15,8 @@ class _BasePrecisionRecall(_BaseClassification): - def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False, device=None): + def __init__(self, output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, + device: Optional[Union[str, torch.device]] = None): if torch.distributed.is_available() and torch.distributed.is_initialized(): if (not average) and is_multilabel: warnings.warn("Precision/Recall metrics do not work in distributed setting when average=False " @@ -30,13 +32,13 @@ def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=Fa device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: dtype = torch.float64 self._true_positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 self._positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 super(_BasePrecisionRecall, self).reset() - def compute(self): + def compute(self) -> torch.Tensor: if not (isinstance(self._positives, torch.Tensor) or self._positives > 0): raise NotComputableError("{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)) @@ -113,12 +115,13 @@ def thresholded_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False, device=None): + def __init__(self, output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, + device: Optional[Union[str, torch.device]] = None): super(Precision, self).__init__(output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device) @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output self._check_shape(output) self._check_type((y_pred, y)) diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 821b25377be..8a18ddc174e 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -1,3 +1,5 @@ +from typing import Sequence, Callable, Optional, Union + import torch from ignite.metrics.precision import _BasePrecisionRecall @@ -67,12 +69,13 @@ def thresholded_output_transform(output): """ - def __init__(self, output_transform=lambda x: x, average=False, is_multilabel=False, device=None): + def __init__(self, output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, + device: Optional[Union[str, torch.device]] = None): super(Recall, self).__init__(output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device) @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output self._check_shape(output) self._check_type((y_pred, y)) diff --git a/ignite/metrics/root_mean_squared_error.py b/ignite/metrics/root_mean_squared_error.py index ed894eac560..63a4c4c5283 100644 --- a/ignite/metrics/root_mean_squared_error.py +++ b/ignite/metrics/root_mean_squared_error.py @@ -1,4 +1,7 @@ import math +from typing import Union + +import torch from ignite.metrics.mean_squared_error import MeanSquaredError @@ -13,6 +16,7 @@ class RootMeanSquaredError(MeanSquaredError): - `update` must receive output of the form (y_pred, y) or `{'y_pred': y_pred, 'y': y}`. """ - def compute(self): + + def compute(self) -> Union[torch.Tensor, float]: mse = super(RootMeanSquaredError, self).compute() return math.sqrt(mse) diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 8300b68f64d..e9c4ffbf36c 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -1,4 +1,8 @@ -from ignite.engine import Events +from typing import Optional, Union, Callable, Sequence + +import torch + +from ignite.engine import Events, Engine from ignite.metrics import Metric from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce @@ -43,7 +47,8 @@ def log_running_avg_metrics(engine): """ _required_output_keys = None - def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True, device=None): + def __init__(self, src: Optional[Metric] = None, alpha: float = 0.98, output_transform: Optional[Callable] = None, + epoch_bound: bool = True, device: Optional[Union[str, torch.device]] = None): if not (isinstance(src, Metric) or src is None): raise TypeError("Argument src should be a Metric or None.") if not (0.0 < alpha <= 1.0): @@ -69,15 +74,15 @@ def __init__(self, src=None, alpha=0.98, output_transform=None, epoch_bound=True super(RunningAverage, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._value = None @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence) -> None: # Implement abstract method pass - def compute(self): + def compute(self) -> Union[torch.Tensor, float]: if self._value is None: self._value = self._get_src_value() else: @@ -85,7 +90,7 @@ def compute(self): return self._value - def attach(self, engine, name): + def attach(self, engine: Engine, name: str): if self.epoch_bound: # restart average every epoch engine.add_event_handler(Events.EPOCH_STARTED, self.started) @@ -94,17 +99,17 @@ def attach(self, engine, name): # apply running average engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name) - def _get_metric_value(self): + def _get_metric_value(self) -> Union[torch.Tensor, float]: return self.src.compute() @sync_all_reduce("src") - def _get_output_value(self): + def _get_output_value(self) -> Metric: return self.src - def _metric_iteration_completed(self, engine): + def _metric_iteration_completed(self, engine: Engine) -> None: self.src.started(engine) self.src.iteration_completed(engine) @reinit__is_reduced - def _output_update(self, output): + def _output_update(self, output: Metric) -> None: self.src = output diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 14e3ad49879..7d9bdb664f1 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -1,3 +1,5 @@ +from typing import Union, Optional, Callable, Sequence + import torch from ignite.metrics.metric import Metric @@ -15,17 +17,19 @@ class TopKCategoricalAccuracy(Metric): - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. """ - def __init__(self, k=5, output_transform=lambda x: x, device=None): + + def __init__(self, k=5, output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = None): super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device) self._k = k @reinit__is_reduced - def reset(self): + def reset(self) -> None: self._num_correct = 0 self._num_examples = 0 @reinit__is_reduced - def update(self, output): + def update(self, output: Sequence) -> None: y_pred, y = output sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) @@ -34,7 +38,7 @@ def update(self, output): self._num_examples += correct.shape[0] @sync_all_reduce("_num_correct", "_num_examples") - def compute(self): + def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("TopKCategoricalAccuracy must have at" "least one example before it can be computed.") diff --git a/ignite/utils.py b/ignite/utils.py index 381a500f7c7..503712f8ad4 100644 --- a/ignite/utils.py +++ b/ignite/utils.py @@ -1,5 +1,6 @@ import collections.abc as collections import logging +from typing import Union, Optional, Callable, Any, Type, Tuple import torch @@ -12,21 +13,37 @@ ] -def convert_tensor(input_, device=None, non_blocking=False): +def convert_tensor(input_: Union[torch.Tensor, collections.Sequence, + collections.Mapping, str, bytes], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False) -> Union[torch.Tensor, + collections.Sequence, + collections.Mapping, + str, bytes]: """Move tensors to relevant device.""" - def _func(tensor): - return tensor.to(device=device, non_blocking=non_blocking) if device else tensor + + def _func(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(device=device, + non_blocking=non_blocking) if device else tensor return apply_to_tensor(input_, _func) -def apply_to_tensor(input_, func): +def apply_to_tensor(input_: Union[torch.Tensor, collections.Sequence, + collections.Mapping, str, bytes], + func: Callable) -> Union[torch.Tensor, + collections.Sequence, + collections.Mapping, str, bytes]: """Apply a function on a tensor or mapping, or sequence of tensors. """ return apply_to_type(input_, torch.Tensor, func) -def apply_to_type(input_, input_type, func): +def apply_to_type(input_: Union[Any, collections.Sequence, + collections.Mapping, str, bytes], + input_type: Union[Type, Tuple[Type[Any], Any]], + func: Callable) -> Union[Any, collections.Sequence, + collections.Mapping, str, bytes]: """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`. """ if isinstance(input_, input_type): @@ -42,7 +59,7 @@ def apply_to_type(input_, input_type, func): .format(input_type, type(input_)))) -def to_onehot(indices, num_classes): +def to_onehot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: """Convert a tensor of indices of any shape `(N, ...)` to a tensor of one-hot indicators of shape `(N, num_classes, ...) and of type uint8. Output's device is equal to the input's device`. @@ -53,8 +70,10 @@ def to_onehot(indices, num_classes): return onehot.scatter_(1, indices.unsqueeze(1), 1) -def setup_logger(name, level=logging.INFO, format="%(asctime)s %(name)s %(levelname)s: %(message)s", - filepath=None, distributed_rank=0): +def setup_logger(name: str, level: int = logging.INFO, + format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s", + filepath: Optional[str] = None, + distributed_rank: int = 0) -> logging.Logger: """Setups logger: name, level, format etc. Args: