Skip to content

Commit

Permalink
Add trainable_params to layer_info
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerYep committed Nov 24, 2021
1 parent 0548f78 commit 483e9c6
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 6 deletions.
12 changes: 12 additions & 0 deletions fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


class MixedTrainableParameters(nn.Module):
"""Model with trainable and non-trainable parameters in the same layer."""

def __init__(self) -> None:
super().__init__()
self.w = nn.Parameter(torch.empty(10), requires_grad=True)
self.b = nn.Parameter(torch.empty(10), requires_grad=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w * x
9 changes: 9 additions & 0 deletions tests/test_output/mixed_trainable_parameters.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
=================================================================
Layer (type:depth-idx) Param #
=================================================================
MixedTrainableParameters 20
=================================================================
Total params: 20
Trainable params: 10
Non-trainable params: 10
=================================================================
8 changes: 8 additions & 0 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EmptyModule,
LinearModel,
LSTMNet,
MixedTrainableParameters,
ModuleDictModel,
MultipleInputNetDifferentDtypes,
NamedTuple,
Expand Down Expand Up @@ -482,3 +483,10 @@ def test_reusing_activation_layers() -> None:
result_2 = summary(model2)

assert len(result_1.summary_list) == len(result_2.summary_list) == 6


def test_mixed_trainable_parameters() -> None:
result = summary(MixedTrainableParameters())

assert result.trainable_params == 10
assert result.total_params == 20
2 changes: 1 addition & 1 deletion torchinfo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .torchinfo import summary

__all__ = ("ModelStatistics", "summary", "ALL_COLUMN_SETTINGS", "ALL_ROW_SETTINGS")
__version__ = "1.5.3"
__version__ = "1.5.4"
7 changes: 4 additions & 3 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
self.is_leaf_layer = not any(self.module.children())

# Statistics
self.trainable = True
self.trainable_params = 0
self.is_recursive = False
self.input_size: List[int] = []
self.output_size: List[int] = []
Expand Down Expand Up @@ -122,7 +122,8 @@ def calculate_num_params(self) -> None:
name = ""
for name, param in self.module.named_parameters():
self.num_params += param.nelement()
self.trainable &= param.requires_grad
if param.requires_grad:
self.trainable_params += param.nelement()

ksize = list(param.size())
if name == "weight":
Expand Down Expand Up @@ -194,7 +195,7 @@ def num_params_to_str(self, reached_max_depth: bool) -> str:
return "(recursive)"
if self.num_params > 0 and (reached_max_depth or self.is_leaf_layer):
param_count_str = f"{self.num_params:,}"
return param_count_str if self.trainable else f"({param_count_str})"
return param_count_str if self.trainable_params else f"({param_count_str})"
return "--"


Expand Down
3 changes: 1 addition & 2 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def __init__(
if layer_info.is_recursive:
continue
self.total_params += layer_info.num_params
if layer_info.trainable:
self.trainable_params += layer_info.num_params
self.trainable_params += layer_info.trainable_params
if layer_info.num_params > 0:
# x2 for gradients
self.total_output += 2 * prod(layer_info.output_size)
Expand Down

0 comments on commit 483e9c6

Please sign in to comment.