Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Donate model's memory buffer to jit/pmap functions. #226

Merged
merged 1 commit into from
Mar 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions elegy/model/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ def train_step_fn(self, model: M) -> TrainStep[M]:

@dataclass(unsafe_hash=True)
class JIT(DistributedStrategy):
# donate 'model' memory buffer since we return an updated model
def init_step_fn(self, model: M) -> InitStep[M]:
return jax.jit(model.__class__._static_init_step)
return jax.jit(model.__class__._static_init_step, donate_argnums=0)

def pred_step_fn(self, model: M) -> PredStep[M]:
return jax.jit(model.__class__._static_pred_step)
return jax.jit(model.__class__._static_pred_step, donate_argnums=0)

def test_step_fn(self, model: M) -> TestStep[M]:
return jax.jit(model.__class__._static_test_step)
return jax.jit(model.__class__._static_test_step, donate_argnums=0)

def train_step_fn(self, model: M) -> TrainStep[M]:
return jax.jit(model.__class__._static_train_step)
return jax.jit(model.__class__._static_train_step, donate_argnums=0)


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -204,26 +205,30 @@ def init_step_fn(self, model: M) -> InitStep[M]:
return jax.pmap(
model.__class__._static_init_step,
axis_name="device",
donate_argnums=0,
)

def pred_step_fn(self, model: M) -> PredStep[M]:
return jax.pmap(
model.__class__._static_pred_step,
axis_name="device",
donate_argnums=0,
)

def test_step_fn(self, model: M) -> TestStep[M]:
return jax.pmap(
model.__class__._static_test_step,
axis_name="device",
out_axes=(0, None, 0), # None = logs not replicated
donate_argnums=0,
)

def train_step_fn(self, model: M) -> TrainStep[M]:
return jax.pmap(
model.__class__._static_train_step,
axis_name="device",
out_axes=(None, 0), # None = logs not replicated
donate_argnums=0,
)


Expand Down