Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy load wandb #228

Merged
merged 7 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- uses: actions/setup-python@v2
- uses: pre-commit/action@v2.0.3
test:
name: Run Tests
name: Run Tests
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
runs-on: ubuntu-latest
strategy:
Expand Down
82 changes: 38 additions & 44 deletions docs/getting-started/high-level-api.ipynb

Large diffs are not rendered by default.

50 changes: 30 additions & 20 deletions docs/getting-started/low-level-api.ipynb

Large diffs are not rendered by default.

32 changes: 18 additions & 14 deletions elegy/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
import math
from typing import Dict, Optional, Union

import wandb

from .callback import Callback


class WandbCallback(Callback):
"""
Callback that streams epoch results to a [Weights & Biases](https://wandb.ai/) run.
Callback that streams epoch results to a [Weights & Biases](https://self.wandb.ai/) run.

```python
wandb.login()
wandb_logger = WandbCallback(project="sample-wandb-project", job_type="train")
self.wandb.login()
wandb_logger = WandbCallback(project="sample-self.wandb-project", job_type="train")
model.fit(X_train, Y_train, callbacks=[wandb_logger])
```
"""
Expand Down Expand Up @@ -52,13 +50,13 @@ def __init__(
runs together into larger experiments using group. For example, you might have multiple
jobs in a group, with job types like train and eval. Setting this makes it easy to
filter and group similar runs together in the UI so you can compare apples to apples.
config: (dict, argparse, absl.flags, str, optional) This sets `wandb.config`, a dictionary-like
config: (dict, argparse, absl.flags, str, optional) This sets `self.wandb.config`, a dictionary-like
object for saving inputs to your job, like hyperparameters for a model or settings for a
data preprocessing job. The config will show up in a table in the UI that you can use to
group, filter, and sort runs. Keys should not contain . in their names, and values should
be under 10 MB. If dict, argparse or `absl.flags`: will load the key value pairs into the
wandb.config object. If str: will look for a yaml file by that name, and load config from
that file into the `wandb.config` object.
self.wandb.config object. If str: will look for a yaml file by that name, and load config from
that file into the `self.wandb.config` object.
update_freq: (str, int)`'batch'` or `'epoch'` or integer. When using `'batch'`, writes the
losses and metrics to TensorBoard after each batch. The same applies for `'epoch'`. If
using an integer, let's say `1000`, the callback will write the metrics and losses to
Expand All @@ -70,18 +68,21 @@ def __init__(
mode: (str) one of {`'min'`, `'max'`, `'every'`}. `'min'` - save model when monitor is minimized.
`'max'` - save model when monitor is maximized. `'every'` - save model after every epoch.
"""
import wandb

self.wandb = wandb
super().__init__()
self.run = (
wandb.init(
self.wandb.init(
project=project,
name=name,
entity=entity,
job_type=job_type,
config=config,
**kwargs,
)
if wandb.run is None
else wandb.run
if self.wandb.run is None
else self.wandb.run
)
self.keys = None
self.write_per_batch = True
Expand All @@ -106,16 +107,18 @@ def __init__(
raise e

def _gather_configs(self):

module_attributes = vars(vars(self.model)["module"])
for _var in module_attributes:
if (
type(module_attributes[_var]) == str
or type(module_attributes[_var]) == int
):
wandb.run.config[_var] = module_attributes[_var]
self.wandb.run.config[_var] = module_attributes[_var]

def _add_model_as_artifact(self, model_path: str, epoch: int):
artifact = wandb.Artifact(

artifact = self.wandb.Artifact(
f"model-{self.run.name}",
type="model",
metadata={"epoch": epoch, "model_path": model_path},
Expand Down Expand Up @@ -204,6 +207,7 @@ def on_epoch_end(self, epoch: int, logs=None):
self._add_model_as_artifact(self._model_path, epoch=self._best_epoch)

def on_train_end(self, logs=None):

for key in self._constant_fields:
wandb.run.summary[key] = self._constants[key]
self.wandb.run.summary[key] = self._constants[key]
self.run.finish()
4 changes: 2 additions & 2 deletions examples/elegy/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_test = dataset["test"]["image"]
X_train = np.stack(dataset["train"]["image"])
X_test = np.stack(dataset["test"]["image"])

print("X_train:", X_train.shape, X_train.dtype)
print("X_test:", X_test.shape, X_test.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self, training: bool = True):

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

if training:
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/elegy/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = np.array(dataset["train"]["image"], dtype=np.uint8)
X_test = np.array(dataset["test"]["image"], dtype=np.uint8)
X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)

# Now binarize data
X_train = (X_train / 255.0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/flax/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/flax/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def main(
logdir = os.path.join(logdir, current_time)

dataset = load_dataset("mnist")
X_train = np.array(dataset["train"]["image"], dtype=np.uint8)
X_test = np.array(dataset["test"]["image"], dtype=np.uint8)
X_train = np.array(np.stack(dataset["train"]["image"]), dtype=np.uint8)
X_test = np.array(np.stack(dataset["test"]["image"]), dtype=np.uint8)
# Now binarize data
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/haiku/mnist_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"][..., None]
X_train = np.stack(dataset["train"]["image"])[..., None]
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"][..., None]
X_test = np.stack(dataset["test"]["image"])[..., None]
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/haiku/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_test = dataset["test"]["image"]
X_train = np.stack(dataset["train"]["image"])
X_test = np.stack(dataset["test"]["image"])
# Now binarize data
X_train = (X_train > 0).astype(jnp.float32)
X_test = (X_test > 0).astype(jnp.float32)
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/linear_classifier_test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/linear_classifier_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def main(

dataset = load_dataset("mnist")
dataset.set_format("np")
X_train = dataset["train"]["image"]
X_train = np.stack(dataset["train"]["image"])
y_train = dataset["train"]["label"]
X_test = dataset["test"]["image"]
X_test = np.stack(dataset["test"]["image"])
y_test = dataset["test"]["label"]

print("X_train:", X_train.shape, X_train.dtype)
Expand Down
Loading