Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wandb logger #926

Merged
merged 14 commits into from
Apr 22, 2020
6 changes: 6 additions & 0 deletions docs/source/contrib/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ polyaxon_logger
:members:
:inherited-members:

wandb_logger
---------------

.. automodule:: ignite.contrib.handlers.wandb_logger
:members:
:inherited-members:

More on parameter scheduling
----------------------------
Expand Down
175 changes: 175 additions & 0 deletions examples/contrib/mnist/mnist_with_wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
MNIST example with training and validation monitoring using Weights & Biases

Requirements:
Weights & Biases: `pip install wandb`

Usage:

Make sure you are logged into Weights & Biases (use the `wandb` command).

Run the example:
```bash
python mnist_with_wandb_logger.py
```

Go to https://wandb.com and explore your experiment.
"""
import sys
from argparse import ArgumentParser
import logging

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint

from ignite.contrib.handlers.wandb_logger import *


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)


def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True
)

val_loader = DataLoader(
MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False
)
return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
device = "cpu"

if torch.cuda.is_available():
device = "cuda"

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)

metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

@trainer.on(Events.EPOCH_COMPLETED)
def compute_metrics(engine):
train_evaluator.run(train_loader)
validation_evaluator.run(val_loader)

wandb_logger = WandBLogger(
project="pytorch-ignite-integration",
name="ignite-mnist-example",
config={
"train_batch_size": train_batch_size,
"val_batch_size": val_batch_size,
"epochs": epochs,
"lr": lr,
"momentum": momentum,
},
)

wandb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="training",
output_transform=lambda loss: {"batchloss": loss},
metric_names="all",
global_step_transform=lambda *_: trainer.state.iteration,
),
event_name=Events.ITERATION_COMPLETED(every=100),
)

wandb_logger.attach(
train_evaluator,
log_handler=OutputHandler(
tag="training", metric_names=["loss", "accuracy"], global_step_transform=lambda *_: trainer.state.iteration
),
event_name=Events.EPOCH_COMPLETED,
)

wandb_logger.attach(
validation_evaluator,
log_handler=OutputHandler(
tag="validation",
metric_names=["loss", "accuracy"],
global_step_transform=lambda *_: trainer.state.iteration,
),
event_name=Events.EPOCH_COMPLETED,
)

wandb_logger.attach(
trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_COMPLETED(every=100)
)
wandb_logger.watch(model, log="all")

def score_function(engine):
return engine.state.metrics["accuracy"]

model_checkpoint = ModelCheckpoint(
wandb_logger.run.dir,
n_saved=2,
filename_prefix="best",
score_function=score_function,
score_name="validation_accuracy",
global_step_transform=lambda *_: trainer.state.iteration,
)
validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

# kick everything off
trainer.run(train_loader, max_epochs=epochs)
wandb_logger.close()


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training (default: 64)")
parser.add_argument(
"--val_batch_size", type=int, default=1000, help="input batch size for validation (default: 1000)"
)
parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)")
parser.add_argument("--momentum", type=float, default=0.5, help="SGD momentum (default: 0.5)")

args = parser.parse_args()

# Setup engine logger
logger = logging.getLogger("ignite.engine.engine.Engine")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum)
1 change: 1 addition & 0 deletions ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
from ignite.contrib.handlers.visdom_logger import VisdomLogger
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
from ignite.contrib.handlers.wandb_logger import WandBLogger
from ignite.contrib.handlers.base_logger import global_step_from_engine
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
Loading