diff --git a/docs/source/contrib/handlers.rst b/docs/source/contrib/handlers.rst index f24c1b56b64..a2f1cce94cc 100644 --- a/docs/source/contrib/handlers.rst +++ b/docs/source/contrib/handlers.rst @@ -85,6 +85,12 @@ polyaxon_logger :members: :inherited-members: +wandb_logger +--------------- + +.. automodule:: ignite.contrib.handlers.wandb_logger + :members: + :inherited-members: More on parameter scheduling ---------------------------- diff --git a/examples/contrib/mnist/mnist_with_wandb_logger.py b/examples/contrib/mnist/mnist_with_wandb_logger.py new file mode 100644 index 00000000000..8827b1cbdad --- /dev/null +++ b/examples/contrib/mnist/mnist_with_wandb_logger.py @@ -0,0 +1,175 @@ +""" + MNIST example with training and validation monitoring using Weights & Biases + + Requirements: + Weights & Biases: `pip install wandb` + + Usage: + + Make sure you are logged into Weights & Biases (use the `wandb` command). + + Run the example: + ```bash + python mnist_with_wandb_logger.py + ``` + + Go to https://wandb.com and explore your experiment. +""" +import sys +from argparse import ArgumentParser +import logging + +import torch +from torch.utils.data import DataLoader +from torch import nn +import torch.nn.functional as F +from torch.optim import SGD +from torchvision.datasets import MNIST +from torchvision.transforms import Compose, ToTensor, Normalize + +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator +from ignite.metrics import Accuracy, Loss +from ignite.handlers import ModelCheckpoint + +from ignite.contrib.handlers.wandb_logger import * + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=-1) + + +def get_data_loaders(train_batch_size, val_batch_size): + data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + + train_loader = DataLoader( + MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True + ) + + val_loader = DataLoader( + MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False + ) + return train_loader, val_loader + + +def run(train_batch_size, val_batch_size, epochs, lr, momentum): + train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) + model = Net() + device = "cpu" + + if torch.cuda.is_available(): + device = "cuda" + + optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) + criterion = nn.CrossEntropyLoss() + trainer = create_supervised_trainer(model, optimizer, criterion, device=device) + + metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} + + train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) + validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) + + @trainer.on(Events.EPOCH_COMPLETED) + def compute_metrics(engine): + train_evaluator.run(train_loader) + validation_evaluator.run(val_loader) + + wandb_logger = WandBLogger( + project="pytorch-ignite-integration", + name="ignite-mnist-example", + config={ + "train_batch_size": train_batch_size, + "val_batch_size": val_batch_size, + "epochs": epochs, + "lr": lr, + "momentum": momentum, + }, + ) + + wandb_logger.attach( + trainer, + log_handler=OutputHandler( + tag="training", + output_transform=lambda loss: {"batchloss": loss}, + metric_names="all", + global_step_transform=lambda *_: trainer.state.iteration, + ), + event_name=Events.ITERATION_COMPLETED(every=100), + ) + + wandb_logger.attach( + train_evaluator, + log_handler=OutputHandler( + tag="training", metric_names=["loss", "accuracy"], global_step_transform=lambda *_: trainer.state.iteration + ), + event_name=Events.EPOCH_COMPLETED, + ) + + wandb_logger.attach( + validation_evaluator, + log_handler=OutputHandler( + tag="validation", + metric_names=["loss", "accuracy"], + global_step_transform=lambda *_: trainer.state.iteration, + ), + event_name=Events.EPOCH_COMPLETED, + ) + + wandb_logger.attach( + trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100) + ) + wandb_logger.watch(model, log="all") + + def score_function(engine): + return engine.state.metrics["accuracy"] + + model_checkpoint = ModelCheckpoint( + wandb_logger.run.dir, + n_saved=2, + filename_prefix="best", + score_function=score_function, + score_name="validation_accuracy", + global_step_transform=lambda *_: trainer.state.iteration, + ) + validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) + + # kick everything off + trainer.run(train_loader, max_epochs=epochs) + wandb_logger.close() + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training (default: 64)") + parser.add_argument( + "--val_batch_size", type=int, default=1000, help="input batch size for validation (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train (default: 10)") + parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)") + parser.add_argument("--momentum", type=float, default=0.5, help="SGD momentum (default: 0.5)") + + args = parser.parse_args() + + # Setup engine logger + logger = logging.getLogger("ignite.engine.engine.Engine") + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum) diff --git a/ignite/contrib/handlers/__init__.py b/ignite/contrib/handlers/__init__.py index 4365e45b95d..e876cff8819 100644 --- a/ignite/contrib/handlers/__init__.py +++ b/ignite/contrib/handlers/__init__.py @@ -15,5 +15,6 @@ from ignite.contrib.handlers.visdom_logger import VisdomLogger from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger from ignite.contrib.handlers.mlflow_logger import MLflowLogger +from ignite.contrib.handlers.wandb_logger import WandBLogger from ignite.contrib.handlers.base_logger import global_step_from_engine from ignite.contrib.handlers.lr_finder import FastaiLRFinder diff --git a/ignite/contrib/handlers/wandb_logger.py b/ignite/contrib/handlers/wandb_logger.py new file mode 100644 index 00000000000..6b191d7ce14 --- /dev/null +++ b/ignite/contrib/handlers/wandb_logger.py @@ -0,0 +1,284 @@ +from ignite.contrib.handlers.base_logger import ( + BaseLogger, + BaseOutputHandler, + BaseOptimizerParamsHandler, + global_step_from_engine, +) + + +__all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"] + + +class OutputHandler(BaseOutputHandler): + """Helper handler to log engine's output and/or metrics + + Examples: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after + # each epoch. We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch + # of the `trainer`: + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metric_names=["nll", "accuracy"], + global_step_transform=global_step_from_engine(trainer)), + event_name=Events.EPOCH_COMPLETED) + + Another example, where model is evaluated every 500 iterations: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + @trainer.on(Events.ITERATION_COMPLETED(every=500)) + def evaluate(engine): + evaluator.run(validation_set, max_epochs=1) + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + def global_step_transform(*args, **kwargs): + return trainer.state.iteration + + # Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after + # every 500 iterations. Since evaluator engine does not have access to the training iteration, we + # provide a global_step_transform to return the trainer.state.iteration for the global_step, each time + # evaluator metrics are plotted on Weights & Biases. + + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metrics=["nll", "accuracy"], + global_step_transform=global_step_transform), + event_name=Events.EPOCH_COMPLETED) + + Args: + tag (str): common title for all produced plots. For example, 'training' + metric_names (list of str, optional): list of metric names to plot or a string "all" to plot all available + metrics. + output_transform (callable, optional): output transform function to prepare `engine.state.output` as a number. + For example, `output_transform = lambda output: output` + This function can also return a dictionary, e.g `{'loss': loss1, 'another_loss': loss2}` to label the plot + with corresponding keys. + another_engine (Engine): Deprecated (see :attr:`global_step_transform`). Another engine to use to provide the + value of event. Typically, user can provide + the trainer if this handler is attached to an evaluator and thus it logs proper trainer's + epoch/iteration value. + global_step_transform (callable, optional): global step transform function to output a desired global step. + Input of the function is `(engine, event_name)`. Output of function should be an integer. + Default is None, global_step based on attached engine. If provided, + uses function output as global_step. To setup global step from another engine, please use + :meth:`~ignite.contrib.handlers.wandb_logger.global_step_from_engine`. + sync (bool, optional): If set to False, process calls to log in a seperate thread. Default (None) uses whatever + the default value of wandb.log. + + Note: + + Example of `global_step_transform`: + + .. code-block:: python + + def global_step_transform(engine, event_name): + return engine.state.get_event_attrib_value(event_name) + + """ + + def __init__( + self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, sync=None + ): + super().__init__(tag, metric_names, output_transform, another_engine, global_step_transform) + self.sync = sync + + def __call__(self, engine, logger, event_name): + + if not isinstance(logger, WandBLogger): + raise RuntimeError("Handler '{}' works only with WandBLogger.".format(self.__class__.__name__)) + + global_step = self.global_step_transform(engine, event_name) + if not isinstance(global_step, int): + raise TypeError( + "global_step must be int, got {}." + " Please check the output of global_step_transform.".format(type(global_step)) + ) + + metrics = self._setup_output_metrics(engine) + if self.tag is not None: + metrics = {"{tag}/{name}".format(tag=self.tag, name=name): value for name, value in metrics.items()} + + logger.log(metrics, step=global_step, sync=self.sync) + + +class OptimizerParamsHandler(BaseOptimizerParamsHandler): + """Helper handler to log optimizer parameters + + Examples: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration + wandb_logger.attach(trainer, + log_handler=OptimizerParamsHandler(optimizer), + event_name=Events.ITERATION_STARTED) + + Args: + optimizer (torch.optim.Optimizer): torch optimizer which parameters to log + param_name (str): parameter name + tag (str, optional): common title for all produced plots. For example, 'generator' + sync (bool, optional): If set to False, process calls to log in a seperate thread. Default (None) uses whatever + the default value of wandb.log. + """ + + def __init__(self, optimizer, param_name="lr", tag=None, sync=None): + super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag) + self.sync = sync + + def __call__(self, engine, logger, event_name): + if not isinstance(logger, WandBLogger): + raise RuntimeError("Handler 'OptimizerParamsHandler' works only with WandBLogger") + + global_step = engine.state.get_event_attrib_value(event_name) + tag_prefix = "{}/".format(self.tag) if self.tag else "" + params = { + "{}{}/group_{}".format(tag_prefix, self.param_name, i): float(param_group[self.param_name]) + for i, param_group in enumerate(self.optimizer.param_groups) + } + logger.log(params, step=global_step, sync=self.sync) + + +class WandBLogger(BaseLogger): + """ + ` + Weights & Biases handler to log metrics, model/optimizer parameters, gradients during training and validation. + It can also be used to log model checkpoints to the Weights & Biases cloud. + + .. code-block:: bash + + pip install wandb + + This class is also a wrapper for the wandb module. This means that you can call any wandb function using + this wrapper. See examples on how to save model parameters and gradients. + + Args: + args, kwargs: Please see wandb.init for documentation of possible parameters. + + Examples: + + .. code-block:: python + + from ignite.contrib.handlers.wandb_logger import * + + # Create a logger. All parameters are optional. See documentation + # on wandb.init for details. + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + # Attach the logger to the trainer to log training loss at each iteration + wandb_logger.attach(trainer, + log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), + event_name=Events.ITERATION_COMPLETED) + + # Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch + # We setup `global_step_transform=global_step_from_engine(trainer)` to take the epoch + # of the `trainer` instead of `evaluator`. + wandb_logger.attach(evaluator, + log_handler=OutputHandler(tag="validation", + metric_names=["nll", "accuracy"], + global_step_transform=global_step_from_engine(trainer)), + event_name=Events.EPOCH_COMPLETED) + + # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration + wandb_logger.attach(trainer, + log_handler=OptimizerParamsHandler(optimizer), + event_name=Events.ITERATION_STARTED) + + If you want to log model gradients, the model call graph, etc., use the logger as wrapper of wandb. Refer + to the documentation of wandb.watch for details: + + .. code-block:: python + + wandb_logger = WandBLogger( + entity="shared", + project="pytorch-ignite-integration", + name="cnn-mnist", + config={"max_epochs": 10}, + tags=["pytorch-ignite", "minst"] + ) + + model = torch.nn.Sequential(...) + wandb_logger.watch(model) + + For model checkpointing, Weights & Biases creates a local run dir, and automatically synchronizes all + files saved there at the end of the run. You can just use the `wandb_logger.run.dir` as path for the + `ModelCheckpoint`: + + .. code-block:: python + + from ignite.handlers import ModelCheckpoint + + def score_function(engine): + return engine.state.metrics['accuracy'] + + model_checkpoint = ModelCheckpoint( + wandb_logger.run.dir, n_saved=2, filename_prefix='best', + require_empty=False, score_function=score_function, + score_name="validation_accuracy", + global_step_transform=global_step_from_engine(trainer)) + evaluator.add_event_handler( + Events.COMPLETED, model_checkpoint, {'model': model}) + """ + + def __init__(self, *args, **kwargs): + try: + import wandb + + self._wandb = wandb + except ImportError: + raise RuntimeError( + "This contrib module requires wandb to be installed. " + "You man install wandb with the command:\n pip install wandb\n" + ) + if kwargs.get("init", True): + wandb.init(*args, **kwargs) + + def __getattr__(self, attr): + return getattr(self._wandb, attr) diff --git a/tests/ignite/contrib/handlers/test_wandb_logger.py b/tests/ignite/contrib/handlers/test_wandb_logger.py new file mode 100644 index 00000000000..28173c62242 --- /dev/null +++ b/tests/ignite/contrib/handlers/test_wandb_logger.py @@ -0,0 +1,238 @@ +from unittest.mock import call, MagicMock +import pytest +import torch + +from ignite.engine import Events, State +from ignite.contrib.handlers.wandb_logger import * + + +def test_optimizer_params_handler_wrong_setup(): + with pytest.raises(TypeError): + OptimizerParamsHandler(optimizer=None) + + optimizer = MagicMock(spec=torch.optim.Optimizer) + handler = OptimizerParamsHandler(optimizer=optimizer) + + mock_logger = MagicMock() + mock_engine = MagicMock() + with pytest.raises(RuntimeError, match="Handler 'OptimizerParamsHandler' works only with WandBLogger"): + handler(mock_engine, mock_logger, Events.ITERATION_STARTED) + + +def test_optimizer_params(): + optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01) + wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"lr/group_0": 0.01}, step=123, sync=None) + + wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"generator/lr/group_0": 0.01}, step=123, sync=None) + + +def test_output_handler_with_wrong_logger_type(): + wrapper = OutputHandler("tag", output_transform=lambda x: x) + + mock_logger = MagicMock() + mock_engine = MagicMock() + with pytest.raises(RuntimeError, match="Handler 'OutputHandler' works only with WandBLogger"): + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + +def test_output_handler_output_transform(): + wrapper = OutputHandler("tag", output_transform=lambda x: x) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.output = 12345 + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=None) + + wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=None) + + +def test_output_handler_output_transform_sync(): + wrapper = OutputHandler("tag", output_transform=lambda x: x, sync=False) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.output = 12345 + mock_engine.state.iteration = 123 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + + mock_logger.log.assert_called_once_with({"tag/output": 12345}, step=123, sync=False) + + wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x}, sync=True) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"another_tag/loss": 12345}, step=123, sync=True) + + +def test_output_handler_metric_names(): + wrapper = OutputHandler("tag", metric_names=["a", "b"]) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 1, "b": 5}) + mock_engine.state.iteration = 5 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 1, "tag/b": 5}, step=5, sync=None) + + wrapper = OutputHandler("tag", metric_names=["a", "c"]) + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 55.56, "c": "Some text"}) + mock_engine.state.iteration = 7 + + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 55.56, "tag/c": "Some text"}, step=7, sync=None) + + # all metrics + wrapper = OutputHandler("tag", metric_names="all") + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) + mock_engine.state.iteration = 5 + + wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) + mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45}, step=5, sync=None) + + +def test_output_handler_both(): + wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x}) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State(metrics={"a": 12.23, "b": 23.45}) + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + + mock_logger.log.assert_called_once_with({"tag/a": 12.23, "tag/b": 23.45, "tag/loss": 12345}, step=5, sync=None) + + +def test_output_handler_with_wrong_global_step_transform_output(): + def global_step_transform(*args, **kwargs): + return "a" + + wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + with pytest.raises(TypeError, match="global_step must be int"): + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + + +def test_output_handler_with_global_step_transform(): + def global_step_transform(*args, **kwargs): + return 10 + + wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform) + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 5 + mock_engine.state.output = 12345 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + mock_logger.log.assert_called_once_with({"tag/loss": 12345}, step=10, sync=None) + + +def test_output_handler_with_global_step_from_engine(): + + mock_another_engine = MagicMock() + mock_another_engine.state = State() + mock_another_engine.state.epoch = 10 + mock_another_engine.state.output = 12.345 + + wrapper = OutputHandler( + "tag", + output_transform=lambda x: {"loss": x}, + global_step_transform=global_step_from_engine(mock_another_engine), + ) + + mock_logger = MagicMock(spec=WandBLogger) + mock_logger.log = MagicMock() + + mock_engine = MagicMock() + mock_engine.state = State() + mock_engine.state.epoch = 1 + mock_engine.state.output = 0.123 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + mock_logger.log.assert_called_once_with( + {"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None + ) + + mock_another_engine.state.epoch = 11 + mock_engine.state.output = 1.123 + + wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) + assert mock_logger.log.call_count == 2 + mock_logger.log.assert_has_calls( + [call({"tag/loss": mock_engine.state.output}, step=mock_another_engine.state.epoch, sync=None)] + ) + + +@pytest.fixture +def no_site_packages(): + import sys + + wandb_client_modules = {} + for k in sys.modules: + if "wandb" in k: + wandb_client_modules[k] = sys.modules[k] + for k in wandb_client_modules: + del sys.modules[k] + + prev_path = list(sys.path) + sys.path = [p for p in sys.path if "site-packages" not in p] + yield "no_site_packages" + sys.path = prev_path + for k in wandb_client_modules: + sys.modules[k] = wandb_client_modules[k] + + +def test_no_wandb_client(no_site_packages): + + with pytest.raises(RuntimeError, match=r"This contrib module requires wandb to be installed."): + WandBLogger()