Skip to content

Commit

Permalink
hotfix, dual bias in FusedMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed May 31, 2022
1 parent 2957a71 commit a188d5c
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions xformers/components/feedforward/fused_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,26 @@ def __init__(
dim_mlp = hidden_layer_multiplier * dim_model

self.mlp = nn.Sequential(
nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
nn.Linear(
in_features=dim_model, out_features=dim_mlp, bias=False
), # bias is handled in the next layer
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
# the `FusedLinear` import.
FusedDropoutBias(
p=dropout, bias_shape=dim_mlp, activation=activation
p=dropout,
bias_shape=dim_mlp if bias else None,
activation=activation,
),
nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
nn.Linear(
in_features=dim_mlp, out_features=dim_model, bias=False
), # bias is handled in the next layer
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
# the `FusedLinear` import.
FusedDropoutBias(p=dropout, bias_shape=dim_model, activation=None),
FusedDropoutBias(
p=dropout,
bias_shape=dim_model if bias else None,
activation=None,
),
)
self.requires_cuda = True

Expand Down

0 comments on commit a188d5c

Please sign in to comment.