Skip to content

Commit

Permalink
Don't change device if input_data is given (#236)
Browse files Browse the repository at this point in the history
* Only change device if no input-data is given

* Add .out-file for test_device_parallelism

* Move test_device_parallelism to gpu_test.py

* Assert devices unchanged

* Move input_data to device if it is given by user

* Fix small issues:

1. Move logic for setting device in `process_input` into `set_device` if `input_data` is given
2. Replace exception that is likely never raised with assertion (also in `process_input`)

* Remove unnecessary comment
  • Loading branch information
snimu authored Mar 13, 2023
1 parent 4847263 commit ef9daef
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 12 deletions.
24 changes: 24 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,27 @@ def forward(self, token_embedding: torch.Tensor) -> torch.Tensor:
hx = self.projection(hx)
hx = self.activation(hx)
return hx


class MultiDeviceModel(nn.Module):
"""
A model living on several devices.
Follows the ToyModel from the Tutorial on parallelism:
https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
"""

def __init__(
self, device1: torch.device | str, device2: torch.device | str
) -> None:
super().__init__()
self.device1 = device1
self.device2 = device2

self.net1 = torch.nn.Linear(10, 10).to(device1)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(device2)

def forward(self, x: torch.Tensor) -> Any:
x = self.relu(self.net1(x.to(self.device1)))
return self.net2(x.to(self.device2))
13 changes: 12 additions & 1 deletion tests/gpu_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from tests.fixtures.models import SingleInputNet
from tests.fixtures.models import MultiDeviceModel, SingleInputNet
from torchinfo import summary


Expand Down Expand Up @@ -75,3 +75,14 @@ def test_different_model_parts_on_different_devices() -> None:
torch.nn.Linear(10, 10).to(1), torch.nn.Linear(10, 10).to(0)
)
summary(model)


@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Need CUDA to test parallelism."
)
def test_device_parallelism() -> None:
model = MultiDeviceModel("cpu", "cuda")
input_data = torch.randn(10)
summary(model, input_data=input_data)
assert not next(model.net1.parameters()).is_cuda
assert next(model.net2.parameters()).is_cuda
18 changes: 18 additions & 0 deletions tests/test_output/device_parallelism.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MultiDeviceModel [5] --
├─Linear: 1-1 [10] 110
├─ReLU: 1-2 [10] --
├─Linear: 1-3 [5] 55
==========================================================================================
Total params: 165
Trainable params: 165
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
35 changes: 24 additions & 11 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class name as the key. If the forward pass is an expensive operation,
cache_forward_pass = False

if device is None:
device = get_device(model)
device = get_device(model, input_data)
elif isinstance(device, str):
device = torch.device(device)

Expand All @@ -234,7 +234,7 @@ def process_input(
input_data: INPUT_DATA_TYPE | None,
input_size: INPUT_SIZE_TYPE | None,
batch_dim: int | None,
device: torch.device,
device: torch.device | None,
dtypes: list[torch.dtype] | None = None,
) -> tuple[CORRECTED_INPUT_DATA_TYPE, Any]:
"""Reads sample input data to get the input size."""
Expand All @@ -247,6 +247,7 @@ def process_input(
x = [x]

if input_size is not None:
assert device is not None
if dtypes is None:
dtypes = [torch.float] * len(input_size)
correct_input_size = get_correct_input_sizes(input_size)
Expand All @@ -259,7 +260,7 @@ def forward_pass(
x: CORRECTED_INPUT_DATA_TYPE,
batch_dim: int | None,
cache_forward_pass: bool,
device: torch.device,
device: torch.device | None,
mode: Mode,
**kwargs: Any,
) -> list[LayerInfo]:
Expand All @@ -274,7 +275,7 @@ def forward_pass(
set_children_layers(summary_list)
return summary_list

kwargs = set_device(kwargs, device)
kwargs = set_device(kwargs, device) if device is not None else kwargs
saved_model_mode = model.training
try:
if mode == Mode.TRAIN:
Expand All @@ -287,10 +288,11 @@ def forward_pass(
)

with torch.no_grad():
model = model.to(device) if device is not None else model
if isinstance(x, (list, tuple)):
_ = model.to(device)(*x, **kwargs)
_ = model(*x, **kwargs)
elif isinstance(x, dict):
_ = model.to(device)(**x, **kwargs)
_ = model(**x, **kwargs)
else:
# Should not reach this point, since process_input_data ensures
# x is either a list, tuple, or dict
Expand Down Expand Up @@ -368,7 +370,7 @@ def validate_user_params(
input_size: INPUT_SIZE_TYPE | None,
col_names: tuple[ColumnSettings, ...],
col_width: int,
device: torch.device,
device: torch.device | None,
dtypes: list[torch.dtype] | None,
verbose: int,
) -> None:
Expand Down Expand Up @@ -400,7 +402,7 @@ def validate_user_params(
"output incorrect results. Try passing input_data directly.",
stacklevel=2,
)
if device.type == "cpu":
if device is not None and device.type == "cpu":
warnings.warn(
"Half precision is not supported on cpu. Set the `device` field or "
"pass `input_data` using the correct device.",
Expand Down Expand Up @@ -449,20 +451,31 @@ def traverse_input_data(
return result


def set_device(data: Any, device: torch.device) -> Any:
def set_device(data: Any, device: torch.device | None) -> Any:
"""Sets device for all input types and collections of input types."""
if device is None:
return data

return traverse_input_data(
data,
action_fn=lambda data: data.to(device, non_blocking=True),
aggregate_fn=type,
)


def get_device(model: nn.Module) -> torch.device:
def get_device(
model: nn.Module, input_data: INPUT_DATA_TYPE | None
) -> torch.device | None:
"""
Gets device of first parameter of model and returns it if it is on cuda,
If input_data is given, the device should not be changed
(to allow for multi-device models, etc.)
Otherwise gets device of first parameter of model and returns it if it is on cuda,
otherwise returns cuda if available or cpu if not.
"""
if input_data is not None:
return None

try:
model_parameter = next(model.parameters())
except StopIteration:
Expand Down

0 comments on commit ef9daef

Please sign in to comment.