-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Weights and Biases Callback for Elegy (#220)
* feature: added wandb callback * chore: added WandbCallback import inside __init__.py * chore: applied black * chore: added documentation for WandbCallback * chore: updated wandb run initialization + added run finish on ending training * updated poetry dependencies * updated run initialization in wandb callback * fixec wandb import * fixed on_train_end in wandb callback * added 'size' to ignore fields + added 'train' alias to train metrics * fix: wandb run keys * fix: updated run key names * minor bug fix * feature: instead of logging constant metrics step-wise, logs them to the run summary * updated wandb sun summary * updated configs * updated callback * updated callback * added method to gather module attributes * fix: fixed minor bug in gathering model attributes * updated callback example * added train alias * Minor Bug Fix * updated wandb callback signature * fixed typo * Added model checkpointing * updated checkpoint system * updated checkpoint logic * ipdated checkpoint system * made model checkpoint saving optional * updated checkpoint system to save artifacts every epoch * updated checkpoint system * updated checkpoint system * updated checkpoint system * updated checkpoint system * update deps * update lock Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>
- Loading branch information
1 parent
3914938
commit c6ad092
Showing
4 changed files
with
1,493 additions
and
990 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Implementation based on tf.keras.callbacks.py and elegy.callbacks.TensorBoard | ||
# https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/callbacks.py | ||
# https://github.com/poets-ai/elegy/blob/master/elegy/callbacks/tensorboard.py | ||
|
||
|
||
import math | ||
import wandb | ||
from typing import Union, Optional, Dict | ||
|
||
from .callback import Callback | ||
|
||
|
||
class WandbCallback(Callback): | ||
""" | ||
Callback that streams epoch results to a [Weights & Biases](https://wandb.ai/) run. | ||
```python | ||
wandb.login() | ||
wandb_logger = WandbCallback(project="sample-wandb-project", job_type="train") | ||
model.fit(X_train, Y_train, callbacks=[wandb_logger]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
project: Optional[str] = None, | ||
name: Optional[str] = None, | ||
entity: Optional[str] = None, | ||
job_type: Optional[str] = None, | ||
config: Union[Dict, str, None] = None, | ||
update_freq: Union[str, int] = "epoch", | ||
save_model: bool = False, | ||
monitor: str = "val_loss", | ||
mode: str = "min", | ||
**kwargs, | ||
): | ||
""" | ||
Arguments: | ||
project: (str, optional) The name of the project where you're sending the new run. | ||
If the project is not specified, the run is put in an "Uncategorized" project. | ||
name: (str, optional) A short display name for this run, which is how you'll | ||
identify this run in the UI. By default we generate a random two-word name that | ||
lets you easily cross-reference runs from the table to charts. Keeping these run | ||
names short makes the chart legends and tables easier to read. If you're looking | ||
for a place to save your hyperparameters, we recommend saving those in config. | ||
entity: (str, optional) An entity is a username or team name where you're sending runs. | ||
This entity must exist before you can send runs there, so make sure to create your | ||
account or team in the UI before starting to log runs. If you don't specify an entity, | ||
the run will be sent to your default entity, which is usually your username. | ||
job_type: (str, optional) Specify the type of run, which is useful when you're grouping | ||
runs together into larger experiments using group. For example, you might have multiple | ||
jobs in a group, with job types like train and eval. Setting this makes it easy to | ||
filter and group similar runs together in the UI so you can compare apples to apples. | ||
config: (dict, argparse, absl.flags, str, optional) This sets `wandb.config`, a dictionary-like | ||
object for saving inputs to your job, like hyperparameters for a model or settings for a | ||
data preprocessing job. The config will show up in a table in the UI that you can use to | ||
group, filter, and sort runs. Keys should not contain . in their names, and values should | ||
be under 10 MB. If dict, argparse or `absl.flags`: will load the key value pairs into the | ||
wandb.config object. If str: will look for a yaml file by that name, and load config from | ||
that file into the `wandb.config` object. | ||
update_freq: (str, int)`'batch'` or `'epoch'` or integer. When using `'batch'`, writes the | ||
losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If | ||
using an integer, let's say `1000`, the callback will write the metrics and losses to | ||
TensorBoard every 1000 batches. Note that writing too frequently to TensorBoard can slow | ||
down your training. | ||
save_model: (bool) Save a model when monitor beats all previous epochs if set to `True` otherwise | ||
don't save models. | ||
monitor: (str) name of metric to monitor. Defaults to `'val_loss'`. | ||
mode: (str) one of {`'min'`, `'max'`, `'every'`}. `'min'` - save model when monitor is minimized. | ||
`'max'` - save model when monitor is maximized. `'every'` - save model after every epoch. | ||
""" | ||
super().__init__() | ||
self.run = ( | ||
wandb.init( | ||
project=project, | ||
name=name, | ||
entity=entity, | ||
job_type=job_type, | ||
config=config, | ||
**kwargs, | ||
) | ||
if wandb.run is None | ||
else wandb.run | ||
) | ||
self.keys = None | ||
self.write_per_batch = True | ||
self._constant_fields = ["size"] | ||
self._constants = {} | ||
self._save_model = save_model | ||
self._monitor = monitor | ||
self._mode = mode | ||
self._monitor_metric_val = math.inf if mode == "min" else -math.inf | ||
self._best_epoch = 0 | ||
self._model_path = f"model-best-0" | ||
self._model_checkpoint = self.model | ||
try: | ||
self.update_freq = int(update_freq) | ||
except ValueError as e: | ||
self.update_freq = 1 | ||
if update_freq == "batch": | ||
self.write_per_batch = True | ||
elif update_freq == "epoch": | ||
self.write_per_batch = False | ||
else: | ||
raise e | ||
|
||
def _gather_configs(self): | ||
module_attributes = vars(vars(self.model)["module"]) | ||
for _var in module_attributes: | ||
if ( | ||
type(module_attributes[_var]) == str | ||
or type(module_attributes[_var]) == int | ||
): | ||
wandb.run.config[_var] = module_attributes[_var] | ||
|
||
def _add_model_as_artifact(self, model_path: str, epoch: int): | ||
artifact = wandb.Artifact( | ||
f"model-{self.run.name}", | ||
type="model", | ||
metadata={ | ||
"epoch": epoch, | ||
"model_path": model_path | ||
} | ||
) | ||
artifact.add_dir(model_path) | ||
self.run.log_artifact(artifact) | ||
|
||
def on_train_begin(self, logs=None): | ||
self.steps = self.params["steps"] | ||
self.global_step = 0 | ||
self._gather_configs() | ||
|
||
def on_train_batch_end(self, batch: int, logs=None): | ||
if not self.write_per_batch: | ||
return | ||
logs = logs or {} | ||
for key in self._constant_fields: | ||
self._constants[key] = logs[key] | ||
logs.pop(key, None) | ||
logs.pop("val_" + key, None) | ||
self.global_step = batch + self.current_epoch * (self.steps) | ||
if self.global_step % self.update_freq == 0: | ||
if self.keys is None: | ||
self.keys = logs.keys() | ||
for key in self.keys: | ||
log_key = key | ||
if log_key[:4] != "val_": | ||
log_key = "train_" + log_key | ||
self.run.log({log_key: logs[key]}, step=self.global_step) | ||
|
||
def on_epoch_begin(self, epoch: int, logs=None): | ||
self.current_epoch = epoch | ||
|
||
def on_epoch_end(self, epoch: int, logs=None): | ||
logs = logs or {} | ||
for key in self._constant_fields: | ||
self._constants[key] = logs[key] | ||
logs.pop(key, None) | ||
logs.pop("val_" + key, None) | ||
|
||
if self.keys is None: | ||
self.keys = logs.keys() | ||
|
||
if self.write_per_batch: | ||
for key in logs: | ||
log_key = key | ||
if log_key[:4] != "val_": | ||
log_key = "train_" + log_key | ||
self.run.log({log_key: logs[key]}, step=self.global_step) | ||
|
||
elif epoch % self.update_freq == 0: | ||
for key in logs: | ||
log_key = key | ||
if log_key[:4] != "val_": | ||
log_key = "train_" + log_key | ||
self.run.log({log_key: logs[key]}, step=epoch) | ||
|
||
if self._save_model: | ||
if self._mode == "every": | ||
self._best_epoch = epoch | ||
self._model_path = f"model-{epoch + 1}-{self.run.name}" | ||
print(f"Saving Model at {self._model_path}") | ||
self._model_checkpoint = self.model | ||
self._model_checkpoint.save(self._model_path) | ||
|
||
elif self._mode == "min" and logs[self._monitor] < self._monitor_metric_val: | ||
self._best_epoch = epoch | ||
self._model_path = f"model-best-{epoch + 1}-{self.run.name}" | ||
print( | ||
f"{self._monitor} decreased at epoch {epoch}. Saving Model at {self._model_path}" | ||
) | ||
self._model_checkpoint = self.model | ||
self._model_checkpoint.save(self._model_path) | ||
self._monitor_metric_val = logs[self._monitor] | ||
|
||
elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val: | ||
self._best_epoch = epoch | ||
self._model_path = f"model-best-{epoch + 1}-{self.run.name}" | ||
print( | ||
f"{self._monitor} increased at epoch {epoch}. Saving Model at {self._model_path}" | ||
) | ||
self._model_checkpoint = self.model | ||
self._model_checkpoint.save(self._model_path) | ||
self._monitor_metric_val = logs[self._monitor] | ||
|
||
self._add_model_as_artifact(self._model_path, epoch=self._best_epoch) | ||
|
||
def on_train_end(self, logs=None): | ||
for key in self._constant_fields: | ||
wandb.run.summary[key] = self._constants[key] | ||
self.run.finish() |
Oops, something went wrong.