Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_to_collection breaks named tuples (force casts to tuples) #206

Closed
artbataev opened this issue Dec 14, 2023 · 1 comment · Fixed by #210
Closed

apply_to_collection breaks named tuples (force casts to tuples) #206

artbataev opened this issue Dec 14, 2023 · 1 comment · Fixed by #210
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@artbataev
Copy link

artbataev commented Dec 14, 2023

🐛 Bug

lightning_utilities.core.apply_func.apply_to_collection breaks homogeneous named tuples, forcing conversion to simple tuples. This results in a forced conversion of NamedTuple-based batches from the dataloader to tuples in Lightning.

To Reproduce

Reproducing the bug in Lightning requires creating model + dataloader with NamedTuple output.

Reproducing the problem in Lightning-Utilities is easier, see the snippet below:

from lightning_utilities.core.apply_func import apply_to_collection
import torch
from typing import NamedTuple

class NamedTupleBatch(NamedTuple):
    x: torch.Tensor
    y: torch.Tensor

def test_apply_to_collection():
    batch = NamedTupleBatch(x=torch.rand(10, 10), y=torch.rand(10, 10))
    assert isinstance(batch, NamedTupleBatch)  # before
    batch_out = apply_to_collection(batch, torch.Tensor, lambda x: x.to("cpu"))
    assert isinstance(batch_out, NamedTupleBatch)  # after - broken

Expected behavior

apply_to_collection should return the NamedTuple (if input is of NamedTuple type) instead of force-casting it to simple tuple.

Additional context

I think that the problem was introduced in #160
Lightning Utilities version 0.9.0 is fine, but 0.10.0 breaks named tuples.

I think that the behavior is incorrect in

return tuple(function(x, *args, **kwargs) for x in data)

if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
    return tuple(function(x, *args, **kwargs) for x in data)

Potential fix:

if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
    if is_namedtuple(data):
        return type(data)(*[function(x, *args, **kwargs) for x in data])
    else:
        return tuple(function(x, *args, **kwargs) for x in data)
Environment details
  • PyTorch Version (e.g., 1.0): any (problem is not related to PyTorch)
  • OS (e.g., Linux): MacOS, Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:
@artbataev artbataev added bug Something isn't working help wanted Extra attention is needed labels Dec 14, 2023
@artbataev
Copy link
Author

Ah, I'm sorry. It seems to be a duplicate for #196. Anyway, please, fix the issue :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant