Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights and Biases Callback for Elegy #220

Merged
merged 39 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6f2be70
feature: added wandb callback
soumik12345 Feb 14, 2022
3ba4ce3
chore: added WandbCallback import inside __init__.py
soumik12345 Feb 14, 2022
08808ea
chore: applied black
soumik12345 Feb 14, 2022
90c6b21
chore: added documentation for WandbCallback
soumik12345 Feb 14, 2022
1d7cedf
chore: updated wandb run initialization + added run finish on ending …
soumik12345 Feb 17, 2022
cb7fcf1
updated poetry dependencies
soumik12345 Feb 17, 2022
b7ff92c
updated run initialization in wandb callback
soumik12345 Feb 20, 2022
ebb945e
fixec wandb import
soumik12345 Feb 20, 2022
2d1210a
fixed on_train_end in wandb callback
soumik12345 Feb 20, 2022
aac79bd
added 'size' to ignore fields + added 'train' alias to train metrics
soumik12345 Feb 22, 2022
f423035
fix: wandb run keys
soumik12345 Feb 22, 2022
dfc2a3a
fix: updated run key names
soumik12345 Feb 22, 2022
cad4ae5
minor bug fix
soumik12345 Feb 22, 2022
90bacff
feature: instead of logging constant metrics step-wise, logs them to …
soumik12345 Feb 22, 2022
db07abe
updated wandb sun summary
soumik12345 Feb 22, 2022
40609df
updated configs
soumik12345 Feb 22, 2022
5e08aaf
updated callback
soumik12345 Feb 22, 2022
03bb72e
updated callback
soumik12345 Feb 22, 2022
d1abc88
added method to gather module attributes
soumik12345 Feb 22, 2022
0cf6479
fix: fixed minor bug in gathering model attributes
soumik12345 Feb 22, 2022
78d35e1
updated callback example
soumik12345 Feb 22, 2022
b21f38b
added train alias
soumik12345 Feb 22, 2022
3c8babc
Minor Bug Fix
soumik12345 Feb 22, 2022
10e177a
updated wandb callback signature
soumik12345 Feb 22, 2022
3f112b3
fixed typo
soumik12345 Feb 22, 2022
4d08a48
Added model checkpointing
soumik12345 Feb 22, 2022
8929139
updated checkpoint system
soumik12345 Feb 22, 2022
494942e
updated checkpoint logic
soumik12345 Feb 23, 2022
9522109
ipdated checkpoint system
soumik12345 Feb 23, 2022
f7dca2f
Merge pull request #1 from soumik12345/checkpoints
soumik12345 Feb 23, 2022
81c9383
made model checkpoint saving optional
soumik12345 Feb 23, 2022
9b609c9
updated checkpoint system to save artifacts every epoch
soumik12345 Feb 23, 2022
d5837d0
updated checkpoint system
soumik12345 Feb 23, 2022
f95db76
updated checkpoint system
soumik12345 Feb 23, 2022
9c4c42e
updated checkpoint system
soumik12345 Feb 23, 2022
03a57e9
updated checkpoint system
soumik12345 Feb 23, 2022
1cd7cfa
update deps
cgarciae Mar 23, 2022
af1ec99
Merge branch 'master' into pr/soumik12345/220
cgarciae Mar 23, 2022
8d45c2b
update lock
cgarciae Mar 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions elegy/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .sigint import SigInt
from .tensorboard import TensorBoard
from .terminate_nan import TerminateOnNaN
from .wandb_callback import WandbCallback

__all__ = [
"CallbackList",
Expand All @@ -21,4 +22,5 @@
"RemoteMonitor",
"CSVLogger",
"TensorBoard",
"WandbCallback"
]
211 changes: 211 additions & 0 deletions elegy/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Implementation based on tf.keras.callbacks.py and elegy.callbacks.TensorBoard
# https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py
# https://github.com/poets-ai/elegy/blob/master/elegy/callbacks/tensorboard.py


import math
import wandb
from typing import Union, Optional, Dict

from .callback import Callback


class WandbCallback(Callback):
"""
Callback that streams epoch results to a [Weights & Biases](https://wandb.ai/) run.

```python
wandb.login()
wandb_logger = WandbCallback(project="sample-wandb-project", job_type="train")
model.fit(X_train, Y_train, callbacks=[wandb_logger])
```
"""

def __init__(
self,
project: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
job_type: Optional[str] = None,
config: Union[Dict, str, None] = None,
update_freq: Union[str, int] = "epoch",
save_model: bool = False,
monitor: str = "val_loss",
mode: str = "min",
**kwargs,
):
"""
Arguments:
project: (str, optional) The name of the project where you're sending the new run.
If the project is not specified, the run is put in an "Uncategorized" project.
name: (str, optional) A short display name for this run, which is how you'll
identify this run in the UI. By default we generate a random two-word name that
lets you easily cross-reference runs from the table to charts. Keeping these run
names short makes the chart legends and tables easier to read. If you're looking
for a place to save your hyperparameters, we recommend saving those in config.
entity: (str, optional) An entity is a username or team name where you're sending runs.
This entity must exist before you can send runs there, so make sure to create your
account or team in the UI before starting to log runs. If you don't specify an entity,
the run will be sent to your default entity, which is usually your username.
job_type: (str, optional) Specify the type of run, which is useful when you're grouping
runs together into larger experiments using group. For example, you might have multiple
jobs in a group, with job types like train and eval. Setting this makes it easy to
filter and group similar runs together in the UI so you can compare apples to apples.
config: (dict, argparse, absl.flags, str, optional) This sets `wandb.config`, a dictionary-like
object for saving inputs to your job, like hyperparameters for a model or settings for a
data preprocessing job. The config will show up in a table in the UI that you can use to
group, filter, and sort runs. Keys should not contain . in their names, and values should
be under 10 MB. If dict, argparse or `absl.flags`: will load the key value pairs into the
wandb.config object. If str: will look for a yaml file by that name, and load config from
that file into the `wandb.config` object.
update_freq: (str, int)`'batch'` or `'epoch'` or integer. When using `'batch'`, writes the
losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If
using an integer, let's say `1000`, the callback will write the metrics and losses to
TensorBoard every 1000 batches. Note that writing too frequently to TensorBoard can slow
down your training.
save_model: (bool) Save a model when monitor beats all previous epochs if set to `True` otherwise
don't save models.
monitor: (str) name of metric to monitor. Defaults to `'val_loss'`.
mode: (str) one of {`'min'`, `'max'`, `'every'`}. `'min'` - save model when monitor is minimized.
`'max'` - save model when monitor is maximized. `'every'` - save model after every epoch.
"""
super().__init__()
self.run = (
wandb.init(
project=project,
name=name,
entity=entity,
job_type=job_type,
config=config,
**kwargs,
)
if wandb.run is None
else wandb.run
)
self.keys = None
self.write_per_batch = True
self._constant_fields = ["size"]
self._constants = {}
self._save_model = save_model
self._monitor = monitor
self._mode = mode
self._monitor_metric_val = math.inf if mode == "min" else -math.inf
self._best_epoch = 0
self._model_path = f"model-best-0"
self._model_checkpoint = self.model
try:
self.update_freq = int(update_freq)
except ValueError as e:
self.update_freq = 1
if update_freq == "batch":
self.write_per_batch = True
elif update_freq == "epoch":
self.write_per_batch = False
else:
raise e

def _gather_configs(self):
module_attributes = vars(vars(self.model)["module"])
for _var in module_attributes:
if (
type(module_attributes[_var]) == str
or type(module_attributes[_var]) == int
):
wandb.run.config[_var] = module_attributes[_var]

def _add_model_as_artifact(self, model_path: str, epoch: int):
artifact = wandb.Artifact(
f"model-{self.run.name}",
type="model",
metadata={
"epoch": epoch,
"model_path": model_path
}
)
artifact.add_dir(model_path)
self.run.log_artifact(artifact)

def on_train_begin(self, logs=None):
self.steps = self.params["steps"]
self.global_step = 0
self._gather_configs()

def on_train_batch_end(self, batch: int, logs=None):
if not self.write_per_batch:
return
logs = logs or {}
for key in self._constant_fields:
self._constants[key] = logs[key]
logs.pop(key, None)
logs.pop("val_" + key, None)
self.global_step = batch + self.current_epoch * (self.steps)
if self.global_step % self.update_freq == 0:
if self.keys is None:
self.keys = logs.keys()
for key in self.keys:
log_key = key
if log_key[:4] != "val_":
log_key = "train_" + log_key
self.run.log({log_key: logs[key]}, step=self.global_step)

def on_epoch_begin(self, epoch: int, logs=None):
self.current_epoch = epoch

def on_epoch_end(self, epoch: int, logs=None):
logs = logs or {}
for key in self._constant_fields:
self._constants[key] = logs[key]
logs.pop(key, None)
logs.pop("val_" + key, None)

if self.keys is None:
self.keys = logs.keys()

if self.write_per_batch:
for key in logs:
log_key = key
if log_key[:4] != "val_":
log_key = "train_" + log_key
self.run.log({log_key: logs[key]}, step=self.global_step)

elif epoch % self.update_freq == 0:
for key in logs:
log_key = key
if log_key[:4] != "val_":
log_key = "train_" + log_key
self.run.log({log_key: logs[key]}, step=epoch)

if self._save_model:
if self._mode == "every":
self._best_epoch = epoch
self._model_path = f"model-{epoch + 1}-{self.run.name}"
print(f"Saving Model at {self._model_path}")
self._model_checkpoint = self.model
self._model_checkpoint.save(self._model_path)

elif self._mode == "min" and logs[self._monitor] < self._monitor_metric_val:
self._best_epoch = epoch
self._model_path = f"model-best-{epoch + 1}-{self.run.name}"
print(
f"{self._monitor} decreased at epoch {epoch}. Saving Model at {self._model_path}"
)
self._model_checkpoint = self.model
self._model_checkpoint.save(self._model_path)
self._monitor_metric_val = logs[self._monitor]

elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val:
self._best_epoch = epoch
self._model_path = f"model-best-{epoch + 1}-{self.run.name}"
print(
f"{self._monitor} increased at epoch {epoch}. Saving Model at {self._model_path}"
)
self._model_checkpoint = self.model
self._model_checkpoint.save(self._model_path)
self._monitor_metric_val = logs[self._monitor]

self._add_model_as_artifact(self._model_path, epoch=self._best_epoch)

def on_train_end(self, logs=None):
for key in self._constant_fields:
wandb.run.summary[key] = self._constants[key]
self.run.finish()
Loading