Skip to content

Commit

Permalink
refactor weights loading without using pickled files - vgg, resnet (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy authored Jun 21, 2023
2 parents cc72331 + bfd4ba6 commit 3c263d0
Show file tree
Hide file tree
Showing 18 changed files with 546 additions and 697 deletions.
4 changes: 3 additions & 1 deletion ivy_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from . import resnet
from . import vgg
from . import convnext
from .transformers import *
from . import helpers
from .transformers import perceiver_io
from .transformers.perceiver_io import *
from .resnet import *
from .vgg import *
from .convnext import *
54 changes: 23 additions & 31 deletions ivy_models/convnext/convnext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ivy
import ivy_models
from ivy.stateful.module import Module
from ivy.stateful.initializers import Zeros, Ones, Constant

Expand Down Expand Up @@ -158,6 +159,25 @@ def _forward(self, x):
return self.v.weight[:, None, None] * x + self.v.bias[:, None, None]


def _convnext_torch_weights_mapping(old_key, new_key):
new_mapping = new_key
if "downsample_layers" in old_key:
if "0/0/bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "0/0/weight" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "b c h w-> h w c b"}
elif "downsample_layers/0" not in old_key and "1/bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "downsample_layers/0" not in old_key and "1/w" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
elif "dwconv" in old_key:
if "bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "weight" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "a 1 c d -> c d a"}
return new_mapping


def convnext(size: str, pretrained=True):
"""Loads a ConvNeXt with specified size, optionally pretrained."""
size_dict = {
Expand All @@ -181,36 +201,8 @@ def convnext(size: str, pretrained=True):
"large": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", # noqa
}

import torch

old_backend = ivy.backend
ivy.set_backend("torch")
weights = torch.hub.load_state_dict_from_url(weight_dl[size])
weights_raw = ivy.to_numpy(ivy.Container(weights))
reference_model = ConvNeXt(depths=depths, dims=dims)
mapping = {}
for old_key, new_key in zip(
weights_raw.cont_sort_by_key().cont_to_iterator_keys(),
reference_model.v.cont_sort_by_key().cont_to_iterator_keys(),
):
new_mapping = new_key
if "downsample_layers" in old_key:
if "0/0/bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "0/0/weight" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "b c h w-> h w c b"}
elif "downsample_layers/0" not in old_key and "1/bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "downsample_layers/0" not in old_key and "1/w" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
elif "dwconv" in old_key:
if "bias" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
elif "weight" in old_key:
new_mapping = {"key_chain": new_key, "pattern": "a 1 c d -> c d a"}
mapping[old_key] = new_mapping

ivy.set_backend(old_backend)
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
w_clean = ivy.asarray(w_clean)
w_clean = ivy_models.helpers.load_torch_weights(
weight_dl[size], reference_model, custom_mapping=_convnext_torch_weights_mapping
)
return ConvNeXt(depths=depths, dims=dims, v=w_clean)
2 changes: 1 addition & 1 deletion ivy_models/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .match_weights_helpers import *
from .weights_helpers import *
67 changes: 0 additions & 67 deletions ivy_models/helpers/match_weights_helpers.py

This file was deleted.

69 changes: 69 additions & 0 deletions ivy_models/helpers/weights_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# global
import ivy


def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
if raw_keys_to_prune != []:
for kc in raw_keys_to_prune:
raw = raw.cont_prune_key_from_key_chains(containing=kc)
if ref_keys_to_prune != []:
for kc in ref_keys_to_prune:
ref = ref.cont_prune_key_from_key_chains(containing=kc)
return raw, ref


def _map_weights(raw, ref, custom_mapping=None):
mapping = {}
for old_key, new_key in zip(
raw.cont_sort_by_key().cont_to_iterator_keys(),
ref.cont_sort_by_key().cont_to_iterator_keys(),
):
new_mapping = new_key
if custom_mapping is not None:
new_mapping = custom_mapping(old_key, new_key)
if new_mapping is None:
continue
mapping[old_key] = new_mapping
return mapping


def load_torch_weights(url, ref_model, custom_mapping=None):
import torch

ivy.set_backend("torch")
weights = torch.hub.load_state_dict_from_url(url)
weights_raw = ivy.to_numpy(ivy.Container(weights))
mapping = _map_weights(weights_raw, ref_model.v, custom_mapping=custom_mapping)

ivy.previous_backend()
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
return ivy.asarray(w_clean)


def load_jax_weights(
url, ref_model, custom_mapping=None, raw_keys_to_prune=[], ref_keys_to_prune=[]
):
import urllib.request
import os
import pickle

ivy.set_backend("jax")
urllib.request.urlretrieve(url, filename="jax_weights.pystate")
with open("jax_weights.pystate", "rb") as f:
weights = pickle.loads(f.read())
os.remove("jax_weights.pystate")

try:
weights = {**weights["params"], **weights["state"]}
except KeyError:
pass

weights_raw = ivy.to_numpy(ivy.Container(weights))
weights_raw, weights_ref = _prune_keys(
weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune
)
mapping = _map_weights(weights_raw, weights_ref, custom_mapping=custom_mapping)

ivy.previous_backend()
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
return ivy.asarray(w_clean)
3 changes: 0 additions & 3 deletions ivy_models/resnet/pretrained_weights/resnet_18.pickled

This file was deleted.

24 changes: 22 additions & 2 deletions ivy_models/resnet/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# global
import builtins
import ivy
import ivy_models


class ResidualBlock(ivy.Module):
Expand Down Expand Up @@ -149,9 +151,27 @@ def _forward(self, x):
return x


def resnet_18(v=None):
def _resnet_torch_weights_mapping(old_key, new_key):
W_KEY = ["conv1/weight", "conv2/weight", "downsample/0/weight"]
new_mapping = new_key
if builtins.any([kc in old_key for kc in W_KEY]):
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
elif "num_batches_tracked" in old_key:
new_mapping = None
return new_mapping


def resnet_18(pretrained=True):
"""ResNet-18 model"""
return ResNet(ResidualBlock, [2, 2, 2, 2], v=v)
if not pretrained:
return ResNet(ResidualBlock, [2, 2, 2, 2])

reference_model = ResNet(ResidualBlock, [2, 2, 2, 2])
url = "https://download.pytorch.org/models/resnet18-f37072fd.pth"
w_clean = ivy_models.helpers.load_torch_weights(
url, reference_model, custom_mapping=_resnet_torch_weights_mapping
)
return ResNet(ResidualBlock, [2, 2, 2, 2], v=w_clean)


def resnet_34(v=None):
Expand Down
3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg11.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg11_bn.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg13.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg13_bn.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg16.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg16_bn.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg19.pickled

This file was deleted.

3 changes: 0 additions & 3 deletions ivy_models/vgg/pretrained_weights/vgg19_bn.pickled

This file was deleted.

Loading

0 comments on commit 3c263d0

Please sign in to comment.