Skip to content

Commit

Permalink
Merge branch 'master' into lazy-load-wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 23, 2022
2 parents 605bee8 + c690748 commit 979b5f5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
8 changes: 4 additions & 4 deletions docs/elegy-module.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def call(self, x):
return x
```
Here `Linear` and `Conv` are dangerously swapped based on some condition. If you want to
do this you can just clare them unconditionally and use them inside the condition:
do this you can just declare them unconditionally and use them inside the condition:

```python
def call(self, x):
Expand Down Expand Up @@ -207,8 +207,8 @@ customize this name by using the `name` argument available in the `Module`'s con

A big theme in Jax is that state and computation are separate, this is a requirement
because in order for combinators like `jax.grad` and `jax.jit` to work you need pure functions.
Elegy as you've seen is object oriented so additional effort ir required to properly convert all
the global states and `Module` parameters an inputs to a function so Jax can track them. To achieve
Elegy as you've seen is object oriented so additional effort is required to properly convert all
the global states and `Module` parameters as inputs to a function so Jax can track them. To achieve
Elegy implements its own `jit` and `value_and_grad` function wrappers that handle this for you.

Lets create a low level training loop using the previous definition `MLP` along with these functions:
Expand Down Expand Up @@ -258,7 +258,7 @@ def update(x, y):
```

After that we just use `tree_multimap` to implement Gradient Descent
and get our `new_parameters` and then use the `set_parameters` method our
and get our `new_parameters` and then use the `set_parameters` method in our
`Module` to update its state.

```python hl_lines="4 5 6 8"
Expand Down
13 changes: 9 additions & 4 deletions elegy/model/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ def train_step_fn(self, model: M) -> TrainStep[M]:

@dataclass(unsafe_hash=True)
class JIT(DistributedStrategy):
# donate 'model' memory buffer since we return an updated model
def init_step_fn(self, model: M) -> InitStep[M]:
return jax.jit(model.__class__._static_init_step)
return jax.jit(model.__class__._static_init_step, donate_argnums=0)

def pred_step_fn(self, model: M) -> PredStep[M]:
return jax.jit(model.__class__._static_pred_step)
return jax.jit(model.__class__._static_pred_step, donate_argnums=0)

def test_step_fn(self, model: M) -> TestStep[M]:
return jax.jit(model.__class__._static_test_step)
return jax.jit(model.__class__._static_test_step, donate_argnums=0)

def train_step_fn(self, model: M) -> TrainStep[M]:
return jax.jit(model.__class__._static_train_step)
return jax.jit(model.__class__._static_train_step, donate_argnums=0)


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -204,26 +205,30 @@ def init_step_fn(self, model: M) -> InitStep[M]:
return jax.pmap(
model.__class__._static_init_step,
axis_name="device",
donate_argnums=0,
)

def pred_step_fn(self, model: M) -> PredStep[M]:
return jax.pmap(
model.__class__._static_pred_step,
axis_name="device",
donate_argnums=0,
)

def test_step_fn(self, model: M) -> TestStep[M]:
return jax.pmap(
model.__class__._static_test_step,
axis_name="device",
out_axes=(0, None, 0), # None = logs not replicated
donate_argnums=0,
)

def train_step_fn(self, model: M) -> TrainStep[M]:
return jax.pmap(
model.__class__._static_train_step,
axis_name="device",
out_axes=(None, 0), # None = logs not replicated
donate_argnums=0,
)


Expand Down

0 comments on commit 979b5f5

Please sign in to comment.