Skip to content

Commit

Permalink
v1.7.2
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerYep committed Feb 5, 2023
1 parent 01fa0ce commit bef1aa9
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
args: [-C]
Expand All @@ -36,7 +36,7 @@ repos:
]

- repo: https://github.com/PyCQA/pylint
rev: v2.16.0b1
rev: v2.16.1
hooks:
- id: pylint
args: ["--disable=import-error"]
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def __init__(self) -> None:
self.linear = nn.Linear(3, 1)

def forward(self, input_list: dict[str, torch.Tensor]) -> dict[str, IntWithGetitem]:
x = input_list["foo"] if input_list else torch.ones(3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = input_list["foo"] if input_list else torch.ones(3).to(device)
x = self.linear(x)
return {"foo": IntWithGetitem(x)}

Expand Down
2 changes: 1 addition & 1 deletion torchinfo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"Units",
"Verbosity",
)
__version__ = "1.7.1"
__version__ = "1.7.2"
1 change: 1 addition & 0 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(

# TODO: Figure out why the below functions using max() are ever 0
# (they should always be non-negative), and remove the call to max().
# Investigation: https://github.com/TylerYep/torchinfo/pull/195
for layer_info in summary_list:
if layer_info.is_leaf_layer:
self.total_mult_adds += layer_info.macs
Expand Down
1 change: 0 additions & 1 deletion torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ def add_missing_container_layers(summary_list: list[LayerInfo]) -> None:
d not in current_hierarchy
or current_hierarchy[d].module is not hierarchy[d].module
) and hierarchy[d] is not summary_list[idx + rel_idx - 1]:

hierarchy[d].calculate_num_params()
hierarchy[d].check_recursive(layer_ids)
summary_list.insert(idx + rel_idx, hierarchy[d])
Expand Down

0 comments on commit bef1aa9

Please sign in to comment.