Skip to content

Commit

Permalink
Add "Trainable" column (#128)
Browse files Browse the repository at this point in the history
* add is_trainable column

* test: testcase for is_trainable column

* model which has fully, partial and non trainable modules
* update tests that require all coloums to display is_trainable coloumn as well

* docs: update README.md

* fix type ignores and nits

* Rename is_trainable to trainable

* Calculate trainable in pre_hook

* Fix readme

Co-authored-by: Tyler Yep <tyler.yep@robinhood.com>
  • Loading branch information
bsridatta and TylerYep authored May 15, 2022
1 parent 793c4f5 commit 7def795
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 27 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ Summarize the given PyTorch model. Summarized information includes:
2) input/output shapes,
3) kernel shape,
4) # of parameters,
5) # of operations (Mult-Adds)
5) # of operations (Mult-Adds),
6) whether layer is trainable
NOTE: If neither input_data or input_size are provided, no forward pass through the
network is performed, and the provided model information is limited to layer names.
Expand Down Expand Up @@ -166,6 +167,7 @@ Args:
"num_params",
"kernel_size",
"mult_adds",
"trainable",
)
Default: ("output_size", "num_params")
If input_data / input_size are not provided, only "num_params" is used.
Expand Down
26 changes: 26 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w * x + self.b


class MixedTrainable(nn.Module):
"""Model with fully, partial and non trainable modules."""

def __init__(self) -> None:
super().__init__()
self.fully_trainable = nn.Conv1d(1, 1, 1)

self.partially_trainable = nn.Conv1d(1, 1, 1, bias=True)
assert self.partially_trainable.bias is not None
self.partially_trainable.bias.requires_grad = False

self.non_trainable = nn.Conv1d(1, 1, 1, 1, bias=True)
self.non_trainable.weight.requires_grad = False
assert self.non_trainable.bias is not None
self.non_trainable.bias.requires_grad = False

self.dropout = nn.Dropout()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fully_trainable(x)
x = self.partially_trainable(x)
x = self.non_trainable(x)
x = self.dropout(x)
return x


class ReuseLinear(nn.Module):
"""Model that uses a reference to the same Linear layer over and over."""

Expand Down
20 changes: 10 additions & 10 deletions tests/test_output/parameter_list.out
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
===================================================================================================================
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds
===================================================================================================================
ParameterListModel -- -- -- -- --
├─ParameterList: 1-1 -- -- -- 30,000 --
│ └─0 [100, 100] ├─10,000
│ └─1 [100, 200] └─20,000
===================================================================================================================
================================================================================================================================================================
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds Trainable
================================================================================================================================================================
ParameterListModel -- -- -- -- -- True
├─ParameterList: 1-1 -- -- -- 30,000 -- True
│ └─0 [100, 100] ├─10,000
│ └─1 [100, 200] └─20,000
================================================================================================================================================================
Total params: 30,000
Trainable params: 30,000
Non-trainable params: 0
Total mult-adds (M): 0.00
===================================================================================================================
================================================================================================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 0.00
Params size (MB): 0.12
Estimated Total Size (MB): 0.16
===================================================================================================================
================================================================================================================================================================
24 changes: 12 additions & 12 deletions tests/test_output/single_input_all_cols.out
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
============================================================================================================================================
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds
============================================================================================================================================
SingleInputNet -- -- -- -- --
├─Conv2d: 1-1 [5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320
├─Conv2d: 1-2 [5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960
├─Dropout2d: 1-3 -- [7, 20, 8, 8] [7, 20, 8, 8] -- --
├─Linear: 1-4 -- [7, 320] [7, 50] 16,050 112,350
├─Linear: 1-5 -- [7, 50] [7, 10] 510 3,570
============================================================================================================================================
================================================================================================================================================================
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds Trainable
================================================================================================================================================================
SingleInputNet -- -- -- -- -- True
├─Conv2d: 1-1 [5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320 True
├─Conv2d: 1-2 [5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960 True
├─Dropout2d: 1-3 -- [7, 20, 8, 8] [7, 20, 8, 8] -- -- --
├─Linear: 1-4 -- [7, 320] [7, 50] 16,050 112,350 True
├─Linear: 1-5 -- [7, 50] [7, 10] 510 3,570 True
================================================================================================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
============================================================================================================================================
================================================================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
============================================================================================================================================
================================================================================================================================================================
19 changes: 19 additions & 0 deletions tests/test_output/trainable_column.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
============================================================================================================================================
Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Trainable
============================================================================================================================================
MixedTrainable -- -- -- Partial
├─Conv1d: 1-1 [1] [1, 1, 1] [1, 1, 1] True
├─Conv1d: 1-2 [1] [1, 1, 1] [1, 1, 1] Partial
├─Conv1d: 1-3 [1] [1, 1, 1] [1, 1, 1] False
├─Dropout: 1-4 -- [1, 1, 1] [1, 1, 1] --
============================================================================================================================================
Total params: 6
Trainable params: 3
Non-trainable params: 3
Total mult-adds (M): 0.00
============================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
============================================================================================================================================
14 changes: 11 additions & 3 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FakePrunedLayerModel,
LinearModel,
LSTMNet,
MixedTrainable,
MixedTrainableParameters,
ModuleDictModel,
MultipleInputNetDifferentDtypes,
Expand Down Expand Up @@ -111,13 +112,12 @@ def test_multiple_input_types() -> None:

def test_single_input_all_cols() -> None:
model = SingleInputNet()
col_names = ("kernel_size", "input_size", "output_size", "num_params", "mult_adds")
input_shape = (7, 1, 28, 28)
summary(
model,
input_data=torch.randn(*input_shape),
depth=1,
col_names=col_names,
col_names=list(ColumnSettings),
col_width=20,
)

Expand Down Expand Up @@ -194,7 +194,7 @@ def test_parameter_list() -> None:
input_size=(100, 100),
verbose=2,
col_names=list(ColumnSettings),
col_width=15,
col_width=20,
)


Expand Down Expand Up @@ -462,3 +462,11 @@ def test_pruned_adversary() -> None:
results = summary(second_model, input_size=(1,))

assert results.total_params == 32 # should be 64


def test_trainable_column() -> None:
summary(
MixedTrainable(),
input_size=(1, 1, 1),
col_names=("kernel_size", "input_size", "output_size", "trainable"),
)
1 change: 1 addition & 0 deletions torchinfo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ColumnSettings(str, Enum):
OUTPUT_SIZE = "output_size"
NUM_PARAMS = "num_params"
MULT_ADDS = "mult_adds"
TRAINABLE = "trainable"


@unique
Expand Down
2 changes: 2 additions & 0 deletions torchinfo/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ColumnSettings.OUTPUT_SIZE: "Output Shape",
ColumnSettings.NUM_PARAMS: "Param #",
ColumnSettings.MULT_ADDS: "Mult-Adds",
ColumnSettings.TRAINABLE: "Trainable",
}


Expand Down Expand Up @@ -113,6 +114,7 @@ def layer_info_to_row(
ColumnSettings.MULT_ADDS: layer_info.macs_to_str(
reached_max_depth, children_layers
),
ColumnSettings.TRAINABLE: self.str_(layer_info.trainable),
}
start_str = self.get_start_str(layer_info.depth)
layer_name = layer_info.get_layer_name(self.show_var_name, self.show_depth)
Expand Down
19 changes: 19 additions & 0 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.param_bytes = 0
self.output_bytes = 0
self.macs = 0
self.trainable = self.is_trainable(module)

def __repr__(self) -> str:
return f"{self.class_name}: {self.depth}"
Expand Down Expand Up @@ -159,6 +160,24 @@ def get_kernel_size(module: nn.Module) -> int | list[int] | None:
return kernel_size
return None

@staticmethod
def is_trainable(module: nn.Module) -> str:
"""
Checks if the module is trainable. Returns:
"True", if all the parameters are trainable (`requires_grad=True`)
"False" if none of the parameters are trainable.
"Partial" if some weights are trainable, but not all.
"--" if no module has no parameters, like Dropout.
"""
module_requires_grad = [param.requires_grad for param in module.parameters()]
if not module_requires_grad:
return "--"
if all(module_requires_grad):
return "True"
if any(module_requires_grad):
return "Partial"
return "False"

def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str:
layer_name = self.class_name
if show_var_name and self.var_name:
Expand Down
4 changes: 3 additions & 1 deletion torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def summary(
2) input/output shapes,
3) kernel shape,
4) # of parameters,
5) # of operations (Mult-Adds)
5) # of operations (Mult-Adds),
6) whether layer is trainable
NOTE: If neither input_data or input_size are provided, no forward pass through the
network is performed, and the provided model information is limited to layer names.
Expand Down Expand Up @@ -121,6 +122,7 @@ class name as the key. If the forward pass is an expensive operation,
"num_params",
"kernel_size",
"mult_adds",
"trainable",
)
Default: ("output_size", "num_params")
If input_data / input_size are not provided, only "num_params" is used.
Expand Down

0 comments on commit 7def795

Please sign in to comment.