Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing Core Package #716

Merged
merged 37 commits into from
Jan 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b8ffba5
Add typing for utils
justusschock Jan 23, 2020
38e168f
Add typing for time conversions
justusschock Jan 23, 2020
f7aa0d6
typing for metric base class
justusschock Jan 23, 2020
91bf839
typing for lambda metrics
justusschock Jan 23, 2020
efb25d0
typing for accumulation metrics
justusschock Jan 23, 2020
d9cfd45
typing for accuracy metrics
justusschock Jan 23, 2020
0215caf
typing for confusion matrix metrics
justusschock Jan 23, 2020
e21ace5
typing for epoch metrics
justusschock Jan 23, 2020
592be26
typing for fbeta scores
justusschock Jan 23, 2020
2c153bc
typing for average loss scores
justusschock Jan 23, 2020
f5ee65c
typing for mae
justusschock Jan 23, 2020
0c73a18
typing for mean pairwise distance
justusschock Jan 23, 2020
43ef791
typing for mse
justusschock Jan 23, 2020
86022c7
typing for precision
justusschock Jan 23, 2020
817494d
typing for recall
justusschock Jan 23, 2020
479dfa2
typing for rmse
justusschock Jan 23, 2020
f50985f
typing for running average
justusschock Jan 23, 2020
4941a30
typing for top k
justusschock Jan 23, 2020
ddb837d
typing for checkpointer
justusschock Jan 23, 2020
231c992
typing for early stopping
justusschock Jan 23, 2020
ac4f7d9
typing for termination on NaN
justusschock Jan 23, 2020
eb4e4de
typing for timer
justusschock Jan 23, 2020
5a7e990
typing for engine
justusschock Jan 23, 2020
5a6ef24
Pep8 Fixes
justusschock Jan 23, 2020
b54a99e
Apply suggestions from code review
justusschock Jan 23, 2020
d937d1f
Add missing typing for inits
justusschock Jan 24, 2020
1548f34
Add missing import
justusschock Jan 27, 2020
1804e39
pep8 Fixes
justusschock Jan 27, 2020
caaa6ee
pep8 Fixes in contrib package
justusschock Jan 27, 2020
791ed1f
fix cyclic imports during runtime
justusschock Jan 27, 2020
ea6000f
move signature checks back to engine.py
justusschock Jan 27, 2020
e06ab0e
remove typing for check_signature as this can't be done without circu…
justusschock Jan 27, 2020
660ae85
pep8
justusschock Jan 27, 2020
8ac63c2
readd dynamic typing (was lost during rebasing)
justusschock Jan 27, 2020
60dfa42
Merge branch 'master' into typing
vfdev-5 Jan 27, 2020
85f670c
Removed useless _check_signature
vfdev-5 Jan 27, 2020
6ea3521
Minor fix in _check_signature
vfdev-5 Jan 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ignite/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Union, Tuple

# For compatibilty
from ignite.utils import convert_tensor, apply_to_tensor, apply_to_type, to_onehot


def _to_hours_mins_secs(time_taken):
def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, int]:
"""Convert seconds to hours, mins, and seconds."""
mins, secs = divmod(time_taken, 60)
hours, mins = divmod(mins, 60)
Expand Down
3 changes: 0 additions & 3 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def _setup_common_distrib_training_handlers(trainer, train_sampler=None,
lr_scheduler=None, with_gpu_stats=True,
output_names=None, with_pbars=True, with_pbar_on_iters=True,
log_every_iters=100, device='cuda'):

if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Distributed setting is not initialized, please call `dist.init_process_group` before.")

Expand Down Expand Up @@ -162,7 +161,6 @@ def empty_cuda_cache(_):


def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters):

if optimizers is not None:
from torch.optim.optimizer import Optimizer

Expand Down Expand Up @@ -279,7 +277,6 @@ def setup_plx_logging(trainer, optimizers=None, evaluators=None, log_every_iters


def get_default_score_fn(metric_name):

def wrapper(engine):
score = engine.state.metrics[metric_name]
return score
Expand Down
16 changes: 8 additions & 8 deletions ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def _detach_hidden(hidden):


def create_supervised_tbptt_trainer(
model,
optimizer,
loss_fn,
tbtt_step,
dim=0,
device=None,
non_blocking=False,
prepare_batch=_prepare_batch
model,
optimizer,
loss_fn,
tbtt_step,
dim=0,
device=None,
non_blocking=False,
prepare_batch=_prepare_batch
):
"""Create a trainer for truncated backprop through time supervised models.

Expand Down
1 change: 0 additions & 1 deletion ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from ignite.contrib.handlers.param_scheduler import LinearCyclicalScheduler, CosineAnnealingScheduler, \
ConcatScheduler, LRScheduler, create_lr_scheduler_with_warmup, PiecewiseLinear, ParamGroupScheduler

Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class BaseLogger:
Base logger handler. See implementations: TensorboardLogger, VisdomLogger, PolyaxonLogger, MLflowLogger, ...

"""

def attach(self, engine, log_handler, event_name):
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.

Expand Down
9 changes: 5 additions & 4 deletions ignite/contrib/handlers/custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ def __init__(self, n_iterations=None, n_epochs=None):

self.custom_state_attr = "{}_{}".format(prefix, self.period)
event_name = "{}_{}".format(prefix.upper(), self.period)
setattr(self, "Events", Enum("Events", " ".join([
"{}_STARTED".format(event_name),
"{}_COMPLETED".format(event_name)])
))
setattr(self, "Events",
Enum("Events",
" ".join(["{}_STARTED".format(event_name),
"{}_COMPLETED".format(event_name)])
))
# Update State.event_to_attr
for e in self.Events:
State.event_to_attr[e] = self.custom_state_attr
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler, BaseOptimizerParamsHandler, \
global_step_from_engine


__all__ = ['MLflowLogger', 'OutputHandler', 'OptimizerParamsHandler', 'global_step_from_engine']


Expand Down Expand Up @@ -90,6 +89,7 @@ def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)

"""

def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None):
super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform)

Expand Down
4 changes: 2 additions & 2 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,8 +782,8 @@ def _get_start_end(self):
elif self.milestones[-1] <= self.event_index:
return self.event_index, self.event_index + 1, self.values[-1], self.values[-1],
elif self.milestones[self._index] <= self.event_index < self.milestones[self._index + 1]:
return self.milestones[self._index], self.milestones[self._index + 1], \
self.values[self._index], self.values[self._index + 1]
return (self.milestones[self._index], self.milestones[self._index + 1],
self.values[self._index], self.values[self._index + 1])
else:
self._index += 1
return self._get_start_end()
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/polyaxon_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOutputHandler, BaseOptimizerParamsHandler, \
global_step_from_engine


__all__ = ['PolyaxonLogger', 'OutputHandler', 'OptimizerParamsHandler', 'global_step_from_engine']


Expand Down Expand Up @@ -91,6 +90,7 @@ def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)

"""

def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None):
super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform)

Expand Down Expand Up @@ -225,4 +225,5 @@ def __init__(self):
def __getattr__(self, attr):
def wrapper(*args, **kwargs):
return getattr(self.experiment, attr)(*args, **kwargs)

return wrapper
5 changes: 4 additions & 1 deletion ignite/contrib/handlers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler, \
BaseWeightsScalarHandler, BaseWeightsHistHandler, global_step_from_engine


__all__ = ['TensorboardLogger', 'OptimizerParamsHandler', 'OutputHandler',
'WeightsScalarHandler', 'WeightsHistHandler', 'GradsScalarHandler',
'GradsHistHandler', 'global_step_from_engine']
Expand Down Expand Up @@ -92,6 +91,7 @@ def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)

"""

def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None):
super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform)

Expand Down Expand Up @@ -184,6 +184,7 @@ class WeightsScalarHandler(BaseWeightsScalarHandler):
tag (str, optional): common title for all produced plots. For example, 'generator'

"""

def __init__(self, model, reduction=torch.norm, tag=None):
super(WeightsScalarHandler, self).__init__(model, reduction, tag=tag)

Expand Down Expand Up @@ -271,6 +272,7 @@ class GradsScalarHandler(BaseWeightsScalarHandler):
tag (str, optional): common title for all produced plots. For example, 'generator'

"""

def __init__(self, model, reduction=torch.norm, tag=None):
super(GradsScalarHandler, self).__init__(model, reduction, tag=tag)

Expand Down Expand Up @@ -312,6 +314,7 @@ class GradsHistHandler(BaseWeightsHistHandler):
tag (str, optional): common title for all produced plots. For example, 'generator'

"""

def __init__(self, model, tag=None):
super(GradsHistHandler, self).__init__(model, tag=tag)

Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class _OutputHandler(BaseOutputHandler):
:meth:`~ignite.engine.Engine.register_events`.

"""

def __init__(self, description, metric_names=None, output_transform=None,
closing_event_name=Events.EPOCH_COMPLETED):
if metric_names is None and output_transform is None:
Expand Down
3 changes: 1 addition & 2 deletions ignite/contrib/handlers/visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler, \
BaseWeightsScalarHandler, global_step_from_engine


__all__ = ['VisdomLogger', 'OptimizerParamsHandler', 'OutputHandler',
'WeightsScalarHandler', 'GradsScalarHandler', 'global_step_from_engine']

Expand Down Expand Up @@ -142,6 +141,7 @@ def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)

"""

def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None,
show_legend=False):
super(OutputHandler, self).__init__(tag, metric_names, output_transform, another_engine, global_step_transform)
Expand Down Expand Up @@ -473,7 +473,6 @@ def close(self):


class _DummyExecutor:

class _DummyFuture:

def __init__(self, result):
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ def activated_output_transform(output):
avg_precision = AveragePrecision(activated_output_transform)

"""

def __init__(self, output_transform=lambda x: x):
super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform)
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/fractional_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FractionalBias(_BaseRegression):
__ https://arxiv.org/abs/1809.03006

"""

def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
__ https://arxiv.org/abs/1809.03006

"""

def reset(self):
self._sum_y = 0.0
self._num_examples = 0
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/manhattan_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ManhattanDistance(_BaseRegression):
__ https://arxiv.org/abs/1809.03006

"""

def reset(self):
self._sum_of_errors = 0.0

Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class MeanError(_BaseRegression):
__ https://arxiv.org/abs/1809.03006

"""

def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/mean_normalized_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class MeanNormalizedBias(_BaseRegression):
__ https://arxiv.org/abs/1809.03006

"""

def reset(self):
self._sum_of_errors = 0.0
self._num_examples = 0
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/median_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ class MedianAbsoluteError(_BaseRegressionEpoch):
__ https://arxiv.org/abs/1809.03006

"""

def __init__(self, output_transform=lambda x: x):
super(MedianAbsoluteError, self).__init__(median_absolute_error_compute_fn, output_transform)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MedianAbsolutePercentageError(_BaseRegressionEpoch):
__ https://arxiv.org/abs/1809.03006

"""

def __init__(self, output_transform=lambda x: x):
super(MedianAbsolutePercentageError, self).__init__(median_absolute_percentage_error_compute_fn,
output_transform)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MedianRelativeAbsoluteError(_BaseRegressionEpoch):
__ https://arxiv.org/abs/1809.03006

"""

def __init__(self, output_transform=lambda x: x):
super(MedianRelativeAbsoluteError, self).__init__(median_relative_absolute_error_compute_fn,
output_transform)
1 change: 1 addition & 0 deletions ignite/contrib/metrics/regression/r2_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class R2Score(_BaseRegression):
- `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
"""

def reset(self):
self._num_examples = 0
self._sum_of_errors = 0
Expand Down
1 change: 1 addition & 0 deletions ignite/contrib/metrics/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ def activated_output_transform(output):
roc_auc = ROC_AUC(activated_output_transform)

"""

def __init__(self, output_transform=lambda x: x):
super(ROC_AUC, self).__init__(roc_auc_compute_fn, output_transform=output_transform)
26 changes: 15 additions & 11 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Sequence, Union, Optional, Callable, Dict, Any, Tuple
import torch

from ignite.engine.engine import Engine
from ignite.engine.events import State, Events
from ignite.utils import convert_tensor
from ignite.metrics import Metric

__all__ = [
'create_supervised_trainer',
Expand All @@ -12,7 +14,8 @@
]


def _prepare_batch(batch, device=None, non_blocking=False):
def _prepare_batch(batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False):
"""Prepare batch for training: pass to a device with options.

"""
Expand All @@ -21,10 +24,11 @@ def _prepare_batch(batch, device=None, non_blocking=False):
convert_tensor(y, device=device, non_blocking=non_blocking))


def create_supervised_trainer(model, optimizer, loss_fn,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
def create_supervised_trainer(model: torch.nn.Module, optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item()) -> Engine:
"""
Factory function for creating a trainer for supervised models.

Expand All @@ -50,7 +54,7 @@ def create_supervised_trainer(model, optimizer, loss_fn,
if device:
model.to(device)

def _update(engine, batch):
def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
Expand All @@ -63,10 +67,10 @@ def _update(engine, batch):
return Engine(_update)


def create_supervised_evaluator(model, metrics=None,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred: (y_pred, y,)):
def create_supervised_evaluator(model: torch.nn.Module, metrics: Optional[Dict[str, Metric]] = None,
device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred: (y_pred, y,)) -> Engine:
"""
Factory function for creating an evaluator for supervised models.

Expand Down Expand Up @@ -94,7 +98,7 @@ def create_supervised_evaluator(model, metrics=None,
if device:
model.to(device)

def _inference(engine, batch):
def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
Expand Down
Loading