Skip to content

Commit

Permalink
more changes to helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jun 21, 2023
1 parent 5b38aa1 commit f62484b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
1 change: 1 addition & 0 deletions ivy_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import convnext
from . import helpers
from .transformers import perceiver_io
from .transformers.perceiver_io import *
from .resnet import *
from .vgg import *
from .convnext import *
60 changes: 51 additions & 9 deletions ivy_models/helpers/weights_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,69 @@
import ivy


def load_torch_weights(url, ref_model, custom_mapping=None):
import torch
def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
if raw_keys_to_prune != []:
for kc in raw_keys_to_prune:
raw = raw.cont_prune_key_from_key_chains(containing=kc)
if ref_keys_to_prune != []:
for kc in ref_keys_to_prune:
ref = ref.cont_prune_key_from_key_chains(containing=kc)
return raw, ref

old_backend = ivy.backend
ivy.set_backend("torch")
weights = torch.hub.load_state_dict_from_url(url)
weights_raw = ivy.to_numpy(ivy.Container(weights))

def _map_weights(raw, ref, custom_mapping=None):
mapping = {}
for old_key, new_key in zip(
weights_raw.cont_sort_by_key().cont_to_iterator_keys(),
ref_model.v.cont_sort_by_key().cont_to_iterator_keys(),
raw.cont_sort_by_key().cont_to_iterator_keys(),
ref.cont_sort_by_key().cont_to_iterator_keys(),
):
new_mapping = new_key
if custom_mapping is not None:
new_mapping = custom_mapping(old_key, new_key)
if new_mapping is None:
continue
mapping[old_key] = new_mapping
return mapping


def load_torch_weights(url, ref_model, custom_mapping=None):
import torch

ivy.set_backend("torch")
weights = torch.hub.load_state_dict_from_url(url)
weights_raw = ivy.to_numpy(ivy.Container(weights))
mapping = _map_weights(weights_raw, ref_model.v, custom_mapping=custom_mapping)

ivy.previous_backend()
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
return ivy.asarray(w_clean)


def load_jax_weights(
url, ref_model, custom_mapping=None, raw_keys_to_prune=[], ref_keys_to_prune=[]
):
import urllib.request
import os
import pickle

ivy.set_backend("jax")
urllib.request.urlretrieve(url, filename="jax_weights.pystate")
with open("jax_weights.pystate", "rb") as f:
weights = pickle.loads(f.read())
os.remove("jax_weights.pystate")

try:
weights = {**weights["params"], **weights["state"]}
except KeyError:
pass

weights_raw = ivy.to_numpy(ivy.Container(weights))
weights_raw, weights_ref = _prune_keys(
weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune
)
mapping = _map_weights(weights_raw, weights_ref, custom_mapping=custom_mapping)

ivy.set_backend(old_backend)
ivy.previous_backend()
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)
return ivy.asarray(w_clean)

Expand Down

0 comments on commit f62484b

Please sign in to comment.