Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6268 enhance hovernet load pretrained function #6269

Merged
21 changes: 17 additions & 4 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ class HoVerNet(nn.Module):
adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this
value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,
this value should be `True`.
pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. It is used to
extract the expected state dict.
freeze_encoder: whether to freeze the encoder of the network.
"""

Expand All @@ -461,6 +463,7 @@ def __init__(
dropout_prob: float = 0.0,
pretrained_url: str | None = None,
adapt_standard_resnet: bool = False,
pretrained_state_dict_key: str | None = None,
freeze_encoder: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -566,7 +569,7 @@ def __init__(

if pretrained_url is not None:
if adapt_standard_resnet:
weights = _remap_standard_resnet_model(pretrained_url)
weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key)
else:
weights = _remap_preact_resnet_model(pretrained_url)
_load_pretrained_encoder(self, weights)
Expand Down Expand Up @@ -609,6 +612,12 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict):

model_dict.update(state_dict)
model.load_state_dict(model_dict)
if len(state_dict.keys()) == 0:
warnings.warn(
"no key will be updated. Please double confirm if the 'pretrained_url' is correct or `pretrained_state_dict_key` is reasonably set."
)
else:
print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.")


def _remap_preact_resnet_model(model_url: str):
Expand All @@ -619,7 +628,9 @@ def _remap_preact_resnet_model(model_url: str):
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None)["desc"]
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
"desc"
]
for key in list(state_dict.keys()):
new_key = None
if pattern_conv0.match(key):
Expand All @@ -639,7 +650,7 @@ def _remap_preact_resnet_model(model_url: str):
return state_dict


def _remap_standard_resnet_model(model_url: str):
def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None):
pattern_conv0 = re.compile(r"^conv1\.(.+)$")
pattern_bn1 = re.compile(r"^bn1\.(.+)$")
pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$")
Expand All @@ -652,7 +663,9 @@ def _remap_standard_resnet_model(model_url: str):
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None)
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
if state_dict_key is not None:
state_dict = state_dict[state_dict_key]

for key in list(state_dict.keys()):
new_key = None
Expand Down