Skip to content

Commit

Permalink
move changes
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jun 21, 2023
1 parent 5419f18 commit 1fb97d8
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions ivy_models/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class PreNorm(ivy.Module):
def __init__(self, dim, fn, context_dim=None, eps=1e-05, device=None, v=None):
self._fn = fn
self._attention = fn
self._norm = ivy.LayerNorm([dim], eps=eps, device=device)
if isinstance(context_dim, int):
context_dim = [context_dim]
Expand All @@ -14,17 +14,28 @@ def __init__(self, dim, fn, context_dim=None, eps=1e-05, device=None, v=None):
else None
)
ivy.Module.__init__(self, v=v, device=device)
if self.v.cont_has_key_chain("attention/to_q/b"):
self.v = self.v.cont_restructure(
{
"attention/to_q/b": "attention/linear/b",
"attention/to_q/w": "attention/linear/w",
}
)
elif self.v.cont_has_key_chain("attention/mlp/submodules/v0/b"):
self.v = self.v.cont_restructure(
{"norm/bias": "a_norm/bias", "norm/weight": "a_norm/weight"}
)

def _forward(self, x, **kwargs):
x = self._norm(x)
if ivy.exists(self._norm_context):
kwargs.update(context=self._norm_context(kwargs["context"]))
return self._fn(x, **kwargs)
return self._attention(x, **kwargs)


class FeedForward(ivy.Module):
def __init__(self, dim, dropout=0.0, device=None, v=None):
self._net = ivy.Sequential(
self._mlp = ivy.Sequential(
ivy.Linear(dim, dim, device=device),
ivy.GELU(),
ivy.Linear(dim, dim, device=device),
Expand All @@ -34,4 +45,8 @@ def __init__(self, dim, dropout=0.0, device=None, v=None):
ivy.Module.__init__(self, v=v)

def _forward(self, x):
return self._net(x)
return self._mlp(x)


def _perceiver_jax_weights_mapping():
return

0 comments on commit 1fb97d8

Please sign in to comment.