From 793c4f57b6c2bb588bc8f394b90111cc55cc6cea Mon Sep 17 00:00:00 2001 From: Tyler Yep Date: Sun, 15 May 2022 14:23:48 -0700 Subject: [PATCH] Get kernel_size directly from attribute to avoid incorrect assumptions of var names --- setup.cfg | 2 +- tests/test_output/autoencoder.out | 30 ++++---- tests/test_output/frozen_layers.out | 72 ++++++++++---------- tests/test_output/lstm.out | 4 +- tests/test_output/single_input_all_cols.out | 8 +-- tests/test_output/single_input_batch_dim.out | 8 +-- tests/torchinfo_test.py | 10 ++- torchinfo/layer_info.py | 18 ++++- 8 files changed, 87 insertions(+), 65 deletions(-) diff --git a/setup.cfg b/setup.cfg index 14f7459..8b9519b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ disable = too-many-branches, too-many-locals, invalid-name, - line-too-long, + line-too-long, # Covered by flake8 no-member, fixme, duplicate-code, diff --git a/tests/test_output/autoencoder.out b/tests/test_output/autoencoder.out index e3cabf6..86cab54 100644 --- a/tests/test_output/autoencoder.out +++ b/tests/test_output/autoencoder.out @@ -1,23 +1,23 @@ -========================================================================================== -Layer (type:depth-idx) Output Shape Param # -========================================================================================== -AutoEncoder -- -- -├─Sequential: 1-1 [1, 16, 64, 64] -- -│ └─Conv2d: 2-1 [1, 16, 64, 64] 448 -│ └─ReLU: 2-2 [1, 16, 64, 64] -- -├─MaxPool2d: 1-2 [1, 16, 32, 32] -- -├─MaxUnpool2d: 1-3 [1, 16, 64, 64] -- -├─Sequential: 1-4 [1, 3, 64, 64] -- -│ └─Conv2d: 2-3 [1, 3, 64, 64] 435 -│ └─ReLU: 2-4 [1, 3, 64, 64] -- -========================================================================================== +=================================================================================================================== +Layer (type:depth-idx) Output Shape Param # Kernel Shape +=================================================================================================================== +AutoEncoder -- -- -- +├─Sequential: 1-1 [1, 16, 64, 64] -- -- +│ └─Conv2d: 2-1 [1, 16, 64, 64] 448 [3, 3] +│ └─ReLU: 2-2 [1, 16, 64, 64] -- -- +├─MaxPool2d: 1-2 [1, 16, 32, 32] -- 2 +├─MaxUnpool2d: 1-3 [1, 16, 64, 64] -- [2, 2] +├─Sequential: 1-4 [1, 3, 64, 64] -- -- +│ └─Conv2d: 2-3 [1, 3, 64, 64] 435 [3, 3] +│ └─ReLU: 2-4 [1, 3, 64, 64] -- -- +=================================================================================================================== Total params: 883 Trainable params: 883 Non-trainable params: 0 Total mult-adds (M): 3.62 -========================================================================================== +=================================================================================================================== Input size (MB): 0.05 Forward/backward pass size (MB): 0.62 Params size (MB): 0.00 Estimated Total Size (MB): 0.68 -========================================================================================== +=================================================================================================================== diff --git a/tests/test_output/frozen_layers.out b/tests/test_output/frozen_layers.out index c5a5293..d823d10 100644 --- a/tests/test_output/frozen_layers.out +++ b/tests/test_output/frozen_layers.out @@ -2,75 +2,75 @@ Layer (type:depth-idx) Output Shape Param # Kernel Shape Mult-Adds ============================================================================================================================================ ResNet -- -- -- -- -├─Conv2d: 1-1 [1, 64, 32, 32] (9,408) [3, 64, 7, 7] 9,633,792 -├─BatchNorm2d: 1-2 [1, 64, 32, 32] (128) [64] 128 +├─Conv2d: 1-1 [1, 64, 32, 32] (9,408) [7, 7] 9,633,792 +├─BatchNorm2d: 1-2 [1, 64, 32, 32] (128) -- 128 ├─ReLU: 1-3 [1, 64, 32, 32] -- -- -- -├─MaxPool2d: 1-4 [1, 64, 16, 16] -- -- -- +├─MaxPool2d: 1-4 [1, 64, 16, 16] -- 3 -- ├─Sequential: 1-5 [1, 64, 16, 16] -- -- -- │ └─BasicBlock: 2-1 [1, 64, 16, 16] -- -- -- -│ │ └─Conv2d: 3-1 [1, 64, 16, 16] (36,864) [64, 64, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-2 [1, 64, 16, 16] (128) [64] 128 +│ │ └─Conv2d: 3-1 [1, 64, 16, 16] (36,864) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-2 [1, 64, 16, 16] (128) -- 128 │ │ └─ReLU: 3-3 [1, 64, 16, 16] -- -- -- -│ │ └─Conv2d: 3-4 [1, 64, 16, 16] (36,864) [64, 64, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-5 [1, 64, 16, 16] (128) [64] 128 +│ │ └─Conv2d: 3-4 [1, 64, 16, 16] (36,864) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-5 [1, 64, 16, 16] (128) -- 128 │ │ └─ReLU: 3-6 [1, 64, 16, 16] -- -- -- │ └─BasicBlock: 2-2 [1, 64, 16, 16] -- -- -- -│ │ └─Conv2d: 3-7 [1, 64, 16, 16] (36,864) [64, 64, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-8 [1, 64, 16, 16] (128) [64] 128 +│ │ └─Conv2d: 3-7 [1, 64, 16, 16] (36,864) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-8 [1, 64, 16, 16] (128) -- 128 │ │ └─ReLU: 3-9 [1, 64, 16, 16] -- -- -- -│ │ └─Conv2d: 3-10 [1, 64, 16, 16] (36,864) [64, 64, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-11 [1, 64, 16, 16] (128) [64] 128 +│ │ └─Conv2d: 3-10 [1, 64, 16, 16] (36,864) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-11 [1, 64, 16, 16] (128) -- 128 │ │ └─ReLU: 3-12 [1, 64, 16, 16] -- -- -- ├─Sequential: 1-6 [1, 128, 8, 8] -- -- -- │ └─BasicBlock: 2-3 [1, 128, 8, 8] -- -- -- -│ │ └─Conv2d: 3-13 [1, 128, 8, 8] (73,728) [64, 128, 3, 3] 4,718,592 -│ │ └─BatchNorm2d: 3-14 [1, 128, 8, 8] (256) [128] 256 +│ │ └─Conv2d: 3-13 [1, 128, 8, 8] (73,728) [3, 3] 4,718,592 +│ │ └─BatchNorm2d: 3-14 [1, 128, 8, 8] (256) -- 256 │ │ └─ReLU: 3-15 [1, 128, 8, 8] -- -- -- -│ │ └─Conv2d: 3-16 [1, 128, 8, 8] (147,456) [128, 128, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-17 [1, 128, 8, 8] (256) [128] 256 +│ │ └─Conv2d: 3-16 [1, 128, 8, 8] (147,456) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-17 [1, 128, 8, 8] (256) -- 256 │ │ └─Sequential: 3-18 [1, 128, 8, 8] (8,448) -- 524,544 │ │ └─ReLU: 3-19 [1, 128, 8, 8] -- -- -- │ └─BasicBlock: 2-4 [1, 128, 8, 8] -- -- -- -│ │ └─Conv2d: 3-20 [1, 128, 8, 8] (147,456) [128, 128, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-21 [1, 128, 8, 8] (256) [128] 256 +│ │ └─Conv2d: 3-20 [1, 128, 8, 8] (147,456) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-21 [1, 128, 8, 8] (256) -- 256 │ │ └─ReLU: 3-22 [1, 128, 8, 8] -- -- -- -│ │ └─Conv2d: 3-23 [1, 128, 8, 8] (147,456) [128, 128, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-24 [1, 128, 8, 8] (256) [128] 256 +│ │ └─Conv2d: 3-23 [1, 128, 8, 8] (147,456) [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-24 [1, 128, 8, 8] (256) -- 256 │ │ └─ReLU: 3-25 [1, 128, 8, 8] -- -- -- ├─Sequential: 1-7 [1, 256, 4, 4] -- -- -- │ └─BasicBlock: 2-5 [1, 256, 4, 4] -- -- -- -│ │ └─Conv2d: 3-26 [1, 256, 4, 4] 294,912 [128, 256, 3, 3] 4,718,592 -│ │ └─BatchNorm2d: 3-27 [1, 256, 4, 4] 512 [256] 512 +│ │ └─Conv2d: 3-26 [1, 256, 4, 4] 294,912 [3, 3] 4,718,592 +│ │ └─BatchNorm2d: 3-27 [1, 256, 4, 4] 512 -- 512 │ │ └─ReLU: 3-28 [1, 256, 4, 4] -- -- -- -│ │ └─Conv2d: 3-29 [1, 256, 4, 4] 589,824 [256, 256, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-30 [1, 256, 4, 4] 512 [256] 512 +│ │ └─Conv2d: 3-29 [1, 256, 4, 4] 589,824 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-30 [1, 256, 4, 4] 512 -- 512 │ │ └─Sequential: 3-31 [1, 256, 4, 4] 33,280 -- 524,800 │ │ └─ReLU: 3-32 [1, 256, 4, 4] -- -- -- │ └─BasicBlock: 2-6 [1, 256, 4, 4] -- -- -- -│ │ └─Conv2d: 3-33 [1, 256, 4, 4] 589,824 [256, 256, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-34 [1, 256, 4, 4] 512 [256] 512 +│ │ └─Conv2d: 3-33 [1, 256, 4, 4] 589,824 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-34 [1, 256, 4, 4] 512 -- 512 │ │ └─ReLU: 3-35 [1, 256, 4, 4] -- -- -- -│ │ └─Conv2d: 3-36 [1, 256, 4, 4] 589,824 [256, 256, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-37 [1, 256, 4, 4] 512 [256] 512 +│ │ └─Conv2d: 3-36 [1, 256, 4, 4] 589,824 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-37 [1, 256, 4, 4] 512 -- 512 │ │ └─ReLU: 3-38 [1, 256, 4, 4] -- -- -- ├─Sequential: 1-8 [1, 512, 2, 2] -- -- -- │ └─BasicBlock: 2-7 [1, 512, 2, 2] -- -- -- -│ │ └─Conv2d: 3-39 [1, 512, 2, 2] 1,179,648 [256, 512, 3, 3] 4,718,592 -│ │ └─BatchNorm2d: 3-40 [1, 512, 2, 2] 1,024 [512] 1,024 +│ │ └─Conv2d: 3-39 [1, 512, 2, 2] 1,179,648 [3, 3] 4,718,592 +│ │ └─BatchNorm2d: 3-40 [1, 512, 2, 2] 1,024 -- 1,024 │ │ └─ReLU: 3-41 [1, 512, 2, 2] -- -- -- -│ │ └─Conv2d: 3-42 [1, 512, 2, 2] 2,359,296 [512, 512, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-43 [1, 512, 2, 2] 1,024 [512] 1,024 +│ │ └─Conv2d: 3-42 [1, 512, 2, 2] 2,359,296 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-43 [1, 512, 2, 2] 1,024 -- 1,024 │ │ └─Sequential: 3-44 [1, 512, 2, 2] 132,096 -- 525,312 │ │ └─ReLU: 3-45 [1, 512, 2, 2] -- -- -- │ └─BasicBlock: 2-8 [1, 512, 2, 2] -- -- -- -│ │ └─Conv2d: 3-46 [1, 512, 2, 2] 2,359,296 [512, 512, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-47 [1, 512, 2, 2] 1,024 [512] 1,024 +│ │ └─Conv2d: 3-46 [1, 512, 2, 2] 2,359,296 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-47 [1, 512, 2, 2] 1,024 -- 1,024 │ │ └─ReLU: 3-48 [1, 512, 2, 2] -- -- -- -│ │ └─Conv2d: 3-49 [1, 512, 2, 2] 2,359,296 [512, 512, 3, 3] 9,437,184 -│ │ └─BatchNorm2d: 3-50 [1, 512, 2, 2] 1,024 [512] 1,024 +│ │ └─Conv2d: 3-49 [1, 512, 2, 2] 2,359,296 [3, 3] 9,437,184 +│ │ └─BatchNorm2d: 3-50 [1, 512, 2, 2] 1,024 -- 1,024 │ │ └─ReLU: 3-51 [1, 512, 2, 2] -- -- -- ├─AdaptiveAvgPool2d: 1-9 [1, 512, 1, 1] -- -- -- -├─Linear: 1-10 [1, 1000] 513,000 [512, 1000] 513,000 +├─Linear: 1-10 [1, 1000] 513,000 -- 513,000 ============================================================================================================================================ Total params: 11,689,512 Trainable params: 11,006,440 diff --git a/tests/test_output/lstm.out b/tests/test_output/lstm.out index 64916b4..d3fdf18 100644 --- a/tests/test_output/lstm.out +++ b/tests/test_output/lstm.out @@ -2,7 +2,7 @@ Layer (type (var_name)) Kernel Shape Output Shape Param # Mult-Adds ======================================================================================================================== LSTMNet -- -- -- -- -├─Embedding (embedding) [300, 20] [1, 100, 300] 6,000 6,000 +├─Embedding (embedding) -- [1, 100, 300] 6,000 6,000 │ └─weight [300, 20] └─6,000 ├─LSTM (encoder) -- [1, 100, 512] 3,768,320 376,832,000 │ └─weight_ih_l0 [2048, 300] ├─614,400 @@ -13,7 +13,7 @@ LSTMNet -- -- │ └─weight_hh_l1 [2048, 512] ├─1,048,576 │ └─bias_ih_l1 [2048] ├─2,048 │ └─bias_hh_l1 [2048] └─2,048 -├─Linear (decoder) [512, 20] [1, 100, 20] 10,260 10,260 +├─Linear (decoder) -- [1, 100, 20] 10,260 10,260 │ └─weight [512, 20] ├─10,240 │ └─bias [20] └─20 ======================================================================================================================== diff --git a/tests/test_output/single_input_all_cols.out b/tests/test_output/single_input_all_cols.out index f8dc5fa..7eb08f6 100644 --- a/tests/test_output/single_input_all_cols.out +++ b/tests/test_output/single_input_all_cols.out @@ -2,11 +2,11 @@ Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds ============================================================================================================================================ SingleInputNet -- -- -- -- -- -├─Conv2d: 1-1 [1, 10, 5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320 -├─Conv2d: 1-2 [10, 20, 5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960 +├─Conv2d: 1-1 [5, 5] [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320 +├─Conv2d: 1-2 [5, 5] [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960 ├─Dropout2d: 1-3 -- [7, 20, 8, 8] [7, 20, 8, 8] -- -- -├─Linear: 1-4 [320, 50] [7, 320] [7, 50] 16,050 112,350 -├─Linear: 1-5 [50, 10] [7, 50] [7, 10] 510 3,570 +├─Linear: 1-4 -- [7, 320] [7, 50] 16,050 112,350 +├─Linear: 1-5 -- [7, 50] [7, 10] 510 3,570 ============================================================================================================================================ Total params: 21,840 Trainable params: 21,840 diff --git a/tests/test_output/single_input_batch_dim.out b/tests/test_output/single_input_batch_dim.out index 8cffc3c..14e0f38 100644 --- a/tests/test_output/single_input_batch_dim.out +++ b/tests/test_output/single_input_batch_dim.out @@ -2,11 +2,11 @@ Layer (type:depth-idx) Kernel Shape Input Shape Output Shape Param # Mult-Adds ============================================================================================================================================ SingleInputNet -- -- -- -- -- -├─Conv2d: 1-1 [1, 10, 5, 5] [1, 1, 28, 28] [1, 10, 24, 24] 260 149,760 -├─Conv2d: 1-2 [10, 20, 5, 5] [1, 10, 12, 12] [1, 20, 8, 8] 5,020 321,280 +├─Conv2d: 1-1 [5, 5] [1, 1, 28, 28] [1, 10, 24, 24] 260 149,760 +├─Conv2d: 1-2 [5, 5] [1, 10, 12, 12] [1, 20, 8, 8] 5,020 321,280 ├─Dropout2d: 1-3 -- [1, 20, 8, 8] [1, 20, 8, 8] -- -- -├─Linear: 1-4 [320, 50] [1, 320] [1, 50] 16,050 16,050 -├─Linear: 1-5 [50, 10] [1, 50] [1, 10] 510 510 +├─Linear: 1-4 -- [1, 320] [1, 50] 16,050 16,050 +├─Linear: 1-5 -- [1, 50] [1, 10] 510 510 ============================================================================================================================================ Total params: 21,840 Trainable params: 21,840 diff --git a/tests/torchinfo_test.py b/tests/torchinfo_test.py index 6a0e243..fd87683 100644 --- a/tests/torchinfo_test.py +++ b/tests/torchinfo_test.py @@ -379,7 +379,15 @@ def test_containers() -> None: def test_autoencoder() -> None: model = AutoEncoder() - summary(model, input_size=(1, 3, 64, 64)) + summary( + model, + input_size=(1, 3, 64, 64), + col_names=( + ColumnSettings.OUTPUT_SIZE, + ColumnSettings.NUM_PARAMS, + ColumnSettings.KERNEL_SIZE, + ), + ) def test_reusing_activation_layers() -> None: diff --git a/torchinfo/layer_info.py b/torchinfo/layer_info.py index f46cf5d..87189f3 100644 --- a/torchinfo/layer_info.py +++ b/torchinfo/layer_info.py @@ -54,7 +54,7 @@ def __init__( self.is_recursive = False self.input_size: list[int] = [] self.output_size: list[int] = [] - self.kernel_size: list[int] = [] + self.kernel_size = self.get_kernel_size(module) self.num_params = 0 self.param_bytes = 0 self.output_bytes = 0 @@ -145,6 +145,20 @@ def get_param_count( return parameter_count, without_suffix return param.nelement(), name + @staticmethod + def get_kernel_size(module: nn.Module) -> int | list[int] | None: + if hasattr(module, "kernel_size"): + k = module.kernel_size + kernel_size: int | list[int] + if isinstance(k, Iterable): + kernel_size = list(k) + elif isinstance(k, int): + kernel_size = int(k) + else: + raise TypeError(f"kernel_size has an unexpected type: {type(k)}") + return kernel_size + return None + def get_layer_name(self, show_var_name: bool, show_depth: bool) -> str: layer_name = self.class_name if show_var_name and self.var_name: @@ -171,12 +185,12 @@ def calculate_num_params(self) -> None: if param.requires_grad: self.trainable_params += cur_params + # kernel_size for inner layer parameters ksize = list(param.size()) if name == "weight": # to make [in_shape, out_shape, ksize, ksize] if len(ksize) > 1: ksize[0], ksize[1] = ksize[1], ksize[0] - self.kernel_size = ksize # RNN modules have inner weights such as weight_ih_l0 self.inner_layers[name] = {