Skip to content

Commit

Permalink
dependency update
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jul 26, 2023
1 parent 80abcf8 commit b950494
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
3 changes: 3 additions & 0 deletions ivy_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@

from . import densenet
from .densenet import *

from . import bert
from .bert import *
12 changes: 9 additions & 3 deletions ivy_models/bert/bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ivy
from ivy_models.base import BaseModel, BaseSpec
from ivy_mdoels.helpers import load_transformers_weights
from ivy_models.helpers import load_transformers_weights
from .layers import BertAttention, BertFeedForward, BertEmbedding


Expand Down Expand Up @@ -194,6 +194,10 @@ def __init__(self, config: BertConfig, pooler_out=False, v=None):
self.pooler_out = pooler_out
super(BertModel, self).__init__(v=v)

@classmethod
def get_spec_class(self):
return BertConfig

def _build(self, *args, **kwargs):
self.embeddings = BertEmbedding(**self.config.get_embd_attrs())
self.encoder = BertEncoder(self.config)
Expand Down Expand Up @@ -290,6 +294,8 @@ def bert_base_uncased(pretrained=True):
)
model = BertModel(config, pooler_out=True)
if pretrained:
mapping = load_transformers_weights(model, _bert_weights_mapping)
model = BertModel(config, True, v=mapping)
w_clean = load_transformers_weights(
"bert-base-uncased", model, _bert_weights_mapping
)
model.v = w_clean
return model
13 changes: 7 additions & 6 deletions ivy_models/helpers/weights_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import urllib
import os
import copy
from transformers import AutoModel


def _prune_keys(raw, ref, raw_keys_to_prune=[], ref_keys_to_prune=[]):
Expand Down Expand Up @@ -164,18 +163,20 @@ def unflatten_set(container, name, to_set, split_on="__"):
cont[splits[-1]] = to_set


def load_transformers_weights(
model, map_fn, model_name="bert-base-uncased", split_on="__"
):
base = AutoModel.from_pretrained(model_name)
def load_transformers_weights(hf_repo, model, map_fn, split_on="__"):
from transformers import AutoModel

base = AutoModel.from_pretrained(hf_repo)
ref_weights = base.state_dict()
ref_weights = ivy.to_numpy(ivy.Container(ref_weights))

ivy.set_backend("torch")
ref_weights = ivy.to_numpy(ivy.Container(ref_weights))
old_mapping = copy.deepcopy(model.v)
param_names = old_mapping.cont_flatten_key_chains().keys()
mapping_list = map(lambda x: map_fn(x), param_names)
mapping = dict(zip(param_names, mapping_list))
ivy.previous_backend()

for old_name, ref_name in mapping.items():
to_set = ivy.asarray(ref_weights[ref_name])
unflatten_set(old_mapping, old_name, to_set, split_on)
Expand Down

0 comments on commit b950494

Please sign in to comment.