diff --git a/docs/elegy-module.md b/docs/elegy-module.md index b6ecb5f2..ae69cdb9 100644 --- a/docs/elegy-module.md +++ b/docs/elegy-module.md @@ -162,7 +162,7 @@ def call(self, x): return x ``` Here `Linear` and `Conv` are dangerously swapped based on some condition. If you want to -do this you can just clare them unconditionally and use them inside the condition: +do this you can just declare them unconditionally and use them inside the condition: ```python def call(self, x): @@ -207,8 +207,8 @@ customize this name by using the `name` argument available in the `Module`'s con A big theme in Jax is that state and computation are separate, this is a requirement because in order for combinators like `jax.grad` and `jax.jit` to work you need pure functions. -Elegy as you've seen is object oriented so additional effort ir required to properly convert all -the global states and `Module` parameters an inputs to a function so Jax can track them. To achieve +Elegy as you've seen is object oriented so additional effort is required to properly convert all +the global states and `Module` parameters as inputs to a function so Jax can track them. To achieve Elegy implements its own `jit` and `value_and_grad` function wrappers that handle this for you. Lets create a low level training loop using the previous definition `MLP` along with these functions: @@ -258,7 +258,7 @@ def update(x, y): ``` After that we just use `tree_multimap` to implement Gradient Descent -and get our `new_parameters` and then use the `set_parameters` method our +and get our `new_parameters` and then use the `set_parameters` method in our `Module` to update its state. ```python hl_lines="4 5 6 8" diff --git a/elegy/model/model_core.py b/elegy/model/model_core.py index c2678b3e..ae74ebc1 100644 --- a/elegy/model/model_core.py +++ b/elegy/model/model_core.py @@ -124,17 +124,18 @@ def train_step_fn(self, model: M) -> TrainStep[M]: @dataclass(unsafe_hash=True) class JIT(DistributedStrategy): + # donate 'model' memory buffer since we return an updated model def init_step_fn(self, model: M) -> InitStep[M]: - return jax.jit(model.__class__._static_init_step) + return jax.jit(model.__class__._static_init_step, donate_argnums=0) def pred_step_fn(self, model: M) -> PredStep[M]: - return jax.jit(model.__class__._static_pred_step) + return jax.jit(model.__class__._static_pred_step, donate_argnums=0) def test_step_fn(self, model: M) -> TestStep[M]: - return jax.jit(model.__class__._static_test_step) + return jax.jit(model.__class__._static_test_step, donate_argnums=0) def train_step_fn(self, model: M) -> TrainStep[M]: - return jax.jit(model.__class__._static_train_step) + return jax.jit(model.__class__._static_train_step, donate_argnums=0) @dataclass(unsafe_hash=True) @@ -204,12 +205,14 @@ def init_step_fn(self, model: M) -> InitStep[M]: return jax.pmap( model.__class__._static_init_step, axis_name="device", + donate_argnums=0, ) def pred_step_fn(self, model: M) -> PredStep[M]: return jax.pmap( model.__class__._static_pred_step, axis_name="device", + donate_argnums=0, ) def test_step_fn(self, model: M) -> TestStep[M]: @@ -217,6 +220,7 @@ def test_step_fn(self, model: M) -> TestStep[M]: model.__class__._static_test_step, axis_name="device", out_axes=(0, None, 0), # None = logs not replicated + donate_argnums=0, ) def train_step_fn(self, model: M) -> TrainStep[M]: @@ -224,6 +228,7 @@ def train_step_fn(self, model: M) -> TrainStep[M]: model.__class__._static_train_step, axis_name="device", out_axes=(None, 0), # None = logs not replicated + donate_argnums=0, )