Skip to content

Commit

Permalink
add tensor parallel mappings for fused linears
Browse files Browse the repository at this point in the history
  • Loading branch information
littsk committed Jul 26, 2023
1 parent 6412892 commit 394608f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
4 changes: 2 additions & 2 deletions llm/llama/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class ModelArguments:
default=False,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_mlp_linear: bool = field(
fuse_attention_ffn: bool = field(
default=False,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
Expand Down Expand Up @@ -408,7 +408,7 @@ def main():
config.use_flash_attention = model_args.use_flash_attention
config.use_fused_rms_norm = model_args.use_fused_rms_norm
config.fuse_attention_qkv = model_args.fuse_attention_qkv
config.fuse_mlp_linear = model_args.fuse_mlp_linear
config.fuse_attention_ffn = model_args.fuse_attention_ffn
config.recompute_granularity = model_args.recompute_granularity
config.virtual_pp_degree = model_args.virtual_pp_degree
config.use_recompute = training_args.recompute
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
recompute_granularity="full",
fuse_attention_qkv=False,
use_flash_attention=False,
fuse_mlp_linear=False,
fuse_attention_ffn=False,
use_fused_rms_norm=False,
tensor_parallel_output=True,
lm_shift_labels=True,
Expand All @@ -234,7 +234,7 @@ def __init__(
self.recompute_granularity = recompute_granularity
self.fuse_attention_qkv = fuse_attention_qkv
self.use_flash_attention = use_flash_attention
self.fuse_mlp_linear = fuse_mlp_linear
self.fuse_attention_ffn = fuse_attention_ffn
self.use_fused_rms_norm = use_fused_rms_norm
self.tensor_parallel_output = tensor_parallel_output
self.lm_shift_labels = lm_shift_labels
Expand Down
36 changes: 24 additions & 12 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,10 @@ def __init__(self, config):
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.tensor_parallel_degree = config.tensor_parallel_degree
self.fuse_mlp_linear = config.fuse_mlp_linear
self.fuse_attention_ffn = config.fuse_attention_ffn

if config.tensor_parallel_degree > 1:
# 为了减少张量并行的通信量,将两个linear合并成一个
if config.fuse_mlp_linear:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = mpu.ColumnParallelLinear(
self.hidden_size,
self.intermediate_size * 2,
Expand Down Expand Up @@ -349,12 +348,16 @@ def __init__(self, config):
has_bias=False,
)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

def forward(self, x):
if self.tensor_parallel_degree > 1 and self.fuse_mlp_linear:
if self.fuse_attention_ffn:
# [s, b, 4hp]
intermediate_parallel = self.gate_up_fused_proj(x)
# Special Slicing to accomodate Tensor Parallel
Expand Down Expand Up @@ -620,20 +623,29 @@ def _get_tensor_parallel_mappings(cls, config, is_split=True):

def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}

base_actions = {
# Column Linear
"layers.0.self_attn.q_proj.weight": partial(fn, is_column=True),
"layers.0.self_attn.k_proj.weight": partial(fn, is_column=True),
"layers.0.self_attn.v_proj.weight": partial(fn, is_column=True),
"layers.0.mlp.gate_proj.weight": partial(fn, is_column=True),
"layers.0.mlp.up_proj.weight": partial(fn, is_column=True),
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
}

# Column Linear
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)

if config.fuse_attention_ffn:
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)

for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
Expand Down

0 comments on commit 394608f

Please sign in to comment.