Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Artifact-based Checkpoint Storage #1

Merged
merged 7 commits into from
Feb 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions elegy/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/poets-ai/elegy/blob/master/elegy/callbacks/tensorboard.py


import math
import wandb
from typing import Union, Optional, Dict

Expand All @@ -28,6 +29,8 @@ def __init__(
job_type: Optional[str] = None,
config: Union[Dict, str, None] = None,
update_freq: Union[str, int] = "epoch",
monitor: str = "val_loss",
mode: str = "min",
**kwargs
):
"""
Expand Down Expand Up @@ -58,7 +61,10 @@ def __init__(
losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If
using an integer, let's say `1000`, the callback will write the metrics and losses to
TensorBoard every 1000 batches. Note that writing too frequently to TensorBoard can slow
down your training.
down your training.
monitor: (str) name of metric to monitor. Defaults to `'val_loss'`.
mode: (str) one of {`'min'`, `'max'`}. `'min'` - save model when monitor is minimized.
`'max'` - save model when monitor is maximized. Defaults to `'min'`.
"""
super().__init__()
self.run = wandb.init(
Expand All @@ -73,6 +79,10 @@ def __init__(
self.write_per_batch = True
self._constant_fields = ["size"]
self._constants = {}
self._monitor = monitor
self._mode = mode
self._monitor_metric_val = math.inf if mode == "min" else -math.inf
self._model_path = f"model-best-0"
try:
self.update_freq = int(update_freq)
except ValueError as e:
Expand All @@ -89,6 +99,11 @@ def _gather_configs(self):
for _var in module_attributes:
if type(module_attributes[_var]) == str or type(module_attributes[_var]) == int:
wandb.run.config[_var] = module_attributes[_var]

def _add_model_as_artifact(self, model_path: str):
artifact = wandb.Artifact(model_path, type='model')
artifact.add_dir(model_path)
self.run.log_artifact(artifact)

def on_train_begin(self, logs=None):
self.steps = self.params["steps"]
Expand Down Expand Up @@ -132,16 +147,27 @@ def on_epoch_end(self, epoch: int, logs=None):
if log_key[:4] != "val_":
log_key = "train_" + log_key
self.run.log({log_key: logs[key]}, step=self.global_step)
return

elif epoch % self.update_freq == 0:
for key in logs:
log_key = key
if log_key[:4] != "val_":
log_key = "train_" + log_key
self.run.log({log_key: logs[key]}, step=epoch)

if self._mode == "min" and logs[self._monitor] < self._monitor_metric_val:
self._model_path = f"model-best-{epoch + 1}-{self.run.name}"
print(f"{self._monitor} decreased at epoch {epoch}. Saving Model at {self._model_path}")
self.model.save(self._model_path)
self._monitor_metric_val = logs[self._monitor]
elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val:
self._model_path = f"model-best-{epoch + 1}-{self.run.name}"
print(f"{self._monitor} increased at epoch {epoch}. Saving Model at {self._model_path}")
self.model.save(self._model_path)
self._monitor_metric_val = logs[self._monitor]

def on_train_end(self, logs=None):
self._add_model_as_artifact(self._model_path)
for key in self._constant_fields:
wandb.run.summary[key] = self._constants[key]
self.run.finish()