Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't change device if input_data is given #236

Merged
merged 7 commits into from
Mar 13, 2023
Merged

Don't change device if input_data is given #236

merged 7 commits into from
Mar 13, 2023

Conversation

snimu
Copy link
Contributor

@snimu snimu commented Mar 12, 2023

The way that the logic of device-setting works in this PR is the following:

  1. device is given by the user → Move model and data to that device
  2. device is not given by the user. One of two possibilities:
    a. input_data is given by the user → neither model nor data are moved
    b. input_data is not given by the user → previous logic applies (try to find out device from model, if this fails or is on CPU, try to move to GPU, else use CPU)

@codecov
Copy link

codecov bot commented Mar 12, 2023

Codecov Report

Merging #236 (250f4a5) into main (4847263) will decrease coverage by 0.14%.
The diff coverage is 92.30%.

@@            Coverage Diff             @@
##             main     #236      +/-   ##
==========================================
- Coverage   97.63%   97.50%   -0.14%     
==========================================
  Files           6        6              
  Lines         634      640       +6     
==========================================
+ Hits          619      624       +5     
- Misses         15       16       +1     
Impacted Files Coverage Δ
torchinfo/torchinfo.py 97.20% <92.30%> (-0.35%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@TylerYep TylerYep linked an issue Mar 12, 2023 that may be closed by this pull request
Comment on lines 250 to 251
if device is None: # should never happen; 'get_device' should prevent it
raise RuntimeError("`device` is None.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this will never happen then just use assert device is not None

dtypes: list[torch.dtype] | None = None,
) -> tuple[CORRECTED_INPUT_DATA_TYPE, Any]:
"""Reads sample input data to get the input size."""
x = None
correct_input_size = []
if input_data is not None:
correct_input_size = get_input_data_sizes(input_data)
x = set_device(input_data, device)
x = input_data if device is None else set_device(input_data, device)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this logic to set_device

@TylerYep
Copy link
Owner

See https://github.com/TylerYep/torchinfo/tree/snimu/main for the fixes for the changes requested, thanks!

snimu added 2 commits March 13, 2023 09:25
1. Move logic for setting device in `process_input` into `set_device` if `input_data` is given
2. Replace exception that is likely never raised with assertion (also in `process_input`)
@snimu
Copy link
Contributor Author

snimu commented Mar 13, 2023

Sorry, I thought the link you gave is to the normal torchinfo repo, and fixed the code without looking at it. It looks slightly different than your version, but just as readable I think (and exactly the same logic, of course, and implementing your comments). Do you want me to change it to your version or is mine ok?

@TylerYep
Copy link
Owner

Looks good to me, I can fix small nits after it's merged.

Thanks for the fix!

@TylerYep TylerYep merged commit ef9daef into TylerYep:main Mar 13, 2023
@snimu
Copy link
Contributor Author

snimu commented Mar 13, 2023

This should also close issue #224 (model parallelism)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

After calling summary(), model on CPU is pushed to GPU
2 participants