Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Loading of checkpoint fails #7

Open
frdnd opened this issue Jun 23, 2021 · 5 comments
Open

Loading of checkpoint fails #7

frdnd opened this issue Jun 23, 2021 · 5 comments

Comments

@frdnd
Copy link

frdnd commented Jun 23, 2021

The semseg training command (from semantic_segmentation/README.md)

tools/dist_train.sh configs/xcit/sem_fpn/sem_fpn_xcit_small_12_p16_80k_ade20k.py 8 --work-dir /path/to/save --seed 0 --deterministic --options model.pretrained=https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth

fails with

TypeError: EncoderDecoder: XCiT: __init__() got an unexpected keyword argument 'pretrained'
@frdnd
Copy link
Author

frdnd commented Jun 23, 2021

This is a fix which works for checkpoints provided as local files, it fails however for models given by URL.

I think this is due internals of mmcv-fulls load_checkpoint handling. It succeeds with local checkpoints, since the internally called load_from_local returns a checkpoint, whereas URL triggers load_from_http which returns the state_dict directly and crashes.

Which version of mmcv-full did you use for your project, maybe this error is only present in the latest version I'm using (mmcv-full==1.3.7)?

@cyrilzakka
Copy link

Tried your fix but still getting an error where none of the weights are loaded from the pretrained model. Will try load_checkpoint() from here and report back.

@frdnd
Copy link
Author

frdnd commented Jun 30, 2021

Did you succeed? The fix also uses load_checkpoint via the init-weights (see here), which loads the backbone weights in the xcit-model. The weights for the neck included in the checkpoint aren't loaded with this approach, but this was better than nothing and helped with transfer learning.

@cyrilzakka
Copy link

cyrilzakka commented Jul 16, 2021

No luck! I get this:
missing keys in source state_dict:
patch_embed.proj.0.0.weight, patch_embed.proj.0.1.weight, patch_embed.proj.0.1.bias, patch_embed.proj.0.1.running_mean, patch_embed.proj.0.1.running_var, patch_embed.proj.2.0.weight, patch_embed.proj.2.1.weight, patch_embed.proj.2.1.bias, patch_embed.proj.2.1.running_mean, patch_embed.proj.2.1.running_var, patch_embed.proj.4.0.weight, patch_embed.proj.4.1.weight, patch_embed.proj.4.1.bias, patch_embed.proj.4.1.running_mean, patch_embed.proj.4.1.running_var, patch_embed.proj.6.0.weight, patch_embed.proj.6.1.weight, patch_embed.proj.6.1.bias, patch_embed.proj.6.1.running_mean, patch_embed.proj.6.1.running_var, blocks.0.gamma1, blocks.0.gamma2, blocks.0.gamma3, blocks.0.norm1.weight, blocks.0.norm1.bias, blocks.0.attn.temperature, blocks.0.attn.qkv.weight, blocks.0.attn.qkv.bias, blocks.0.attn.proj.weight, blocks.0.attn.proj.bias, blocks.0.norm2.weight, blocks.0.norm2.bias, blocks.0.mlp.fc1.weight, blocks.0.mlp.fc1.bias, blocks.0.mlp.fc2.weight, blocks.0.mlp.fc2.bias, blocks.0.norm3.weight, blocks.0.norm3.bias, blocks.0.local_mp.conv1.weight, blocks.0.local_mp.conv1.bias, blocks.0.local_mp.bn.weight, blocks.0.local_mp.bn.bias, blocks.0.local_mp.bn.running_mean, blocks.0.local_mp.bn.running_var, blocks.0.local_mp.conv2.weight, blocks.0.local_mp.conv2.bias, blocks.1.gamma1, blocks.1.gamma2, blocks.1.gamma3, blocks.1.norm1.weight, blocks.1.norm1.bias, blocks.1.attn.temperature, blocks.1.attn.qkv.weight, blocks.1.attn.qkv.bias, blocks.1.attn.proj.weight, blocks.1.attn.proj.bias, blocks.1.norm2.weight, blocks.1.norm2.bias, blocks.1.mlp.fc1.weight, blocks.1.mlp.fc1.bias, blocks.1.mlp.fc2.weight, blocks.1.mlp.fc2.bias, blocks.1.norm3.weight, blocks.1.norm3.bias, blocks.1.local_mp.conv1.weight, blocks.1.local_mp.conv1.bias, blocks.1.local_mp.bn.weight, blocks.1.local_mp.bn.bias, blocks.1.local_mp.bn.running_mean, blocks.1.local_mp.bn.running_var, blocks.1.local_mp.conv2.weight, blocks.1.local_mp.conv2.bias, blocks.2.gamma1, blocks.2.gamma2, blocks.2.gamma3, blocks.2.norm1.weight, blocks.2.norm1.bias, blocks.2.attn.temperature, blocks.2.attn.qkv.weight, blocks.2.attn.qkv.bias, blocks.2.attn.proj.weight, blocks.2.attn.proj.bias, blocks.2.norm2.weight, blocks.2.norm2.bias, blocks.2.mlp.fc1.weight, blocks.2.mlp.fc1.bias, blocks.2.mlp.fc2.weight, blocks.2.mlp.fc2.bias, blocks.2.norm3.weight, blocks.2.norm3.bias, blocks.2.local_mp.conv1.weight, blocks.2.local_mp.conv1.bias, blocks.2.local_mp.bn.weight, blocks.2.local_mp.bn.bias, blocks.2.local_mp.bn.running_mean, blocks.2.local_mp.bn.running_var, blocks.2.local_mp.conv2.weight, blocks.2.local_mp.conv2.bias, blocks.3.gamma1, blocks.3.gamma2, blocks.3.gamma3, blocks.3.norm1.weight, blocks.3.norm1.bias, blocks.3.attn.temperature, blocks.3.attn.qkv.weight, blocks.3.attn.qkv.bias, blocks.3.attn.proj.weight, blocks.3.attn.proj.bias, blocks.3.norm2.weight, blocks.3.norm2.bias, blocks.3.mlp.fc1.weight, blocks.3.mlp.fc1.bias, blocks.3.mlp.fc2.weight, blocks.3.mlp.fc2.bias, blocks.3.norm3.weight, blocks.3.norm3.bias, blocks.3.local_mp.conv1.weight, blocks.3.local_mp.conv1.bias, blocks.3.local_mp.bn.weight, blocks.3.local_mp.bn.bias, blocks.3.local_mp.bn.running_mean, blocks.3.local_mp.bn.running_var, blocks.3.local_mp.conv2.weight, blocks.3.local_mp.conv2.bias, blocks.4.gamma1, blocks.4.gamma2, blocks.4.gamma3, blocks.4.norm1.weight, blocks.4.norm1.bias, blocks.4.attn.temperature, blocks.4.attn.qkv.weight, blocks.4.attn.qkv.bias, blocks.4.attn.proj.weight, blocks.4.attn.proj.bias, blocks.4.norm2.weight, blocks.4.norm2.bias, blocks.4.mlp.fc1.weight, blocks.4.mlp.fc1.bias, blocks.4.mlp.fc2.weight, blocks.4.mlp.fc2.bias, blocks.4.norm3.weight, blocks.4.norm3.bias, blocks.4.local_mp.conv1.weight, blocks.4.local_mp.conv1.bias, blocks.4.local_mp.bn.weight, blocks.4.local_mp.bn.bias, blocks.4.local_mp.bn.running_mean, blocks.4.local_mp.bn.running_var, blocks.4.local_mp.conv2.weight, blocks.4.local_mp.conv2.bias, blocks.5.gamma1, blocks.5.gamma2, blocks.5.gamma3, blocks.5.norm1.weight, blocks.5.norm1.bias, blocks.5.attn.temperature, blocks.5.attn.qkv.weight, blocks.5.attn.qkv.bias, blocks.5.attn.proj.weight, blocks.5.attn.proj.bias, blocks.5.norm2.weight, blocks.5.norm2.bias, blocks.5.mlp.fc1.weight, blocks.5.mlp.fc1.bias, blocks.5.mlp.fc2.weight, blocks.5.mlp.fc2.bias, blocks.5.norm3.weight, blocks.5.norm3.bias, blocks.5.local_mp.conv1.weight, blocks.5.local_mp.conv1.bias, blocks.5.local_mp.bn.weight, blocks.5.local_mp.bn.bias, blocks.5.local_mp.bn.running_mean, blocks.5.local_mp.bn.running_var, blocks.5.local_mp.conv2.weight, blocks.5.local_mp.conv2.bias, blocks.6.gamma1, blocks.6.gamma2, blocks.6.gamma3, blocks.6.norm1.weight, blocks.6.norm1.bias, blocks.6.attn.temperature, blocks.6.attn.qkv.weight, blocks.6.attn.qkv.bias, blocks.6.attn.proj.weight, blocks.6.attn.proj.bias, blocks.6.norm2.weight, blocks.6.norm2.bias, blocks.6.mlp.fc1.weight, blocks.6.mlp.fc1.bias, blocks.6.mlp.fc2.weight, blocks.6.mlp.fc2.bias, blocks.6.norm3.weight, blocks.6.norm3.bias, blocks.6.local_mp.conv1.weight, blocks.6.local_mp.conv1.bias, blocks.6.local_mp.bn.weight, blocks.6.local_mp.bn.bias, blocks.6.local_mp.bn.running_mean, blocks.6.local_mp.bn.running_var, blocks.6.local_mp.conv2.weight, blocks.6.local_mp.conv2.bias, blocks.7.gamma1, blocks.7.gamma2, blocks.7.gamma3, blocks.7.norm1.weight, blocks.7.norm1.bias, blocks.7.attn.temperature, blocks.7.attn.qkv.weight, blocks.7.attn.qkv.bias, blocks.7.attn.proj.weight, blocks.7.attn.proj.bias, blocks.7.norm2.weight, blocks.7.norm2.bias, blocks.7.mlp.fc1.weight, blocks.7.mlp.fc1.bias, blocks.7.mlp.fc2.weight, blocks.7.mlp.fc2.bias, blocks.7.norm3.weight, blocks.7.norm3.bias, blocks.7.local_mp.conv1.weight, blocks.7.local_mp.conv1.bias, blocks.7.local_mp.bn.weight, blocks.7.local_mp.bn.bias, blocks.7.local_mp.bn.running_mean, blocks.7.local_mp.bn.running_var, blocks.7.local_mp.conv2.weight, blocks.7.local_mp.conv2.bias, blocks.8.gamma1, blocks.8.gamma2, blocks.8.gamma3, blocks.8.norm1.weight, blocks.8.norm1.bias, blocks.8.attn.temperature, blocks.8.attn.qkv.weight, blocks.8.attn.qkv.bias, blocks.8.attn.proj.weight, blocks.8.attn.proj.bias, blocks.8.norm2.weight, blocks.8.norm2.bias, blocks.8.mlp.fc1.weight, blocks.8.mlp.fc1.bias, blocks.8.mlp.fc2.weight, blocks.8.mlp.fc2.bias, blocks.8.norm3.weight, blocks.8.norm3.bias, blocks.8.local_mp.conv1.weight, blocks.8.local_mp.conv1.bias, blocks.8.local_mp.bn.weight, blocks.8.local_mp.bn.bias, blocks.8.local_mp.bn.running_mean, blocks.8.local_mp.bn.running_var, blocks.8.local_mp.conv2.weight, blocks.8.local_mp.conv2.bias, blocks.9.gamma1, blocks.9.gamma2, blocks.9.gamma3, blocks.9.norm1.weight, blocks.9.norm1.bias, blocks.9.attn.temperature, blocks.9.attn.qkv.weight, blocks.9.attn.qkv.bias, blocks.9.attn.proj.weight, blocks.9.attn.proj.bias, blocks.9.norm2.weight, blocks.9.norm2.bias, blocks.9.mlp.fc1.weight, blocks.9.mlp.fc1.bias, blocks.9.mlp.fc2.weight, blocks.9.mlp.fc2.bias, blocks.9.norm3.weight, blocks.9.norm3.bias, blocks.9.local_mp.conv1.weight, blocks.9.local_mp.conv1.bias, blocks.9.local_mp.bn.weight, blocks.9.local_mp.bn.bias, blocks.9.local_mp.bn.running_mean, blocks.9.local_mp.bn.running_var, blocks.9.local_mp.conv2.weight, blocks.9.local_mp.conv2.bias, blocks.10.gamma1, blocks.10.gamma2, blocks.10.gamma3, blocks.10.norm1.weight, blocks.10.norm1.bias, blocks.10.attn.temperature, blocks.10.attn.qkv.weight, blocks.10.attn.qkv.bias, blocks.10.attn.proj.weight, blocks.10.attn.proj.bias, blocks.10.norm2.weight, blocks.10.norm2.bias, blocks.10.mlp.fc1.weight, blocks.10.mlp.fc1.bias, blocks.10.mlp.fc2.weight, blocks.10.mlp.fc2.bias, blocks.10.norm3.weight, blocks.10.norm3.bias, blocks.10.local_mp.conv1.weight, blocks.10.local_mp.conv1.bias, blocks.10.local_mp.bn.weight, blocks.10.local_mp.bn.bias, blocks.10.local_mp.bn.running_mean, blocks.10.local_mp.bn.running_var, blocks.10.local_mp.conv2.weight, blocks.10.local_mp.conv2.bias, blocks.11.gamma1, blocks.11.gamma2, blocks.11.gamma3, blocks.11.norm1.weight, blocks.11.norm1.bias, blocks.11.attn.temperature, blocks.11.attn.qkv.weight, blocks.11.attn.qkv.bias, blocks.11.attn.proj.weight, blocks.11.attn.proj.bias, blocks.11.norm2.weight, blocks.11.norm2.bias, blocks.11.mlp.fc1.weight, blocks.11.mlp.fc1.bias, blocks.11.mlp.fc2.weight, blocks.11.mlp.fc2.bias, blocks.11.norm3.weight, blocks.11.norm3.bias, blocks.11.local_mp.conv1.weight, blocks.11.local_mp.conv1.bias, blocks.11.local_mp.bn.weight, blocks.11.local_mp.bn.bias, blocks.11.local_mp.bn.running_mean, blocks.11.local_mp.bn.running_var, blocks.11.local_mp.conv2.weight, blocks.11.local_mp.conv2.bias, pos_embeder.token_projection.weight, pos_embeder.token_projection.bias, fpn1.0.weight, fpn1.0.bias, fpn1.1.weight, fpn1.1.bias, fpn1.1.running_mean, fpn1.1.running_var, fpn1.3.weight, fpn1.3.bias, fpn2.0.weight, fpn2.0.bias

@cyrilzakka
Copy link

Here's how I fixed it:

def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        pretrained = 'pretrained/xcit_small_12_cp16_dino.pth'
        print("Loading pretrained weights from checkpoint", pretrained)
        checkpoint = torch.load(pretrained, map_location='cpu')

        checkpoint_model = checkpoint['model']
        state_dict = self.state_dict()
        self.load_state_dict(checkpoint_model, strict=False)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants