-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't change device if input_data is given #236
Conversation
Codecov Report
@@ Coverage Diff @@
## main #236 +/- ##
==========================================
- Coverage 97.63% 97.50% -0.14%
==========================================
Files 6 6
Lines 634 640 +6
==========================================
+ Hits 619 624 +5
- Misses 15 16 +1
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
torchinfo/torchinfo.py
Outdated
if device is None: # should never happen; 'get_device' should prevent it | ||
raise RuntimeError("`device` is None.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this will never happen then just use assert device is not None
torchinfo/torchinfo.py
Outdated
dtypes: list[torch.dtype] | None = None, | ||
) -> tuple[CORRECTED_INPUT_DATA_TYPE, Any]: | ||
"""Reads sample input data to get the input size.""" | ||
x = None | ||
correct_input_size = [] | ||
if input_data is not None: | ||
correct_input_size = get_input_data_sizes(input_data) | ||
x = set_device(input_data, device) | ||
x = input_data if device is None else set_device(input_data, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this logic to set_device
See https://github.com/TylerYep/torchinfo/tree/snimu/main for the fixes for the changes requested, thanks! |
1. Move logic for setting device in `process_input` into `set_device` if `input_data` is given 2. Replace exception that is likely never raised with assertion (also in `process_input`)
Sorry, I thought the link you gave is to the normal torchinfo repo, and fixed the code without looking at it. It looks slightly different than your version, but just as readable I think (and exactly the same logic, of course, and implementing your comments). Do you want me to change it to your version or is mine ok? |
Looks good to me, I can fix small nits after it's merged. Thanks for the fix! |
This should also close issue #224 (model parallelism) |
The way that the logic of device-setting works in this PR is the following:
device
is given by the user → Move model and data to that devicedevice
is not given by the user. One of two possibilities:a.
input_data
is given by the user → neither model nor data are movedb.
input_data
is not given by the user → previous logic applies (try to find out device from model, if this fails or is on CPU, try to move to GPU, else use CPU)