From 26442cde41918378210d156ca86c0365ee552ccd Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Mon, 13 Apr 2020 17:13:20 -0700 Subject: [PATCH 01/10] initial implementation of WandBLogger with some output handlers. --- ignite/contrib/handlers/__init__.py | 1 + ignite/contrib/handlers/wandb_logger.py | 76 +++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 ignite/contrib/handlers/wandb_logger.py diff --git a/ignite/contrib/handlers/__init__.py b/ignite/contrib/handlers/__init__.py index 4365e45b95d..e876cff8819 100644 --- a/ignite/contrib/handlers/__init__.py +++ b/ignite/contrib/handlers/__init__.py @@ -15,5 +15,6 @@ from ignite.contrib.handlers.visdom_logger import VisdomLogger from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger from ignite.contrib.handlers.mlflow_logger import MLflowLogger +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.contrib.handlers.base_logger import global_step_from_engine from ignite.contrib.handlers.lr_finder import FastaiLRFinder diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py new file mode 100644 index 00000000000..6e6955a73b5 --- /dev/null +++ b/ignite/contrib/handlers/wandb_logger.py @@ -0,0 +1,76 @@ +from ignite.contrib.handlers.base_logger import ( + BaseLogger, + BaseOutputHandler, + BaseOptimizerParamsHandler, + global_step_from_engine, +) + + +__all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", + "global_step_from_engine"] + + +class OutputHandler(BaseOutputHandler): + + def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, + sync=None): + super().__init__(tag, metric_names, output_transform, another_engine, global_step_transform) + self.sync = sync + + def __call__(self, engine, logger, event_name): + + if not isinstance(logger, WandBLogger): + raise RuntimeError("Handler '{}' works only with WandBLogger.".format(self.__class__.__name__)) + + global_step = self.global_step_transform(engine, event_name) + if not isinstance(global_step, int): + raise TypeError( + "global_step must be int, got {}." + " Please check the output of global_step_transform.".format(type(global_step)) + ) + + metrics = self._setup_output_metrics(engine) + if self.tag is not None: + metrics = {"{tag}/{name}".format(tag=self.tag, name=name): value + for name, value in metrics.items()} + + logger.log(metrics, step=global_step, sync=self.sync) + + +class OptimizerParamsHandler(BaseOptimizerParamsHandler): + + def __init__(self, optimizer, param_name="lr", tag=None): + super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) + + def __call__(self, engine, logger, event_name): + if not isinstance(logger, WandBLogger): + raise RuntimeError("Handler 'OptimizerParamsHandler' works only with WandBLogger") + + global_step = engine.state.get_event_attrib_value(event_name) + tag_prefix = "{}/".format(self.tag) if self.tag else "" + params = { + "{}{}/group_{}".format(tag_prefix, self.param_name, i): float(param_group[self.param_name]) + for i, param_group in enumerate(self.optimizer.param_groups) + } + logger.log(params, step=global_step) + + +class WandBLogger(BaseLogger): + + def __init__(self, *args, **kwargs): + try: + import wandb + self._wandb = wandb + except ImportError: + raise RuntimeError( + "This contrib module requires wandb to be installed. " + "You man install wandb with the command:\n pip install wandb\n" + ) + if kwargs.get('init', True): + wandb.init(*args, **kwargs) + + def __getattr__(self, attr): + def wrapper(*args, **kwargs): + return getattr(self._wandb, attr)(*args, **kwargs) + + return wrapper From ed3e4f64a7e23e291aeb395e0bce1d503de8eea7 Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Wed, 15 Apr 2020 17:47:20 -0700 Subject: [PATCH 02/10] added tests --- .../contrib/handlers/test_wandb_logger.py | 244 ++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/ignite/contrib/handlers/test_wandb_logger.py diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py new file mode 100644 index 00000000000..a2d6439e4f8 --- /dev/null +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -0,0 +1,244 @@ +from unittest.mock import call, MagicMock +import pytest +import torch + +from ignite.engine import Events, State +from ignite.contrib.handlers.wandb_logger import * + + +def test_optimizer_params_handler_wrong_setup(): + with pytest.raises(TypeError): + OptimizerParamsHandler(optimizer=None) + + optimizer = MagicMock(spec=torch.optim.Optimizer) + handler = OptimizerParamsHandler(optimizer=optimizer) + + mock_logger = MagicMock() + mock_engine = MagicMock() + with pytest.raises(RuntimeError, match="Handler 'OptimizerParamsHandler' works only with WandBLogger"): + handler(mock_engine, mock_logger, Events.ITERATION_STARTED) + + +def test_optimizer_params(): + optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) + wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123) + + wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123) + + +def test_output_handler_with_wrong_logger_type(): + wrapper = OutputHandler("tag", output_transform=lambda x: x) + + mock_logger = MagicMock() + mock_engine = MagicMock() + with pytest.raises(RuntimeError, match="Handler 'OutputHandler' works only with WandBLogger"): + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + +def test_output_handler_output_transform(): + wrapper = OutputHandler("tag", output_transform=lambda x: x) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.output = 12345 + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=None) + + wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=None) + + +def test_output_handler_output_transform_sync(): + wrapper = OutputHandler("tag", output_transform=lambda x: x, sync=False) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.output = 12345 + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=False) + + wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}, sync=True) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=True) + + +def test_output_handler_metric_names(): + wrapper = OutputHandler("tag", metric_names=["a", "b"]) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 1, "b": 5}) + mock_engine.state.iteration = 5 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, + step=5, sync=None) + + wrapper = OutputHandler("tag", metric_names=["a", "c"]) + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) + mock_engine.state.iteration = 7 + + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 55.56, + "tag/c": "Some text"}, + step=7, sync=None) + + # all metrics + wrapper = OutputHandler("tag", metric_names="all") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) + mock_engine.state.iteration = 5 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 12.23, + "tag/b": 23.45}, + step=5, sync=None) + + +def test_output_handler_both(): + wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + + mock_logger.log.assert_called_once_with( + {"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None + ) + +def test_output_handler_with_wrong_global_step_transform_output(): + def global_step_transform(*args, **kwargs): + return "a" + + wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + with pytest.raises(TypeError, match="global_step must be int"): + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + + +def test_output_handler_with_global_step_transform(): + def global_step_transform(*args, **kwargs): + return 10 + + wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + mock_logger.log.assert_called_once_with( + {"tag/loss": 12345}, step=10, sync=None) + + +def test_output_handler_with_global_step_from_engine(): + + mock_another_engine = MagicMock() + mock_another_engine.state = State() + mock_another_engine.state.epoch = 10 + mock_another_engine.state.output = 12.345 + + wrapper = OutputHandler( + "tag", + output_transform=lambda x: {"loss": x}, + global_step_transform=global_step_from_engine(mock_another_engine), + ) + + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 1 + mock_engine.state.output = 0.123 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + mock_logger.log.assert_called_once_with( + {"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None) + + mock_another_engine.state.epoch = 11 + mock_engine.state.output = 1.123 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + assert mock_logger.log.call_count == 2 + mock_logger.log.assert_has_calls( + [call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None)] + ) + + +@pytest.fixture +def no_site_packages(): + import sys + + wandb_client_modules = {} + for k in sys.modules: + if "wandb" in k: + wandb_client_modules[k] = sys.modules[k] + for k in wandb_client_modules: + del sys.modules[k] + + prev_path = list(sys.path) + sys.path = [p for p in sys.path if "site-packages" not in p] + yield "no_site_packages" + sys.path = prev_path + for k in wandb_client_modules: + sys.modules[k] = wandb_client_modules[k] + + +def test_no_neptune_client(no_site_packages): + + with pytest.raises(RuntimeError, match=r"This contrib module requires wandb to be installed."): + WandBLogger() From 723a667257cfb86c89da0ef10f2a7b5b35f8b076 Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Wed, 15 Apr 2020 18:21:45 -0700 Subject: [PATCH 03/10] added documentation --- ignite/contrib/handlers/wandb_logger.py | 223 +++++++++++++++++++++++- 1 file changed, 219 insertions(+), 4 deletions(-) diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index 6e6955a73b5..b81681f94f2 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -6,11 +6,108 @@ ) -__all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", - "global_step_from_engine"] +__all__ = [ + "WandBLogger", + "OutputHandler", + "OptimizerParamsHandler", + "global_step_from_engine" +] class OutputHandler(BaseOutputHandler): + """Helper handler to log engine's output and/or metrics + + Examples: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after + # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch + # of the `trainer`: + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metric_names=["nll", "accuracy"], + global_step_transform=global_step_from_engine(trainer)), + event_name=Events.EPOCH_COMPLETED) + + Another example, where model is evaluated every 500 iterations: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + @trainer.on(Events.ITERATION_COMPLETED(every=500)) + def evaluate(engine): + evaluator.run(validation_set, max_epochs=1) + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + def global_step_transform(*args, **kwargs): + return trainer.state.iteration + + # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after + # every 500 iterations. Since evaluator engine does not have access to the training iteration, we + # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time + # evaluator metrics are plotted on Weights & Biases. + + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metrics=["nll", "accuracy"], + global_step_transform=global_step_transform), + event_name=Events.EPOCH_COMPLETED) + + Args: + tag (str): common title for all produced plots. For example, 'training' + metric_names (list of str, optional): list of metric names to plot or a string "all" to plot all available + metrics. + output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number. + For example, `output_transform = lambda output: output` + This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot + with corresponding keys. + another_engine (Engine): Deprecated (see :attr:`global_step_transform`). Another engine to use to provide the + value of event. Typically, user can provide + the trainer if this handler is attached to an evaluator and thus it logs proper trainer's + epoch/iteration value. + global_step_transform (callable, optional): global step transform function to output a desired global step. + Input of the function is `(engine, event_name)`. Output of function should be an integer. + Default is None, global_step based on attached engine. If provided, + uses function output as global_step. To setup global step from another engine, please use + :meth:`~ignite.contrib.handlers.wandb_logger.global_step_from_engine`. + sync (bool, optional): If set to False, process calls to log in a seperate thread. Default (None) uses whatever + the default value of wandb.log. + + Note: + + Example of `global_step_transform`: + + .. code-block:: python + + def global_step_transform(engine, event_name): + return engine.state.get_event_attrib_value(event_name) + + """ def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, sync=None): @@ -38,9 +135,41 @@ def __call__(self, engine, logger, event_name): class OptimizerParamsHandler(BaseOptimizerParamsHandler): + """Helper handler to log optimizer parameters + + Examples: + + .. code-block:: python - def __init__(self, optimizer, param_name="lr", tag=None): + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration + wandb_logger.attach(trainer, + log_handler=OptimizerParamsHandler(optimizer), + event_name=Events.ITERATION_STARTED) + + Args: + optimizer (torch.optim.Optimizer): torch optimizer which parameters to log + param_name (str): parameter name + tag (str, optional): common title for all produced plots. For example, 'generator' + sync (bool, optional): If set to False, process calls to log in a seperate thread. Default (None) uses whatever + the default value of wandb.log. + """ + + def __init__(self, optimizer, param_name="lr", tag=None, sync=None): super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) + self.sync = sync def __call__(self, engine, logger, event_name): if not isinstance(logger, WandBLogger): @@ -52,10 +181,96 @@ def __call__(self, engine, logger, event_name): "{}{}/group_{}".format(tag_prefix, self.param_name, i): float(param_group[self.param_name]) for i, param_group in enumerate(self.optimizer.param_groups) } - logger.log(params, step=global_step) + logger.log(params, step=global_step, sync=self.sync) class WandBLogger(BaseLogger): + """ + ` + Weights & Biases handler to log metrics, model/optimizer parameters, gradients during training and validation. + It can also be used to log model checkpoints to the Weights & Biases cloud. + + .. code-block:: bash + + pip install wandb + + This class is also a wrapper for the wandb module. This means that you can call any wandb function using + this wrapper. See examples on how to save model parameters and gradients. + + Args: + args, kwargs: Please see wandb.init for documentation of possible parameters. + + Examples: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the trainer to log training loss at each iteration + wandb_logger.attach(trainer, + log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), + event_name=Events.ITERATION_COMPLETED) + + # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch + # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch + # of the `trainer` instead of `evaluator`. + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metric_names=["nll", "accuracy"], + global_step_transform=global_step_from_engine(trainer)), + event_name=Events.EPOCH_COMPLETED) + + # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration + wandb_logger.attach(trainer, + log_handler=OptimizerParamsHandler(optimizer), + event_name=Events.ITERATION_STARTED) + + If you want to log model gradients, the model call graph, etc., use the logger as wrapper of wandb. Refer + to the documentation of wandb.watch for details: + + .. code-block:: python + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + model = torch.nn.Sequential(...) + wandb_logger.watch(model) + + For model checkpointing, Weights & Biases creates a local run dir, and automatically synchronizes all + files saved there at the end of the run. You can just use the `wandb_logger.run.dir` as path for the + `ModelCheckpoint`: + + .. code-block:: python + + from ignite.handlers import ModelCheckpoint + + def score_function(engine): + return engine.state.metrics['accuracy'] + + model_checkpoint = ModelCheckpoint( + wandb_logger.run.dir, n_saved=2, filename_prefix='best', + require_empty=False, score_function=score_function, + score_name="validation_accuracy", + global_step_transform=global_step_from_engine(trainer)) + evaluator.add_event_handler( + Events.COMPLETED, model_checkpoint, {'model': model}) + """ def __init__(self, *args, **kwargs): try: From ce978e018c911647035204ad7f7b2b23387b2adf Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Wed, 15 Apr 2020 18:42:28 -0700 Subject: [PATCH 04/10] fixed test --- tests/ignite/contrib/handlers/test_wandb_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py index a2d6439e4f8..3e5cbdb1901 100644 --- a/tests/ignite/contrib/handlers/test_wandb_logger.py +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -29,14 +29,14 @@ def test_optimizer_params(): mock_engine.state.iteration = 123 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) - mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123) + mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123, sync=None) wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator") mock_logger = MagicMock(spec=WandBLogger) mock_logger.log = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) - mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123) + mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123, sync=None) def test_output_handler_with_wrong_logger_type(): From 26aece0125271ac021d534de5e6bd25f3de6575a Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Thu, 16 Apr 2020 07:29:45 +0000 Subject: [PATCH 05/10] autopep8 fix --- ignite/contrib/handlers/wandb_logger.py | 18 ++++++--------- .../contrib/handlers/test_wandb_logger.py | 22 +++++++------------ 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index b81681f94f2..2d9e0a5ee79 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -6,12 +6,7 @@ ) -__all__ = [ - "WandBLogger", - "OutputHandler", - "OptimizerParamsHandler", - "global_step_from_engine" -] +__all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] class OutputHandler(BaseOutputHandler): @@ -109,8 +104,9 @@ def global_step_transform(engine, event_name): """ - def __init__(self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, - sync=None): + def __init__( + self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, sync=None + ): super().__init__(tag, metric_names, output_transform, another_engine, global_step_transform) self.sync = sync @@ -128,8 +124,7 @@ def __call__(self, engine, logger, event_name): metrics = self._setup_output_metrics(engine) if self.tag is not None: - metrics = {"{tag}/{name}".format(tag=self.tag, name=name): value - for name, value in metrics.items()} + metrics = {"{tag}/{name}".format(tag=self.tag, name=name): value for name, value in metrics.items()} logger.log(metrics, step=global_step, sync=self.sync) @@ -275,13 +270,14 @@ def score_function(engine): def __init__(self, *args, **kwargs): try: import wandb + self._wandb = wandb except ImportError: raise RuntimeError( "This contrib module requires wandb to be installed. " "You man install wandb with the command:\n pip install wandb\n" ) - if kwargs.get('init', True): + if kwargs.get("init", True): wandb.init(*args, **kwargs) def __getattr__(self, attr): diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py index 3e5cbdb1901..554eb9e94f1 100644 --- a/tests/ignite/contrib/handlers/test_wandb_logger.py +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -102,8 +102,7 @@ def test_output_handler_metric_names(): mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) - mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, - step=5, sync=None) + mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5, sync=None) wrapper = OutputHandler("tag", metric_names=["a", "c"]) mock_engine = MagicMock() @@ -114,9 +113,7 @@ def test_output_handler_metric_names(): mock_logger.log = MagicMock() wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) - mock_logger.log.assert_called_once_with({"tag/a": 55.56, - "tag/c": "Some text"}, - step=7, sync=None) + mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7, sync=None) # all metrics wrapper = OutputHandler("tag", metric_names="all") @@ -128,9 +125,7 @@ def test_output_handler_metric_names(): mock_engine.state.iteration = 5 wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) - mock_logger.log.assert_called_once_with({"tag/a": 12.23, - "tag/b": 23.45}, - step=5, sync=None) + mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5, sync=None) def test_output_handler_both(): @@ -145,9 +140,8 @@ def test_output_handler_both(): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) - mock_logger.log.assert_called_once_with( - {"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None - ) + mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None) + def test_output_handler_with_wrong_global_step_transform_output(): def global_step_transform(*args, **kwargs): @@ -180,8 +174,7 @@ def global_step_transform(*args, **kwargs): mock_engine.state.output = 12345 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) - mock_logger.log.assert_called_once_with( - {"tag/loss": 12345}, step=10, sync=None) + mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10, sync=None) def test_output_handler_with_global_step_from_engine(): @@ -207,7 +200,8 @@ def test_output_handler_with_global_step_from_engine(): wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.log.assert_called_once_with( - {"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None) + {"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None + ) mock_another_engine.state.epoch = 11 mock_engine.state.output = 1.123 From f2cfa394af47139723079312b5571e43c76c1d73 Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Sat, 18 Apr 2020 11:02:34 -0700 Subject: [PATCH 06/10] added mnist example, fixed wandb wrapping, fixed typo in tests; --- .../contrib/mnist/mnist_with_wandb_logger.py | 183 ++++++++++++++++++ ignite/contrib/handlers/wandb_logger.py | 5 +- .../contrib/handlers/test_wandb_logger.py | 2 +- 3 files changed, 185 insertions(+), 5 deletions(-) create mode 100644 examples/contrib/mnist/mnist_with_wandb_logger.py diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py new file mode 100644 index 00000000000..7c2ba0b44ed --- /dev/null +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -0,0 +1,183 @@ +""" + MNIST example with training and validation monitoring using Weights & Biases + + Requirements: + Weights & Biases: `pip install wandb` + + Usage: + + Make sure you are logged into Weights & Biases (use the `wandb` command). + + Run the example: + ```bash + python mnist_with_wandb_logger.py + ``` + + Go to https://wandb.com and explore your experiment. +""" +import sys +from argparse import ArgumentParser +import logging + +import torch +from torch.utils.data import DataLoader +from torch import nn +import torch.nn.functional as F +from torch.optim import SGD +from torchvision.datasets import MNIST +from torchvision.transforms import Compose, ToTensor, Normalize + +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator +from ignite.metrics import Accuracy, Loss +from ignite.handlers import ModelCheckpoint + +from ignite.contrib.handlers.wandb_logger import * + + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=-1) + + +def get_data_loaders(train_batch_size, val_batch_size): + data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + + train_loader = DataLoader( + MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True + ) + + val_loader = DataLoader( + MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False + ) + return train_loader, val_loader + + +def run(train_batch_size, val_batch_size, epochs, lr, momentum): + train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) + model = Net() + device = "cpu" + + if torch.cuda.is_available(): + device = "cuda" + + optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) + criterion = nn.CrossEntropyLoss() + trainer = create_supervised_trainer(model, optimizer, criterion, device=device) + + if sys.version_info > (3,): + from ignite.contrib.metrics.gpu_info import GpuInfo + + try: + GpuInfo().attach(trainer) + except RuntimeError: + print( + "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " + "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " + "install it : `pip install pynvml`" + ) + + metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} + + train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) + validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) + + @trainer.on(Events.EPOCH_COMPLETED) + def compute_metrics(engine): + train_evaluator.run(train_loader) + validation_evaluator.run(val_loader) + + wandb_logger = WandBLogger( + project="pytorch-ignite-integration", + name="ignite-mnist-example", + config={ + "train_batch_size": train_batch_size, + "val_batch_size": val_batch_size, + "epochs": epochs, + "lr": lr, + "momentum": momentum, + }, + ) + + def iteration(engine): + def wrapper(_, event_name): + return engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) + return wrapper + + wandb_logger.attach( + trainer, + log_handler=OutputHandler( + tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", + global_step_transform=iteration(trainer) + ), + event_name=Events.ITERATION_COMPLETED(every=100), + ) + + wandb_logger.attach( + train_evaluator, + log_handler=OutputHandler(tag="training", metric_names=["loss", "accuracy"], + global_step_transform=iteration(trainer)), + event_name=Events.EPOCH_COMPLETED, + ) + + wandb_logger.attach( + validation_evaluator, + log_handler=OutputHandler(tag="validation", metric_names=["loss", "accuracy"], + global_step_transform=iteration(trainer)), + event_name=Events.EPOCH_COMPLETED, + ) + + wandb_logger.attach( + trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100) + ) + wandb_logger.watch(model, log="all") + + def score_function(engine): + return engine.state.metrics['accuracy'] + + model_checkpoint = ModelCheckpoint( + wandb_logger.run.dir, n_saved=2, filename_prefix='best', score_function=score_function, + score_name="validation_accuracy", global_step_transform=iteration(trainer) + ) + validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model}) + + # kick everything off + trainer.run(train_loader, max_epochs=epochs) + wandb_logger.close() + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training (default: 64)") + parser.add_argument( + "--val_batch_size", type=int, default=1000, help="input batch size for validation (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train (default: 10)") + parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)") + parser.add_argument("--momentum", type=float, default=0.5, help="SGD momentum (default: 0.5)") + + args = parser.parse_args() + + # Setup engine logger + logger = logging.getLogger("ignite.engine.engine.Engine") + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum) diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py index 2d9e0a5ee79..6b191d7ce14 100644 --- a/ignite/contrib/handlers/wandb_logger.py +++ b/ignite/contrib/handlers/wandb_logger.py @@ -281,7 +281,4 @@ def __init__(self, *args, **kwargs): wandb.init(*args, **kwargs) def __getattr__(self, attr): - def wrapper(*args, **kwargs): - return getattr(self._wandb, attr)(*args, **kwargs) - - return wrapper + return getattr(self._wandb, attr) diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py index 554eb9e94f1..28173c62242 100644 --- a/tests/ignite/contrib/handlers/test_wandb_logger.py +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -232,7 +232,7 @@ def no_site_packages(): sys.modules[k] = wandb_client_modules[k] -def test_no_neptune_client(no_site_packages): +def test_no_wandb_client(no_site_packages): with pytest.raises(RuntimeError, match=r"This contrib module requires wandb to be installed."): WandBLogger() From dab7cdc48a380ba1261ee6b529cbc7b7d832eae9 Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Sat, 18 Apr 2020 18:05:03 +0000 Subject: [PATCH 07/10] autopep8 fix --- .../contrib/mnist/mnist_with_wandb_logger.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py index 7c2ba0b44ed..5b71191b98b 100644 --- a/examples/contrib/mnist/mnist_with_wandb_logger.py +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -34,7 +34,6 @@ from ignite.contrib.handlers.wandb_logger import * - class Net(nn.Module): def __init__(self): super(Net, self).__init__() @@ -116,28 +115,33 @@ def compute_metrics(engine): def iteration(engine): def wrapper(_, event_name): return engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) + return wrapper wandb_logger.attach( trainer, log_handler=OutputHandler( - tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", - global_step_transform=iteration(trainer) + tag="training", + output_transform=lambda loss: {"batchloss": loss}, + metric_names="all", + global_step_transform=iteration(trainer), ), event_name=Events.ITERATION_COMPLETED(every=100), ) wandb_logger.attach( train_evaluator, - log_handler=OutputHandler(tag="training", metric_names=["loss", "accuracy"], - global_step_transform=iteration(trainer)), + log_handler=OutputHandler( + tag="training", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) + ), event_name=Events.EPOCH_COMPLETED, ) wandb_logger.attach( validation_evaluator, - log_handler=OutputHandler(tag="validation", metric_names=["loss", "accuracy"], - global_step_transform=iteration(trainer)), + log_handler=OutputHandler( + tag="validation", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) + ), event_name=Events.EPOCH_COMPLETED, ) @@ -147,13 +151,17 @@ def wrapper(_, event_name): wandb_logger.watch(model, log="all") def score_function(engine): - return engine.state.metrics['accuracy'] + return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( - wandb_logger.run.dir, n_saved=2, filename_prefix='best', score_function=score_function, - score_name="validation_accuracy", global_step_transform=iteration(trainer) + wandb_logger.run.dir, + n_saved=2, + filename_prefix="best", + score_function=score_function, + score_name="validation_accuracy", + global_step_transform=iteration(trainer), ) - validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model}) + validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) From 2d9582b5b6d25ef3c05030ecb026ebae52679854 Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Tue, 21 Apr 2020 16:43:16 -0700 Subject: [PATCH 08/10] updated docs, simplified example --- docs/source/contrib/handlers.rst | 6 ++++++ .../contrib/mnist/mnist_with_wandb_logger.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/contrib/handlers.rst b/docs/source/contrib/handlers.rst index f24c1b56b64..a2f1cce94cc 100644 --- a/docs/source/contrib/handlers.rst +++ b/docs/source/contrib/handlers.rst @@ -85,6 +85,12 @@ polyaxon_logger :members: :inherited-members: +wandb_logger +--------------- + +.. automodule:: ignite.contrib.handlers.wandb_logger + :members: + :inherited-members: More on parameter scheduling ---------------------------- diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py index 5b71191b98b..f36a2c5dae4 100644 --- a/examples/contrib/mnist/mnist_with_wandb_logger.py +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -112,19 +112,13 @@ def compute_metrics(engine): }, ) - def iteration(engine): - def wrapper(_, event_name): - return engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) - - return wrapper - wandb_logger.attach( trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", - global_step_transform=iteration(trainer), + global_step_transform=lambda *_: trainer.state.iteration, ), event_name=Events.ITERATION_COMPLETED(every=100), ) @@ -132,16 +126,16 @@ def wrapper(_, event_name): wandb_logger.attach( train_evaluator, log_handler=OutputHandler( - tag="training", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) - ), + tag="training", metric_names=["loss", "accuracy"], + global_step_transform=lambda *_: trainer.state.iteration), event_name=Events.EPOCH_COMPLETED, ) wandb_logger.attach( validation_evaluator, log_handler=OutputHandler( - tag="validation", metric_names=["loss", "accuracy"], global_step_transform=iteration(trainer) - ), + tag="validation", metric_names=["loss", "accuracy"], + global_step_transform=lambda *_: trainer.state.iteration), event_name=Events.EPOCH_COMPLETED, ) @@ -159,7 +153,7 @@ def score_function(engine): filename_prefix="best", score_function=score_function, score_name="validation_accuracy", - global_step_transform=iteration(trainer), + global_step_transform=lambda *_: trainer.state.iteration ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) From 04381cc713fd12f5c8624d3dd9cf164b43392d39 Mon Sep 17 00:00:00 2001 From: Filip Korzeniowski Date: Tue, 21 Apr 2020 17:01:29 -0700 Subject: [PATCH 09/10] removed explicit gpu logging from example --- examples/contrib/mnist/mnist_with_wandb_logger.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py index f36a2c5dae4..ca75f321cb4 100644 --- a/examples/contrib/mnist/mnist_with_wandb_logger.py +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -78,18 +78,6 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum): criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) - if sys.version_info > (3,): - from ignite.contrib.metrics.gpu_info import GpuInfo - - try: - GpuInfo().attach(trainer) - except RuntimeError: - print( - "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " - "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " - "install it : `pip install pynvml`" - ) - metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) From 07ae0a2ec18d93d7966fd47d8c9fdd7f127b1acb Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Wed, 22 Apr 2020 00:03:39 +0000 Subject: [PATCH 10/10] autopep8 fix --- examples/contrib/mnist/mnist_with_wandb_logger.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py index ca75f321cb4..8827b1cbdad 100644 --- a/examples/contrib/mnist/mnist_with_wandb_logger.py +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -114,16 +114,18 @@ def compute_metrics(engine): wandb_logger.attach( train_evaluator, log_handler=OutputHandler( - tag="training", metric_names=["loss", "accuracy"], - global_step_transform=lambda *_: trainer.state.iteration), + tag="training", metric_names=["loss", "accuracy"], global_step_transform=lambda *_: trainer.state.iteration + ), event_name=Events.EPOCH_COMPLETED, ) wandb_logger.attach( validation_evaluator, log_handler=OutputHandler( - tag="validation", metric_names=["loss", "accuracy"], - global_step_transform=lambda *_: trainer.state.iteration), + tag="validation", + metric_names=["loss", "accuracy"], + global_step_transform=lambda *_: trainer.state.iteration, + ), event_name=Events.EPOCH_COMPLETED, ) @@ -141,7 +143,7 @@ def score_function(engine): filename_prefix="best", score_function=score_function, score_name="validation_accuracy", - global_step_transform=lambda *_: trainer.state.iteration + global_step_transform=lambda *_: trainer.state.iteration, ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})