Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Compressed Adam optimizers - RuntimeError: Bool type is not supported by dlpack #1859

Closed
jhoareau opened this issue Mar 23, 2022 · 7 comments · Fixed by #1894
Closed
Assignees
Labels
bug Something isn't working

Comments

@jhoareau
Copy link

jhoareau commented Mar 23, 2022

Describe the bug
Traceback:

  File "deepspeed/__init__.py", line 119, in initialize
    engine = DeepSpeedEngine(args=args,
  File "deepspeed/runtime/engine.py", line 293, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "deepspeed/runtime/engine.py", line 1106, in _configure_optimizer
    self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
  File "deepspeed/runtime/engine.py", line 1243, in _configure_fp16_optimizer
    optimizer = FP16_Optimizer(
  File "deepspeed/runtime/fp16/fused_optimizer.py", line 111, in __init__
    self.initialize_optimizer_states()
  File "deepspeed/runtime/fp16/fused_optimizer.py", line 119, in initialize_optimizer_states
    self.optimizer.step()
  File "torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "deepspeed/runtime/fp16/onebit/zoadam.py", line 239, in step
    self.comm_backend_handle.compressed_allreduce(
  File "deepspeed/runtime/comm/nccl.py", line 72, in compressed_allreduce
    self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),
  File "deepspeed/runtime/compression/cupy.py", line 15, in torch2cupy
    return cupy.fromDlpack(to_dlpack(tensor))
RuntimeError: Bool type is not supported by dlpack

When using the implementation for OneBitAdam and ZeroOneAdam, the error from the title appears when comm_backend_name is set to nccl.

This is linked to the operations:

self.compression_backend.torch2cupy(buffer_m.sign_().add_(1).bool()),

compensated_server_m.sign_().add_(1).bool()),

And the fact that Bool is not supported by dlpack since Pytorch 1.10: see pytorch/pytorch#67081
Google's JAX repo recommends casting to uint8 instead of bool: jax-ml/jax#4719
Beware that, when I tried to implement the casting locally I got terrible performance with ZeroOneAdam.

Expected behavior
The ZeroOneAdam optimizer working with nccl and the latest Pytorch version.

ds_report output
DeepSpeed general environment info:
torch version .................... 1.11.0+cu115
torch cuda version ............... 11.5
torch hip version ................ None
nvcc version ..................... 11.4
deepspeed info ................... 0.6.1+208d45b, 208d45b, master
deepspeed wheel compiled w. ...... torch 1.11, cuda 11.5, hip 0.0

Launcher context
Pytorch-Lightning DeepSpeedPlugin, Python 3.8

@jhoareau jhoareau added the bug Something isn't working label Mar 23, 2022
@jhoareau jhoareau changed the title [BUG] ZeroOneAdam RuntimeError: Bool type is not supported by dlpack [BUG] Compressed Adam optimizers - RuntimeError: Bool type is not supported by dlpack Mar 23, 2022
@jhoareau
Copy link
Author

cc @conglongli @awan-10 @samyam @jeffra

@conglongli
Copy link
Contributor

conglongli commented Mar 31, 2022

Hi @jhoareau thanks for reporting this issue. It seems like there is an ongoing discussion at dlpack about whether or not support bool type dmlc/dlpack#75.

On the other hand, for your comment

Google's JAX repo recommends casting to uint8 instead of bool: jax-ml/jax#4719 Beware that, when I tried to implement the casting locally I got terrible performance with ZeroOneAdam.

I assume you added the casting in DeepSpeed/deepspeed/runtime/comm/nccl.py. One of the reasons you see bad performance is that here the bool type is expected to have only 1 bit not 8 bits (just like what the name "1-bit compression" implies), so by casting to uint8 you are actually losing 8x compression ratio. Actually I was wrong: we have a cupy.packbits to compress uint8 into bit. Then could you share how you implement the casting locally for us to investigate?

Overall to solve this problem currently the only way is to use older pytorch (on our side we verified torch 1.8 works). @awan-10 and I need to have some internal discussion about whether there are any solution to make it work without bool type, but it might take some time. We also need to see dlpack's decision.

@jhoareau
Copy link
Author

jhoareau commented Apr 4, 2022

Hi @conglongli I just added .to(dtype=torch.uint8) to the two occurences I mentioned in the original issue. I didn't dig too much into the performance though, it could have been affected by other pieces of my code and it might be worth trying it out on your side.

dlpack's issue has been open for a while, I think going for a workaround like this is best. The reason why it works in Torch up to 1.9 is because they used to do the casting internally: pytorch/pytorch#67081 (comment)

I could potentially recommend only casting to uint8 if you detect the pytorch version to be 1.10 or over, not hitting performance for earlier versions and allowing at least functionality for Pytorch 1.10 and over, and let performance be what it is on those versions?

@conglongli
Copy link
Contributor

I see. Yes we will investigate this on our side. But please understand that because we need to test both performance and convergence and because of bandwidth limitation, this will take some time. Before that I would recommend you to use older Pytorch if possible.

@conglongli
Copy link
Contributor

Hi @jhoareau I created a PR based on your suggestion #1894 and I did test that torch 1.10 w/ this fix is able to provide same performance benefit and same training loss curve on BERT pertaining compared with torch 1.8 w/o this fix. Could you try if this PR fix your issue?

@jhoareau
Copy link
Author

Hi @conglongli the PR indeed fixes the issue. Thanks a lot for the quick PR turnaround!

@conglongli
Copy link
Contributor

Thanks for confirming @jhoareau , will merge the PR then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants