From 0c1a534d4d27a358f937b43c47ef37a9a3baedd7 Mon Sep 17 00:00:00 2001 From: littsk Date: Mon, 17 Jul 2023 15:04:11 +0800 Subject: [PATCH 1/4] Add fused linear for the LLAMA MLP block and multi-head attention block --- llm/llama/run_pretrain.py | 6 +- paddlenlp/transformers/llama/configuration.py | 2 + paddlenlp/transformers/llama/modeling.py | 112 +++++++++++------- 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/llm/llama/run_pretrain.py b/llm/llama/run_pretrain.py index 39216138fdeb..5d4c35dfe089 100644 --- a/llm/llama/run_pretrain.py +++ b/llm/llama/run_pretrain.py @@ -132,9 +132,9 @@ class ModelArguments: default=False, metadata={"help": "llama, use_fused_rms_norm"}, ) - fuse_attention_qkv: bool = field( + fuse_attn_qkv: bool = field( default=True, - metadata={"help": "gpt, fuse_attention_qkv"}, + metadata={"help": "whether to fuse attn qkv"}, ) recompute_granularity: str = field( default="full", @@ -403,7 +403,7 @@ def main(): config.lm_shift_labels = False config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm - config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attn_qkv = model_args.fuse_attn_qkv config.recompute_granularity = model_args.recompute_granularity config.virtual_pp_degree = model_args.virtual_pp_degree config.use_recompute = training_args.recompute diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 42d655cc441e..29da4fe13abd 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -208,6 +208,7 @@ def __init__( use_cache=True, use_recompute=False, recompute_granularity="full", + fuse_attn_qkv=False, use_flash_attention=False, use_fused_rms_norm=False, tensor_parallel_output=True, @@ -230,6 +231,7 @@ def __init__( self.use_cache = use_cache self.use_recompute = use_recompute self.recompute_granularity = recompute_granularity + self.fuse_attn_qkv = fuse_attn_qkv self.use_flash_attention = use_flash_attention self.use_fused_rms_norm = use_fused_rms_norm self.tensor_parallel_output = tensor_parallel_output diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 789974feeb5d..724308c255a5 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -316,11 +316,13 @@ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree if config.tensor_parallel_degree > 1: - self.gate_proj = mpu.ColumnParallelLinear( + # 为了减少张量并行的通信量,将两个linear合并成一个 + self.gate_up_fused_proj = mpu.ColumnParallelLinear( self.hidden_size, - self.intermediate_size, + self.intermediate_size * 2, gather_output=False, has_bias=False, ) @@ -330,19 +332,18 @@ def __init__(self, config): input_is_parallel=True, has_bias=False, ) - self.up_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - gather_output=False, - has_bias=False, - ) else: self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) def forward(self, x): - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + if self.tensor_parallel_degree > 1: + gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) + out = self.down_proj(F.silu(gate_out) * up_out) + else: + out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return out class LlamaAttention(nn.Layer): @@ -354,6 +355,7 @@ def __init__(self, config): self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings + self.fuse_attn_qkv = config.fuse_attn_qkv if config.tensor_parallel_degree > 1: assert ( self.num_heads % config.tensor_parallel_degree == 0 @@ -361,40 +363,55 @@ def __init__(self, config): self.num_heads = self.num_heads // config.tensor_parallel_degree if config.tensor_parallel_degree > 1: - self.q_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) - self.k_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) - self.v_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - has_bias=False, - gather_output=False, - ) + if self.fuse_attn_qkv: + self.qkv_proj = mpu.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + has_bias=False, + gather_output=False, + ) + else: + self.q_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + self.k_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) + self.v_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + has_bias=False, + gather_output=False, + ) else: - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias_attr=False, - ) + if self.fuse_attn_qkv: + self.qkv_proj = nn.Linear( + self.hidden_size, + 3 * self.hidden_size, + bias_attr=False, + ) + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) if config.tensor_parallel_degree > 1: self.o_proj = mpu.RowParallelLinear( @@ -422,9 +439,14 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.shape - query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) - key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) - value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + if self.fuse_attn_qkv: + mix_layer = self.qkv_proj(hidden_states) + mix_layer = paddle.reshape_(mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]) + query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) + else: + query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) + value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim]) kv_seq_len = key_states.shape[-3] offset = 0 From 178e63b9de560984ae07613b2a8c0903ad8198d1 Mon Sep 17 00:00:00 2001 From: littsk Date: Wed, 19 Jul 2023 20:13:57 +0800 Subject: [PATCH 2/4] Refactor the config name 'fuse_attn_qkv' to 'fuse_attention_qkv' for improved readability and consistency. --- llm/llama/run_pretrain.py | 8 ++++---- paddlenlp/transformers/llama/configuration.py | 4 ++-- paddlenlp/transformers/llama/modeling.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/llm/llama/run_pretrain.py b/llm/llama/run_pretrain.py index 5d4c35dfe089..b8e03274d084 100644 --- a/llm/llama/run_pretrain.py +++ b/llm/llama/run_pretrain.py @@ -132,9 +132,9 @@ class ModelArguments: default=False, metadata={"help": "llama, use_fused_rms_norm"}, ) - fuse_attn_qkv: bool = field( - default=True, - metadata={"help": "whether to fuse attn qkv"}, + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, ) recompute_granularity: str = field( default="full", @@ -403,7 +403,7 @@ def main(): config.lm_shift_labels = False config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm - config.fuse_attn_qkv = model_args.fuse_attn_qkv + config.fuse_attention_qkv = model_args.fuse_attention_qkv config.recompute_granularity = model_args.recompute_granularity config.virtual_pp_degree = model_args.virtual_pp_degree config.use_recompute = training_args.recompute diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 29da4fe13abd..016d6abc6bda 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -208,7 +208,7 @@ def __init__( use_cache=True, use_recompute=False, recompute_granularity="full", - fuse_attn_qkv=False, + fuse_attention_qkv=False, use_flash_attention=False, use_fused_rms_norm=False, tensor_parallel_output=True, @@ -231,7 +231,7 @@ def __init__( self.use_cache = use_cache self.use_recompute = use_recompute self.recompute_granularity = recompute_granularity - self.fuse_attn_qkv = fuse_attn_qkv + self.fuse_attention_qkv = fuse_attention_qkv self.use_flash_attention = use_flash_attention self.use_fused_rms_norm = use_fused_rms_norm self.tensor_parallel_output = tensor_parallel_output diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 724308c255a5..6a8069885763 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -355,7 +355,7 @@ def __init__(self, config): self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = config.max_position_embeddings - self.fuse_attn_qkv = config.fuse_attn_qkv + self.fuse_attention_qkv = config.fuse_attention_qkv if config.tensor_parallel_degree > 1: assert ( self.num_heads % config.tensor_parallel_degree == 0 @@ -363,7 +363,7 @@ def __init__(self, config): self.num_heads = self.num_heads // config.tensor_parallel_degree if config.tensor_parallel_degree > 1: - if self.fuse_attn_qkv: + if self.fuse_attention_qkv: self.qkv_proj = mpu.ColumnParallelLinear( self.hidden_size, 3 * self.hidden_size, @@ -390,7 +390,7 @@ def __init__(self, config): gather_output=False, ) else: - if self.fuse_attn_qkv: + if self.fuse_attention_qkv: self.qkv_proj = nn.Linear( self.hidden_size, 3 * self.hidden_size, @@ -439,7 +439,7 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.shape - if self.fuse_attn_qkv: + if self.fuse_attention_qkv: mix_layer = self.qkv_proj(hidden_states) mix_layer = paddle.reshape_(mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]) query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) From 64128922c3c9b959dfb20b159b0af5fb8d2c8492 Mon Sep 17 00:00:00 2001 From: littsk Date: Mon, 24 Jul 2023 15:47:04 +0800 Subject: [PATCH 3/4] Added a switch for fuse_mlp_linear and improved the organization of the fused linear implementation. --- llm/llama/run_pretrain.py | 5 +++ paddlenlp/transformers/llama/configuration.py | 2 + paddlenlp/transformers/llama/modeling.py | 41 +++++++++++++++---- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/llm/llama/run_pretrain.py b/llm/llama/run_pretrain.py index b8e03274d084..198449330485 100644 --- a/llm/llama/run_pretrain.py +++ b/llm/llama/run_pretrain.py @@ -136,6 +136,10 @@ class ModelArguments: default=False, metadata={"help": "whether to fuse attention qkv"}, ) + fuse_mlp_linear: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) recompute_granularity: str = field( default="full", metadata={"help": "full core_attn"}, @@ -404,6 +408,7 @@ def main(): config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_mlp_linear = model_args.fuse_mlp_linear config.recompute_granularity = model_args.recompute_granularity config.virtual_pp_degree = model_args.virtual_pp_degree config.use_recompute = training_args.recompute diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 016d6abc6bda..89c391499c5b 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -210,6 +210,7 @@ def __init__( recompute_granularity="full", fuse_attention_qkv=False, use_flash_attention=False, + fuse_mlp_linear=False, use_fused_rms_norm=False, tensor_parallel_output=True, lm_shift_labels=True, @@ -233,6 +234,7 @@ def __init__( self.recompute_granularity = recompute_granularity self.fuse_attention_qkv = fuse_attention_qkv self.use_flash_attention = use_flash_attention + self.fuse_mlp_linear = fuse_mlp_linear self.use_fused_rms_norm = use_fused_rms_norm self.tensor_parallel_output = tensor_parallel_output self.lm_shift_labels = lm_shift_labels diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 6a8069885763..9f836753c20d 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -317,15 +317,31 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.tensor_parallel_degree = config.tensor_parallel_degree + self.fuse_mlp_linear = config.fuse_mlp_linear if config.tensor_parallel_degree > 1: # 为了减少张量并行的通信量,将两个linear合并成一个 - self.gate_up_fused_proj = mpu.ColumnParallelLinear( - self.hidden_size, - self.intermediate_size * 2, - gather_output=False, - has_bias=False, - ) + if config.fuse_mlp_linear: + self.gate_up_fused_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=False, + ) + else: + self.gate_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = mpu.ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = mpu.RowParallelLinear( self.intermediate_size, self.hidden_size, @@ -338,9 +354,16 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) def forward(self, x): - if self.tensor_parallel_degree > 1: - gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) - out = self.down_proj(F.silu(gate_out) * up_out) + if self.tensor_parallel_degree > 1 and self.fuse_mlp_linear: + # [s, b, 4hp] + intermediate_parallel = self.gate_up_fused_proj(x) + # Special Slicing to accomodate Tensor Parallel + # Even channels is ffc_fc, odd channels is gate + gate_out = intermediate_parallel[..., 0::2] + up_out = intermediate_parallel[..., 1::2] + intermediate_parallel = F.silu(gate_out) * up_out + # [s, b, h] + out = self.down_proj(intermediate_parallel) else: out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) return out From 394608f741990a7cb04d155663b33a0c1c7f85b0 Mon Sep 17 00:00:00 2001 From: littsk Date: Wed, 26 Jul 2023 10:49:32 +0800 Subject: [PATCH 4/4] add tensor parallel mappings for fused linears --- llm/llama/run_pretrain.py | 4 +-- paddlenlp/transformers/llama/configuration.py | 4 +-- paddlenlp/transformers/llama/modeling.py | 36 ++++++++++++------- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/llm/llama/run_pretrain.py b/llm/llama/run_pretrain.py index 198449330485..affdb284c097 100644 --- a/llm/llama/run_pretrain.py +++ b/llm/llama/run_pretrain.py @@ -136,7 +136,7 @@ class ModelArguments: default=False, metadata={"help": "whether to fuse attention qkv"}, ) - fuse_mlp_linear: bool = field( + fuse_attention_ffn: bool = field( default=False, metadata={"help": "whether to fuse first up and gate proj in mlp block"}, ) @@ -408,7 +408,7 @@ def main(): config.use_flash_attention = model_args.use_flash_attention config.use_fused_rms_norm = model_args.use_fused_rms_norm config.fuse_attention_qkv = model_args.fuse_attention_qkv - config.fuse_mlp_linear = model_args.fuse_mlp_linear + config.fuse_attention_ffn = model_args.fuse_attention_ffn config.recompute_granularity = model_args.recompute_granularity config.virtual_pp_degree = model_args.virtual_pp_degree config.use_recompute = training_args.recompute diff --git a/paddlenlp/transformers/llama/configuration.py b/paddlenlp/transformers/llama/configuration.py index 89c391499c5b..adbd130138b2 100644 --- a/paddlenlp/transformers/llama/configuration.py +++ b/paddlenlp/transformers/llama/configuration.py @@ -210,7 +210,7 @@ def __init__( recompute_granularity="full", fuse_attention_qkv=False, use_flash_attention=False, - fuse_mlp_linear=False, + fuse_attention_ffn=False, use_fused_rms_norm=False, tensor_parallel_output=True, lm_shift_labels=True, @@ -234,7 +234,7 @@ def __init__( self.recompute_granularity = recompute_granularity self.fuse_attention_qkv = fuse_attention_qkv self.use_flash_attention = use_flash_attention - self.fuse_mlp_linear = fuse_mlp_linear + self.fuse_attention_ffn = fuse_attention_ffn self.use_fused_rms_norm = use_fused_rms_norm self.tensor_parallel_output = tensor_parallel_output self.lm_shift_labels = lm_shift_labels diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 9f836753c20d..3d9bf67c69aa 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -317,11 +317,10 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.tensor_parallel_degree = config.tensor_parallel_degree - self.fuse_mlp_linear = config.fuse_mlp_linear + self.fuse_attention_ffn = config.fuse_attention_ffn if config.tensor_parallel_degree > 1: - # 为了减少张量并行的通信量,将两个linear合并成一个 - if config.fuse_mlp_linear: + if config.fuse_attention_ffn: self.gate_up_fused_proj = mpu.ColumnParallelLinear( self.hidden_size, self.intermediate_size * 2, @@ -349,12 +348,16 @@ def __init__(self, config): has_bias=False, ) else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + if config.fuse_attention_ffn: + self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) def forward(self, x): - if self.tensor_parallel_degree > 1 and self.fuse_mlp_linear: + if self.fuse_attention_ffn: # [s, b, 4hp] intermediate_parallel = self.gate_up_fused_proj(x) # Special Slicing to accomodate Tensor Parallel @@ -620,13 +623,8 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True): def get_tensor_parallel_split_mappings(num_layers): final_actions = {} + base_actions = { - # Column Linear - "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), - "layers.0.self_attn.k_proj.weight": partial(fn, is_column=True), - "layers.0.self_attn.v_proj.weight": partial(fn, is_column=True), - "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), - "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), "lm_head.weight": partial(fn, is_column=True), # Row Linear "embed_tokens.weight": partial(fn, is_column=False), @@ -634,6 +632,20 @@ def get_tensor_parallel_split_mappings(num_layers): "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), } + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): if "layers.0." in key: for i in range(num_layers):