From a5749005bcdc3b9933af40c3b58c5b24b8c3bfa7 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Sat, 2 Mar 2024 18:36:59 +0800 Subject: [PATCH] Add SwiGLU for auto Llama (#8038) --- paddlenlp/transformers/llama/modeling_auto.py | 16 +++++++++++++--- .../transformers/llama/modeling_auto_static.py | 16 +++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 3905bf4f9efe..5f5483bc809e 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -32,6 +32,16 @@ except ImportError: fused_rotary_position_embedding = None +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + from paddlenlp.transformers.conversion_utils import ( StateDictNameMapping, init_name_mappings, @@ -228,10 +238,10 @@ def __init__(self, config, ipp: Optional[int] = None): def forward(self, x): if self.fuse_attention_ffn: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + x = swiglu(self.gate_up_fused_proj(x)) else: - out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) return out diff --git a/paddlenlp/transformers/llama/modeling_auto_static.py b/paddlenlp/transformers/llama/modeling_auto_static.py index c4ee48def480..61bf3daa2529 100644 --- a/paddlenlp/transformers/llama/modeling_auto_static.py +++ b/paddlenlp/transformers/llama/modeling_auto_static.py @@ -31,6 +31,16 @@ except ImportError: fused_rotary_position_embedding = None +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + from paddlenlp.transformers.conversion_utils import ( StateDictNameMapping, init_name_mappings, @@ -242,10 +252,10 @@ def forward(self, x): fleet.auto.shard_tensor(self.down_proj.weight, *get_dist_attr(["mp", None], self.ipp)) if self.fuse_attention_ffn: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + x = swiglu(self.gate_up_fused_proj(x)) else: - out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) return out