Skip to content

Commit

Permalink
Do not error if there are None values in ModuleLists
Browse files Browse the repository at this point in the history
  • Loading branch information
TylerYep committed May 16, 2022
1 parent 7def795 commit e7be90b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
5 changes: 4 additions & 1 deletion tests/fixtures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,14 @@ def __init__(self) -> None:
self._layers.append(nn.Linear(5, 5))
self._layers.append(ContainerChildModule())
self._layers.append(nn.Linear(5, 5))
# Add None, but filter out this value later.
self._layers.append(None) # type: ignore[arg-type]

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x
for layer in self._layers:
out = layer(out)
if layer is not None:
out = layer(out)
return out


Expand Down
24 changes: 12 additions & 12 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,18 +566,18 @@ def hook(module: nn.Module, inputs: Any, outputs: Any) -> None:

# module.named_modules(remove_duplicate=False) doesn't work (infinite recursion).
for name, mod in module._modules.items(): # pylint: disable=protected-access
assert mod is not None
child = (name, mod)
apply_hooks(
child,
orig_model,
batch_dim,
summary_list,
hooks,
all_layers,
curr_depth + 1,
info,
)
if mod is not None:
child = (name, mod)
apply_hooks(
child,
orig_model,
batch_dim,
summary_list,
hooks,
all_layers,
curr_depth + 1,
info,
)


def clear_cached_forward_pass() -> None:
Expand Down

0 comments on commit e7be90b

Please sign in to comment.