Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DLPack no longer works on Boolean tensors after 1.10+ #67081

Closed
BarclayII opened this issue Oct 22, 2021 · 10 comments
Closed

DLPack no longer works on Boolean tensors after 1.10+ #67081

BarclayII opened this issue Oct 22, 2021 · 10 comments
Labels
module: boolean tensor module: dlpack module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@BarclayII
Copy link

BarclayII commented Oct 22, 2021

🐛 Bug

torch.utils.dlpack.to_dlpack no longer works for Boolean Tensor.

To Reproduce

torch.utils.dlpack.to_dlpack(torch.BoolTensor([False, True]))   # Bool type is not supported by dlpack

Expected behavior

Should work.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): 1.10 RC and nightly
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8.3
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

According to dmlc/dgl#3406 the error is raised after #57110. I was wondering what the reason was behind removing boolean support?

This is blocking DGL's patch release with PyTorch 1.10 since we rely on DLPack to interact with PyTorch tensors in multiple places.

+@jermainewang @VoVAllen

cc @brianjo @mruberry @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser

@mrshenli mrshenli added module: boolean tensor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 24, 2021
@mruberry
Copy link
Collaborator

mruberry commented Nov 5, 2021

cc @emcastillo @rgommers

@mruberry
Copy link
Collaborator

mruberry commented Nov 5, 2021

So I think DLPack doesn't support boolean tensors and the workaround for this is to first convert a tensor to uint8 before exporting it. @Barclayll would that work for you?

If I'm correct, the alternative would be that PyTorch does this bool -> uint8 conversion for the caller, but that may hide this unexpected behavior and lead to confusion when users reimported the same tensor and it was now in uint8.

@rgommers
Copy link
Collaborator

rgommers commented Nov 5, 2021

So I think DLPack doesn't support boolean tensors

Yes indeed, see dmlc/dlpack#75 for discussion.

torch.utils.dlpack.to_dlpack no longer works for Boolean Tensor.

This implies it did work before - what was it doing there exactly? Did it roundtrip correctly? What happened when you'd read in such an exported bool tensor with CuPy or JAX?

@mruberry
Copy link
Collaborator

mruberry commented Nov 5, 2021

Historically (as of PyTorch 1.9) PyTorch would write a uint8 tensor like this:

capsule = torch.utils.dlpack.to_dlpack(torch.tensor([False, True]))
torch.utils.dlpack.from_dlpack(capsule)
: tensor([0, 1], dtype=torch.uint8)

So no, it wouldn't round-trip. It was equivalent to the proposed workaround (first cast the tensor to uint8).

@rgommers
Copy link
Collaborator

rgommers commented Nov 5, 2021

Then I would say that raising an exception as in 1.10 is the desired behavior.

If I'm correct, the alternative would be that PyTorch does this bool -> uint8 conversion for the caller, but that may hide this unexpected behavior and lead to confusion when users reimported the same tensor and it was now in uint8.

This would also become an issue if DLPack resolves the issue I linked above and implements a bool dtype. That would likely force a backwards-incompatible change in the PyTorch implemetation.

@mruberry
Copy link
Collaborator

mruberry commented Nov 7, 2021

I agree with your thinking (as usual), @rgommers.

Unfortunately we did make an unexpected BC-breaking change by clarifying this DLPack behavior, and we're sorry that's so disruptive @BarclayII. It is probably the "right" change from a UX perspective, however.

What are your thoughts, @BarclayII?

@BarclayII
Copy link
Author

I understand. As per DGL we have temporarily worked around it so this is no longer a major blocker.

One further question though. Since PyTorch is deprecating using ByteTensors to index into an array, I'm wondering if PyTorch has any plan on removing this support in the next release? If so, then I think it's better to enable boolean type support in DLPack.

@mruberry
Copy link
Collaborator

mruberry commented Nov 8, 2021

I understand. As per DGL we have temporarily worked around it so this is no longer a major blocker.

Glad to hear it!

One further question though. Since PyTorch is deprecating using ByteTensors to index into an array, I'm wondering if PyTorch has any plan on removing this support in the next release? If so, then I think it's better to enable boolean type support in DLPack.

Removing support for what, exactly?

DLPack has an issue for this already that you may want to comment on, too: dmlc/dlpack#75.

@BarclayII
Copy link
Author

Removing support for what, exactly?

I meant removing the support of using ByteTensors for masked indexing.

@mruberry
Copy link
Collaborator

mruberry commented Nov 8, 2021

Removing support for what, exactly?

I meant removing the support of using ByteTensors for masked indexing.

Aha, I don't think anyone is actively working on removing that support.

That's a good connection to make, however.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: boolean tensor module: dlpack module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants