Skip to content

Commit

Permalink
Update on "MemEff: Raise if wrong bias"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Nov 7, 2022
2 parents 2916a0e + 4d6d4fc commit af5321a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
10 changes: 7 additions & 3 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ def generate_test_shapes():
shapes = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# GPT-3 (small)
(2048, 2048, 5632),
# Chinchilla
(2048, 8192, 22016),
]
# Add some random shapes
r = random.Random(0)
Expand Down
6 changes: 3 additions & 3 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
SHAPES = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# Some smaller shapes as well
(4728, 1536, 1024),
# GPT-3 (small)
Expand Down
19 changes: 7 additions & 12 deletions xformers/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,25 @@ class _SwiGLUModule(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
hidden_features: int,
out_features: Optional[int] = None,
align_as: int = 8,
pack_weights: bool = False,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
swiglu_hidden_features = int(2 * hidden_features / 3)
swiglu_hidden_features = (
(swiglu_hidden_features + align_as - 1) // align_as * align_as
)

self.w12: Optional[nn.Linear]
if pack_weights:
self.w12 = nn.Linear(in_features, 2 * swiglu_hidden_features, bias=bias)
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
else:
self.w12 = None
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias)
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias)
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)

self.swiglu_hidden_features = swiglu_hidden_features
self.hidden_features = hidden_features
self.out_features = out_features
self.in_features = in_features

Expand All @@ -60,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
```
"""
if self.w12 is not None:
x12 = self.w12(x).view([x.shape[0], 2, self.swiglu_hidden_features])
x12 = self.w12(x).view([x.shape[0], 2, self.hidden_features])
x1, x2 = unbind(x12, dim=1)
else:
x1 = self.w1(x)
Expand Down

0 comments on commit af5321a

Please sign in to comment.