From 3c8babc6c406261fe49c62da14effbeeae006b6a Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 22 Feb 2022 20:10:16 +0000 Subject: [PATCH 1/7] Minor Bug Fix --- elegy/callbacks/wandb_callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index d6ab3749..4a278f0a 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -132,7 +132,6 @@ def on_epoch_end(self, epoch: int, logs=None): if log_key[:4] != "val_": log_key = "train_" + log_key self.run.log({log_key: logs[key]}, step=self.global_step) - return elif epoch % self.update_freq == 0: for key in logs: From 10e177a8f025337a8456e7526a1cfd0b979b1958 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 22 Feb 2022 20:16:50 +0000 Subject: [PATCH 2/7] updated wandb callback signature --- elegy/callbacks/wandb_callback.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index 4a278f0a..18adcdf4 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -28,6 +28,8 @@ def __init__( job_type: Optional[str] = None, config: Union[Dict, str, None] = None, update_freq: Union[str, int] = "epoch", + monitor: str = "val_loss", + mode: str = "min" **kwargs ): """ @@ -58,7 +60,10 @@ def __init__( losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If using an integer, let's say `1000`, the callback will write the metrics and losses to TensorBoard every 1000 batches. Note that writing too frequently to TensorBoard can slow - down your training. + down your training. + monitor: (str) name of metric to monitor. Defaults to `'val_loss'`. + mode: (str) one of {`'min'`, `'max'`}. `'min'` - save model when monitor is minimized. + `'max'` - save model when monitor is maximized. Defaults to `'min'`. """ super().__init__() self.run = wandb.init( From 3f112b3ec97d1ad5b83c96e391af8c60e646fee4 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 22 Feb 2022 20:19:36 +0000 Subject: [PATCH 3/7] fixed typo --- elegy/callbacks/wandb_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index 18adcdf4..8ec79c61 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -29,7 +29,7 @@ def __init__( config: Union[Dict, str, None] = None, update_freq: Union[str, int] = "epoch", monitor: str = "val_loss", - mode: str = "min" + mode: str = "min", **kwargs ): """ From 4d08a48c10e5ad77536fafd74a350f162aed85ef Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 22 Feb 2022 20:33:36 +0000 Subject: [PATCH 4/7] Added model checkpointing --- elegy/callbacks/wandb_callback.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index 8ec79c61..f18c0d48 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -3,6 +3,7 @@ # https://github.com/poets-ai/elegy/blob/master/elegy/callbacks/tensorboard.py +import math import wandb from typing import Union, Optional, Dict @@ -78,6 +79,9 @@ def __init__( self.write_per_batch = True self._constant_fields = ["size"] self._constants = {} + self._monitor = monitor + self._mode = mode + self._monitor_metric_val = math.inf if mode == "min" else -math.inf try: self.update_freq = int(update_freq) except ValueError as e: @@ -94,6 +98,11 @@ def _gather_configs(self): for _var in module_attributes: if type(module_attributes[_var]) == str or type(module_attributes[_var]) == int: wandb.run.config[_var] = module_attributes[_var] + + def _add_model_as_artifact(self, model_path: str): + artifact = wandb.Artifact(model_path, type='model') + artifact.add_dir(model_path) + self.run.log_artifact(artifact) def on_train_begin(self, logs=None): self.steps = self.params["steps"] @@ -144,8 +153,15 @@ def on_epoch_end(self, epoch: int, logs=None): if log_key[:4] != "val_": log_key = "train_" + log_key self.run.log({log_key: logs[key]}, step=epoch) + + self._model_path = f"model-best-{epoch}-{self.run.name}" + if self._mode == "min" and logs[self._monitor] < self._monitor_metric_val: + self.model.save(self._model_path) + elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val: + self.model.save(self._model_path) def on_train_end(self, logs=None): + self._add_model_as_artifact(self._model_path) for key in self._constant_fields: wandb.run.summary[key] = self._constants[key] self.run.finish() From 8929139359fc68770afbb0b47adc70a301012eaf Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 22 Feb 2022 20:41:32 +0000 Subject: [PATCH 5/7] updated checkpoint system --- elegy/callbacks/wandb_callback.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index f18c0d48..ffd00f36 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -82,6 +82,7 @@ def __init__( self._monitor = monitor self._mode = mode self._monitor_metric_val = math.inf if mode == "min" else -math.inf + self._model_path = f"model-best-0" try: self.update_freq = int(update_freq) except ValueError as e: @@ -154,10 +155,11 @@ def on_epoch_end(self, epoch: int, logs=None): log_key = "train_" + log_key self.run.log({log_key: logs[key]}, step=epoch) - self._model_path = f"model-best-{epoch}-{self.run.name}" if self._mode == "min" and logs[self._monitor] < self._monitor_metric_val: + self._model_path = f"model-best-{epoch + 1}-{self.run.name}" self.model.save(self._model_path) elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val: + self._model_path = f"model-best-{epoch + 1}-{self.run.name}" self.model.save(self._model_path) def on_train_end(self, logs=None): From 494942e7f801510da382cbc35c515a7c82a9ed2b Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Wed, 23 Feb 2022 06:34:58 +0000 Subject: [PATCH 6/7] updated checkpoint logic --- elegy/callbacks/wandb_callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index ffd00f36..bb8b9ffe 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -158,9 +158,11 @@ def on_epoch_end(self, epoch: int, logs=None): if self._mode == "min" and logs[self._monitor] < self._monitor_metric_val: self._model_path = f"model-best-{epoch + 1}-{self.run.name}" self.model.save(self._model_path) + self._monitor_metric_val = logs[self._monitor] elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val: self._model_path = f"model-best-{epoch + 1}-{self.run.name}" self.model.save(self._model_path) + self._monitor_metric_val = logs[self._monitor] def on_train_end(self, logs=None): self._add_model_as_artifact(self._model_path) From 9522109a7536d989e5d427c06a9fda394835ffb9 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Wed, 23 Feb 2022 06:55:35 +0000 Subject: [PATCH 7/7] ipdated checkpoint system --- elegy/callbacks/wandb_callback.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/elegy/callbacks/wandb_callback.py b/elegy/callbacks/wandb_callback.py index bb8b9ffe..c56fb972 100644 --- a/elegy/callbacks/wandb_callback.py +++ b/elegy/callbacks/wandb_callback.py @@ -157,10 +157,12 @@ def on_epoch_end(self, epoch: int, logs=None): if self._mode == "min" and logs[self._monitor] < self._monitor_metric_val: self._model_path = f"model-best-{epoch + 1}-{self.run.name}" + print(f"{self._monitor} decreased at epoch {epoch}. Saving Model at {self._model_path}") self.model.save(self._model_path) self._monitor_metric_val = logs[self._monitor] elif self._mode == "max" and logs[self._monitor] > self._monitor_metric_val: self._model_path = f"model-best-{epoch + 1}-{self.run.name}" + print(f"{self._monitor} increased at epoch {epoch}. Saving Model at {self._model_path}") self.model.save(self._model_path) self._monitor_metric_val = logs[self._monitor]