Skip to content

Commit

Permalink
update helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jun 22, 2023
1 parent 1fb97d8 commit 8fd5214
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 21 deletions.
53 changes: 45 additions & 8 deletions ivy_models/helpers/weights_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@


def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
if raw_keys_to_prune != []:
pruned_ref = []
if raw_keys_to_prune:
for kc in raw_keys_to_prune:
raw = raw.cont_prune_key_from_key_chains(containing=kc)
if ref_keys_to_prune != []:
raw = raw.cont_prune_key_from_key_chains(absolute=kc)
if ref_keys_to_prune:
for kc in ref_keys_to_prune:
ref = ref.cont_prune_key_from_key_chains(containing=kc)
return raw, ref
pruned_ref.append(ref.cont_at_keys(kc))
ref = ref.cont_prune_key_from_key_chains(absolute=kc)
return raw, ref, pruned_ref


def _map_weights(raw, ref, custom_mapping=None):
Expand Down Expand Up @@ -40,14 +42,38 @@ def load_torch_weights(url, ref_model, custom_mapping=None):
return ivy.asarray(w_clean)


def _rename_weights(raw, ref, rename_dict):
renamed_ref = []
for raw_key, ref_key in rename_dict.items():
old_v = raw.cont_at_keys(raw_key)
new_v = ref.cont_at_keys(ref_key)
mapping = {}
for old_key, new_key in zip(
old_v.cont_sort_by_key().cont_to_iterator_keys(),
new_v.cont_sort_by_key().cont_to_iterator_keys(),
):
mapping[old_key] = new_key

raw = raw.cont_prune_key_from_key_chains(absolute=raw_key)
ref = ref.cont_prune_key_from_key_chains(absolute=ref_key)
renamed_ref.append(old_v.cont_restructure(mapping, keep_orig=False))
return raw, ref, renamed_ref


def load_jax_weights(
url, ref_model, custom_mapping=None, raw_keys_to_prune=[], ref_keys_to_prune=[]
url,
ref_model,
custom_mapping=None,
raw_keys_to_prune=[],
ref_keys_to_prune=[],
special_rename={},
):
import urllib.request
import os
import pickle

ivy.set_backend("jax")
# todo: refactor this into a url load helper
urllib.request.urlretrieve(url, filename="jax_weights.pystate")
with open("jax_weights.pystate", "rb") as f:
weights = pickle.loads(f.read())
Expand All @@ -59,11 +85,22 @@ def load_jax_weights(
pass

weights_raw = ivy.to_numpy(ivy.Container(weights))
weights_raw, weights_ref = _prune_keys(
weights_raw, ref_model.v, raw_keys_to_prune, ref_keys_to_prune
weights_ref = ref_model.v
weights_raw, weights_ref, pruned_ref = _prune_keys(
weights_raw, weights_ref, raw_keys_to_prune, ref_keys_to_prune
)

if special_rename:
weights_raw, weights_ref, renamed_ref = _rename_weights(
weights_raw, weights_ref, special_rename
)
mapping = _map_weights(weights_raw, weights_ref, custom_mapping=custom_mapping)

ivy.previous_backend()
w_clean = weights_raw.cont_restructure(mapping, keep_orig=False)

if special_rename:
w_clean = ivy.Container.cont_combine(w_clean, *renamed_ref)
if ref_keys_to_prune:
w_clean = ivy.Container.cont_combine(w_clean, *pruned_ref)
return ivy.asarray(w_clean)
26 changes: 13 additions & 13 deletions ivy_models/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def __init__(self, dim, fn, context_dim=None, eps=1e-05, device=None, v=None):
else None
)
ivy.Module.__init__(self, v=v, device=device)
if self.v.cont_has_key_chain("attention/to_q/b"):
self.v = self.v.cont_restructure(
{
"attention/to_q/b": "attention/linear/b",
"attention/to_q/w": "attention/linear/w",
}
)
elif self.v.cont_has_key_chain("attention/mlp/submodules/v0/b"):
self.v = self.v.cont_restructure(
{"norm/bias": "a_norm/bias", "norm/weight": "a_norm/weight"}
)
# if self.v.cont_has_key_chain("attention/to_q/b"):
# self.v = self.v.cont_restructure(
# {
# "attention/to_q/b": "attention/linear/b",
# "attention/to_q/w": "attention/linear/w",
# }
# )
# elif self.v.cont_has_key_chain("attention/mlp/submodules/v0/b"):
# self.v = self.v.cont_restructure(
# {"norm/bias": "a_norm/bias", "norm/weight": "a_norm/weight"}
# )

def _forward(self, x, **kwargs):
x = self._norm(x)
Expand All @@ -35,7 +35,7 @@ def _forward(self, x, **kwargs):

class FeedForward(ivy.Module):
def __init__(self, dim, dropout=0.0, device=None, v=None):
self._mlp = ivy.Sequential(
self._net = ivy.Sequential(
ivy.Linear(dim, dim, device=device),
ivy.GELU(),
ivy.Linear(dim, dim, device=device),
Expand All @@ -45,7 +45,7 @@ def __init__(self, dim, dropout=0.0, device=None, v=None):
ivy.Module.__init__(self, v=v)

def _forward(self, x):
return self._mlp(x)
return self._net(x)


def _perceiver_jax_weights_mapping():
Expand Down

0 comments on commit 8fd5214

Please sign in to comment.