Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_func.apply_to_collection force updating its return type. #196

Closed
GdoongMathew opened this issue Nov 21, 2023 · 4 comments · Fixed by #210
Closed

apply_func.apply_to_collection force updating its return type. #196

GdoongMathew opened this issue Nov 21, 2023 · 4 comments · Fixed by #210
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@GdoongMathew
Copy link
Contributor

🐛 Bug

When I'm using torchmetrics, I'm implementing a custom metrics for classification with a nametuple return. before lightning_utilities 0.9.0 it works because the returned nametuple instance was not changed by the apply_func.apply_to_collection. Since this behavior was updated after 0.10.0, I'd like to propose that if apply_to_collection can be updated and still keep the same type.

To Reproduce

Steps to reproduce the behavior...

>>> from torchmetrics.utilities.data import _squeeze_scalar_element_tensor, _squeeze_if_scalar
>>> from collections import namedtuple
>>> import torch
>>> State = namedtuple("State",["gt", "tp", "fp", "tn", "fn"])
>>> state = State(torch.tensor(1), torch.tensor(1) ,torch.tensor(1), torch.tensor(1), torch.tensor(1))
>>> x = _squeeze_if_scalar(state)
>>> x
(tensor(1), tensor(1), tensor(1), tensor(1), tensor(1))
Code sample

Expected behavior

>>> x = _squeeze_if_scalar(state)
>>> x
State(gt=tensor(1), tp=tensor(1), fp=tensor(1), tn=tensor(1), fn=tensor(1))

Additional context

Environment details
  • lightning_utilitis: 0.10.0

Proposal

def apply_to_collection(
    data: Any,
    dtype: Union[type, Any, Tuple[Union[type, Any]]],
    function: Callable,
    *args: Any,
    wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None,
    include_none: bool = True,
    allow_frozen: bool = False,
    **kwargs: Any,
) -> Any:
    if include_none is False or wrong_dtype is not None or allow_frozen is True:
        # not worth implementing these on the fast path: go with the slower option
        return _apply_to_collection_slow(
            data,
            dtype,
            function,
            *args,
            wrong_dtype=wrong_dtype,
            include_none=include_none,
            allow_frozen=allow_frozen,
            **kwargs,
        )
    # fast path for the most common cases:
    if isinstance(data, dtype):  # single element
        return function(data, *args, **kwargs)
    ori_class = data.__class__
    if isinstance(data, list) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous list
        return ori_class(function(x, *args, **kwargs) for x in data)
    if isinstance(data, tuple) and all(isinstance(x, dtype) for x in data):  # 1d homogeneous tuple
        return ori_class(*(function(x, *args, **kwargs) for x in data))
    if isinstance(data, dict) and all(isinstance(x, dtype) for x in data.values()):  # 1d homogeneous dict
        return ori_class(**{k: function(v, *args, **kwargs) for k, v in data.items()})
@GdoongMathew GdoongMathew added bug Something isn't working help wanted Extra attention is needed labels Nov 21, 2023
@shubhodeepMitra
Copy link

Hi, I am open to working on it. Is there any timeframe for this?

@carmocca
Copy link
Contributor

@shubhodeepMitra See my comment here about what I suggest: #199 (comment)

@GdoongMathew
Copy link
Contributor Author

Hi @carmocca
Do you think that the proposal mentioned above is workable?
Cause I think the only thing that the modification in 0.10.0's appy_to_collection lack of is keeping its original input type, other than that, everything looks pretty fine to me.

@carmocca
Copy link
Contributor

I made the change that caused this issue to try to have torch.compile graphbreak when it encounters this function for some simple datatypes.

If your proposal introduces a graphbreak, then it would defeat the point of the fast vs slow paths. I haven't tested if it would.

Since avoiding a graphbreak is not that important to me anymore I suggest that we revert to the old implementation before #160. This should be the simpler and safer option.

However, if you open a PR with your suggestion including a test for the namedtuple case and a script showing that apply_to_collection still doesn't graph break on the basic types, I will merge it without issues.

GdoongMathew added a commit to GdoongMathew/utilities that referenced this issue Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants