Skip to content

Commit

Permalink
Separate nested_list_size function, add some documentation, improve m…
Browse files Browse the repository at this point in the history
…ypy for setuptools (#220)
  • Loading branch information
TylerYep authored Feb 5, 2023
1 parent c879e2a commit 01fa0ce
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 47 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install mypy pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install transformers
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
pip install compressai
- name: mypy
if: ${{ matrix.pytorch-version == '1.13' }}
run: |
mypy .
mypy --install-types --non-interactive .
- name: pytest
if: ${{ matrix.pytorch-version == '1.13' }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random

import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa
from tqdm import trange # type: ignore[import] # pylint: disable=unused-import # noqa
from tqdm import trange # pylint: disable=unused-import # noqa

from torchinfo import summary # pylint: disable=unused-import # noqa

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ torchinfo = py.typed
[mypy]
strict = True
implicit_reexport = True
show_error_codes = True
enable_error_code = ignore-without-code

[pylint.main]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from setuptools import setup # type: ignore[import]
from setuptools import setup

setup()
55 changes: 26 additions & 29 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,38 +93,10 @@ def calculate_size(
Returns the corrected shape of `inputs` and the size of
a single element in bytes.
"""

def nested_list_size(
inputs: Sequence[Any] | torch.Tensor,
) -> tuple[list[int], int]:
"""Flattens nested list size."""

if hasattr(inputs, "tensors"):
size, elem_bytes = nested_list_size(inputs.tensors)
elif isinstance(inputs, torch.Tensor):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif not hasattr(inputs, "__getitem__") or not inputs:
size, elem_bytes = [], 0
elif isinstance(inputs, dict):
size, elem_bytes = nested_list_size(list(inputs.values()))
elif (
hasattr(inputs, "size")
and callable(inputs.size)
and hasattr(inputs, "element_size")
and callable(inputs.element_size)
):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif isinstance(inputs, (list, tuple)):
size, elem_bytes = nested_list_size(inputs[0])
else:
size, elem_bytes = [], 0

return size, elem_bytes

if inputs is None:
size, elem_bytes = [], 0

# pack_padded_seq and pad_packed_seq store feature into data attribute
# pack_padded_seq and pad_packed_seq store feature into data attribute
elif (
isinstance(inputs, (list, tuple)) and inputs and hasattr(inputs[0], "data")
):
Expand Down Expand Up @@ -337,6 +309,31 @@ def leftover_trainable_params(self) -> int:
)


def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], int]:
"""Flattens nested list size."""
if hasattr(inputs, "tensors"):
size, elem_bytes = nested_list_size(inputs.tensors)
elif isinstance(inputs, torch.Tensor):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif not hasattr(inputs, "__getitem__") or not inputs:
size, elem_bytes = [], 0
elif isinstance(inputs, dict):
size, elem_bytes = nested_list_size(list(inputs.values()))
elif (
hasattr(inputs, "size")
and callable(inputs.size)
and hasattr(inputs, "element_size")
and callable(inputs.element_size)
):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif isinstance(inputs, (list, tuple)):
size, elem_bytes = nested_list_size(inputs[0])
else:
size, elem_bytes = [], 0

return size, elem_bytes


def prod(num_list: Iterable[int] | torch.Size) -> int:
result = 1
if isinstance(num_list, Iterable):
Expand Down
18 changes: 6 additions & 12 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
self.total_params, self.trainable_params = 0, 0
self.total_param_bytes, self.total_output_bytes = 0, 0

# TODO: Figure out why the below functions using max() are ever 0
# (they should always be non-negative), and remove the call to max().
for layer_info in summary_list:
if layer_info.is_leaf_layer:
self.total_mult_adds += layer_info.macs
Expand All @@ -33,24 +35,16 @@ def __init__(
self.total_output_bytes += layer_info.output_bytes * 2
if layer_info.is_recursive:
continue
self.total_params += (
layer_info.num_params if layer_info.num_params > 0 else 0
)
self.total_params += max(layer_info.num_params, 0)
self.total_param_bytes += layer_info.param_bytes
self.trainable_params += (
layer_info.trainable_params
if layer_info.trainable_params > 0
else 0
)
self.trainable_params += max(layer_info.trainable_params, 0)
else:
if layer_info.is_recursive:
continue
leftover_params = layer_info.leftover_params()
leftover_trainable_params = layer_info.leftover_trainable_params()
self.total_params += leftover_params if leftover_params > 0 else 0
self.trainable_params += (
leftover_trainable_params if leftover_trainable_params > 0 else 0
)
self.total_params += max(leftover_params, 0)
self.trainable_params += max(leftover_trainable_params, 0)
self.formatting.set_layer_name_width(summary_list)

def __repr__(self) -> str:
Expand Down

0 comments on commit 01fa0ce

Please sign in to comment.