Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Donate model's memory buffer to jit/pmap functions. #226

Merged
merged 1 commit into from
Mar 23, 2022

Conversation

lkhphuc
Copy link
Contributor

@lkhphuc lkhphuc commented Mar 23, 2022

As discussed in Discord, using donate_argnums=1 in Jit/pmap will reduce GPU/TPU memory by 1/3.

Before:
image0
After:
Screen_Shot_2022-03-23_at_09 00 54

Technically, only donate argnum in train_step_fn is necessary, since all other *_step_fn got called inside train_step_fn anyway. But for consistency I add donate argnum to every *step_fn anyway.

@cgarciae
Copy link
Collaborator

Amazing! Thanks a lot for doing this.
LGTM.

Merging.

@cgarciae cgarciae merged commit c690748 into poets-ai:master Mar 23, 2022
@cgarciae cgarciae added the fix label Mar 23, 2022
@lkhphuc lkhphuc deleted the donate_argnum branch March 24, 2022 00:41
@bhoov
Copy link

bhoov commented Apr 30, 2022

My code randomly failed at the end of an epoch of training. I did some modification of the code and found that the call to train_on_batch caused it to fail. I rolled back to elegy==0.8.5 and everything works. Pretty sure this change is breaking (I am running the simple MNIST example code).

2022-04-30 18:39:44.022094: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2140] Execution of replica 0 failed: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [52], in <cell line: 1>()
----> 1 history = model.fit(
      2     inputs=X_train,
      3     labels=y_train,
      4     epochs=5,
      5     steps_per_epoch=2,
      6     batch_size=10,
      7     validation_data=(X_test, y_test),
      8     shuffle=True,
      9     callbacks=[eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True)],
     10 )

Input In [39], in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining)
    102 if drop_remaining and not data_utils.has_batch_size(
    103     batch, data_handler.batch_size
    104 ):
    105     continue
--> 107 tmp_logs = self.train_on_batch(
    108     inputs=inputs,
    109     labels=labels,
    110 )

Input In [50], in train_on_batch(self, inputs, labels)
     88     labels = dict(target=labels)
     91 train_step_fn = self.train_step_fn[self._distributed_strategy]
---> 92 logs, model = train_step_fn(self, inputs, labels)
     93 print("Ending train step")
     94 return {}

ValueError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer

@lkhphuc
Copy link
Contributor Author

lkhphuc commented May 1, 2022

Hi @bhoov, can you paste a full stack trace, and a minimal example? I will try to take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants