Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused linear for the LLAMA MLP block and multi-head attention block #6425

Merged
merged 4 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ class ModelArguments:
metadata={"help": "llama, use_fused_rms_norm"},
)
fuse_attention_qkv: bool = field(
default=True,
metadata={"help": "gpt, fuse_attention_qkv"},
default=False,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_mlp_linear: bool = field(
default=False,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
recompute_granularity: str = field(
default="full",
Expand Down Expand Up @@ -404,6 +408,7 @@ def main():
config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.fuse_attention_qkv = model_args.fuse_attention_qkv
config.fuse_mlp_linear = model_args.fuse_mlp_linear
config.recompute_granularity = model_args.recompute_granularity
config.virtual_pp_degree = model_args.virtual_pp_degree
config.use_recompute = training_args.recompute
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def __init__(
use_cache=True,
use_recompute=False,
recompute_granularity="full",
fuse_attention_qkv=False,
use_flash_attention=False,
fuse_mlp_linear=False,
use_fused_rms_norm=False,
tensor_parallel_output=True,
lm_shift_labels=True,
Expand All @@ -230,7 +232,9 @@ def __init__(
self.use_cache = use_cache
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.fuse_attention_qkv = fuse_attention_qkv
self.use_flash_attention = use_flash_attention
self.fuse_mlp_linear = fuse_mlp_linear
self.use_fused_rms_norm = use_fused_rms_norm
self.tensor_parallel_output = tensor_parallel_output
self.lm_shift_labels = lm_shift_labels
Expand Down
143 changes: 94 additions & 49 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,33 +316,57 @@ def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.tensor_parallel_degree = config.tensor_parallel_degree
self.fuse_mlp_linear = config.fuse_mlp_linear

if config.tensor_parallel_degree > 1:
self.gate_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
gather_output=False,
has_bias=False,
)
# 为了减少张量并行的通信量,将两个linear合并成一个
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文注释 删掉

if config.fuse_mlp_linear:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 fuse_mlp_linear 是不是换个名字更好?
fuse_gate_up_proj ? 或者其他?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的

self.gate_up_fused_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size * 2,
gather_output=False,
has_bias=False,
)
else:
self.gate_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
gather_output=False,
has_bias=False,
)
self.up_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
gather_output=False,
has_bias=False,
)

self.down_proj = mpu.RowParallelLinear(
self.intermediate_size,
self.hidden_size,
input_is_parallel=True,
has_bias=False,
)
self.up_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
gather_output=False,
has_bias=False,
)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
if self.tensor_parallel_degree > 1 and self.fuse_mlp_linear:
# [s, b, 4hp]
intermediate_parallel = self.gate_up_fused_proj(x)
# Special Slicing to accomodate Tensor Parallel
# Even channels is ffc_fc, odd channels is gate
gate_out = intermediate_parallel[..., 0::2]
up_out = intermediate_parallel[..., 1::2]
intermediate_parallel = F.silu(gate_out) * up_out
# [s, b, h]
out = self.down_proj(intermediate_parallel)
else:
out = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
return out


class LlamaAttention(nn.Layer):
Expand All @@ -354,47 +378,63 @@ def __init__(self, config):
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
self.fuse_attention_qkv = config.fuse_attention_qkv
if config.tensor_parallel_degree > 1:
assert (
self.num_heads % config.tensor_parallel_degree == 0
), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
self.num_heads = self.num_heads // config.tensor_parallel_degree

if config.tensor_parallel_degree > 1:
self.q_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
self.k_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
self.v_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
if self.fuse_attention_qkv:
self.qkv_proj = mpu.ColumnParallelLinear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

预训练参数加载转换,看看能不能根据fuse 搞成自动的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问大佬,具体是啥意思呀

self.hidden_size,
3 * self.hidden_size,
has_bias=False,
gather_output=False,
)
else:
self.q_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
self.k_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
self.v_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
has_bias=False,
gather_output=False,
)
else:
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.hidden_size,
3 * self.hidden_size,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
littsk marked this conversation as resolved.
Show resolved Hide resolved

if config.tensor_parallel_degree > 1:
self.o_proj = mpu.RowParallelLinear(
Expand Down Expand Up @@ -422,9 +462,14 @@ def forward(
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])
key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])
value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])
if self.fuse_attention_qkv:
mix_layer = self.qkv_proj(hidden_states)
mix_layer = paddle.reshape_(mix_layer, [0, 0, self.num_heads, 3 * self.head_dim])
query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1)
else:
query_states = self.q_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])
key_states = self.k_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])
value_states = self.v_proj(hidden_states).reshape(shape=[bsz, q_len, self.num_heads, self.head_dim])

kv_seq_len = key_states.shape[-3]
offset = 0
Expand Down