Skip to content

Commit

Permalink
Lazy load wandb (#228)
Browse files Browse the repository at this point in the history
* load wandb inside __init__ and set it as a field

* update lock

* try trigger CI

* update examples with new datasets API

* import numpy in getting started guides

* Fix getting started guides
  • Loading branch information
cgarciae authored Mar 23, 2022
1 parent c690748 commit 1ebd306
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- uses: actions/setup-python@v2
- uses: pre-commit/action@v2.0.3
test:
name: Run Tests
name: Run Tests
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
runs-on: ubuntu-latest
strategy:
Expand Down
82 changes: 38 additions & 44 deletions docs/getting-started/high-level-api.ipynb

Large diffs are not rendered by default.

50 changes: 30 additions & 20 deletions docs/getting-started/low-level-api.ipynb

Large diffs are not rendered by default.

32 changes: 18 additions & 14 deletions elegy/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
import math
from typing import Dict, Optional, Union

import wandb

from .callback import Callback


class WandbCallback(Callback):
"""
Callback that streams epoch results to a [Weights & Biases](https://wandb.ai/) run.
Callback that streams epoch results to a [Weights & Biases](https://self.wandb.ai/) run.
```python
wandb.login()
wandb_logger = WandbCallback(project="sample-wandb-project", job_type="train")
self.wandb.login()
wandb_logger = WandbCallback(project="sample-self.wandb-project", job_type="train")
model.fit(X_train, Y_train, callbacks=[wandb_logger])
```
"""
Expand Down Expand Up @@ -52,13 +50,13 @@ def __init__(
runs together into larger experiments using group. For example, you might have multiple
jobs in a group, with job types like train and eval. Setting this makes it easy to
filter and group similar runs together in the UI so you can compare apples to apples.
config: (dict, argparse, absl.flags, str, optional) This sets `wandb.config`, a dictionary-like
config: (dict, argparse, absl.flags, str, optional) This sets `self.wandb.config`, a dictionary-like
object for saving inputs to your job, like hyperparameters for a model or settings for a
data preprocessing job. The config will show up in a table in the UI that you can use to
group, filter, and sort runs. Keys should not contain . in their names, and values should
be under 10 MB. If dict, argparse or `absl.flags`: will load the key value pairs into the
wandb.config object. If str: will look for a yaml file by that name, and load config from
that file into the `wandb.config` object.
self.wandb.config object. If str: will look for a yaml file by that name, and load config from
that file into the `self.wandb.config` object.
update_freq: (str, int)`'batch'` or `'epoch'` or integer. When using `'batch'`, writes the
losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If
using an integer, let's say `1000`, the callback will write the metrics and losses to
Expand All @@ -70,18 +68,21 @@ def __init__(
mode: (str) one of {`'min'`, `'max'`, `'every'`}. `'min'` - save model when monitor is minimized.
`'max'` - save model when monitor is maximized. `'every'` - save model after every epoch.
"""
import wandb

self.wandb = wandb
super().__init__()
self.run = (
wandb.init(
self.wandb.init(
project=project,
name=name,
entity=entity,
job_type=job_type,
config=config,
**kwargs,
)
if wandb.run is None
else wandb.run
if self.wandb.run is None
else self.wandb.run
)
self.keys = None
self.write_per_batch = True
Expand All @@ -106,16 +107,18 @@ def __init__(
raise e

def _gather_configs(self):

module_attributes = vars(vars(self.model)["module"])
for _var in module_attributes:
if (
type(module_attributes[_var]) == str
or type(module_attributes[_var]) == int
):
wandb.run.config[_var] = module_attributes[_var]
self.wandb.run.config[_var] = module_attributes[_var]

def _add_model_as_artifact(self, model_path: str, epoch: int):
artifact = wandb.Artifact(

artifact = self.wandb.Artifact(
f"model-{self.run.name}",
type="model",
metadata={"epoch": epoch, "model_path": model_path},
Expand Down Expand Up @@ -204,6 +207,7 @@ def on_epoch_end(self, epoch: int, logs=None):
self._add_model_as_artifact(self._model_path, epoch=self._best_epoch)

def on_train_end(self, logs=None):

for key in self._constant_fields:
wandb.run.summary[key] = self._constants[key]
self.wandb.run.summary[key] = self._constants[key]
self.run.finish()
4 changes: 2 additions & 2 deletions examples/elegy/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_test = dataset["test"]["image"]
X_train = np.stack(dataset["train"]["image"])
X_test = np.stack(dataset["test"]["image"])

print("X_train:", X_train.shape, X_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self, training: bool = True):

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

if training:
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = np.array(dataset["train"]["image"], dtype=np.uint8)
X_test = np.array(dataset["test"]["image"], dtype=np.uint8)
X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)

# Now binarize data
X_train = (X_train / 255.0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/flax/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/flax/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def main(
logdir = os.path.join(logdir, current_time)

dataset = load_dataset("mnist")
X_train = np.array(dataset["train"]["image"], dtype=np.uint8)
X_test = np.array(dataset["test"]["image"], dtype=np.uint8)
X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)
# Now binarize data
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/haiku/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/haiku/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_test = dataset["test"]["image"]
X_train = np.stack(dataset["train"]["image"])
X_test = np.stack(dataset["test"]["image"])
# Now binarize data
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/linear_classifier_test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/linear_classifier_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
Loading

0 comments on commit 1ebd306

Please sign in to comment.