diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index ebe8ff213d4b..f1f9c736019a 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -34,6 +34,15 @@ except ImportError: fused_rotary_position_embedding = None +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + from paddle.utils import try_import @@ -568,10 +577,10 @@ def __init__(self, config): def forward(self, x): if self.fuse_attention_ffn: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + x = swiglu(self.gate_up_fused_proj(x)) else: - out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) return out