diff --git a/llm/llama/run_pretrain.py b/llm/llama/run_pretrain.py index 39216138fdeb..affdb284c097 100644 --- a/llm/llama/run_pretrain.py +++ b/llm/llama/run_pretrain.py @@ -133,8 +133,12 @@ class ModelArguments: metadata={"help": "llama, use_fused_rms_norm"}, ) fuse_attention_qkv: bool = field( - default=True, - metadata={"help": "gpt, fuse_attention_qkv"}, + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, ) recompute_granularity: str = field( default="full", @@ -404,6 +408,7 @@ def main(): config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attention_ffn = model_args.fuse_attention_ffn config.recompute_granularity = model_args.recompute_granularity config.virtual_pp_degree = model_args.virtual_pp_degree config.use_recompute = training_args.recompute diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 42d655cc441e..adbd130138b2 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -208,7 +208,9 @@ def __init__( use_cache=True, use_recompute=False, recompute_granularity="full", + fuse_attention_qkv=False, use_flash_attention=False, + fuse_attention_ffn=False, use_fused_rms_norm=False, tensor_parallel_output=True, lm_shift_labels=True, @@ -230,7 +232,9 @@ def __init__( self.use_cache = use_cache self.use_recompute = use_recompute self.recompute_granularity = recompute_granularity + self.fuse_attention_qkv = fuse_attention_qkv self.use_flash_attention = use_flash_attention + self.fuse_attention_ffn = fuse_attention_ffn self.use_fused_rms_norm = use_fused_rms_norm self.tensor_parallel_output = tensor_parallel_output self.lm_shift_labels = lm_shift_labels diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 789974feeb5d..3d9bf67c69aa 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -316,33 +316,60 @@ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree + self.fuse_attention_ffn = config.fuse_attention_ffn if config.tensor_parallel_degree > 1: - self.gate_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - gather_output=False, - has_bias=False, - ) + if config.fuse_attention_ffn: + self.gate_up_fused_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=False, + ) + else: + self.gate_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = mpu.RowParallelLinear( self.intermediate_size, self.hidden_size, input_is_parallel=True, has_bias=False, ) - self.up_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - gather_output=False, - has_bias=False, - ) else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + if config.fuse_attention_ffn: + self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) def forward(self, x): - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + if self.fuse_attention_ffn: + # [s, b, 4hp] + intermediate_parallel = self.gate_up_fused_proj(x) + # Special Slicing to accomodate Tensor Parallel + # Even channels is ffc_fc, odd channels is gate + gate_out = intermediate_parallel[..., 0::2] + up_out = intermediate_parallel[..., 1::2] + intermediate_parallel = F.silu(gate_out) * up_out + # [s, b, h] + out = self.down_proj(intermediate_parallel) + else: + out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return out class LlamaAttention(nn.Layer): @@ -354,6 +381,7 @@ def __init__(self, config): self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings + self.fuse_attention_qkv = config.fuse_attention_qkv if config.tensor_parallel_degree > 1: assert ( self.num_heads % config.tensor_parallel_degree == 0 @@ -361,40 +389,55 @@ def __init__(self, config): self.num_heads = self.num_heads // config.tensor_parallel_degree if config.tensor_parallel_degree > 1: - self.q_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) - self.k_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) - self.v_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) + if self.fuse_attention_qkv: + self.qkv_proj = mpu.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + has_bias=False, + gather_output=False, + ) + else: + self.q_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + self.k_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + self.v_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) else: - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) + if self.fuse_attention_qkv: + self.qkv_proj = nn.Linear( + self.hidden_size, + 3 * self.hidden_size, + bias_attr=False, + ) + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) if config.tensor_parallel_degree > 1: self.o_proj = mpu.RowParallelLinear( @@ -422,9 +465,14 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.shape - query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) - key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) - value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + if self.fuse_attention_qkv: + mix_layer = self.qkv_proj(hidden_states) + mix_layer = paddle.reshape_(mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]) + query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) + else: + query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) kv_seq_len = key_states.shape[-3] offset = 0 @@ -575,13 +623,8 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True): def get_tensor_parallel_split_mappings(num_layers): final_actions = {} + base_actions = { - # Column Linear - "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), - "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), - "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), - "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), - "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), "lm_head.weight": partial(fn, is_column=True), # Row Linear "embed_tokens.weight": partial(fn, is_column=False), @@ -589,6 +632,20 @@ def get_tensor_parallel_split_mappings(num_layers): "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), } + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): if "layers.0." in key: for i in range(num_layers):