Skip to content

Commit

Permalink
add a100 test ground truth
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Apr 11, 2024
1 parent 79cb8b6 commit fdfc7bf
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 4 deletions.
12 changes: 8 additions & 4 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def scaled_dot_product_attention(

attn_output = paddle.matmul(attn_weights, value_states)
attn_output = attn_output.transpose([0, 2, 1, 3])
# [bsz, q_len, num_heads, head_dim] -> [bsz, q_len, num_heads * head_dim]
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output

Expand Down Expand Up @@ -399,9 +400,10 @@ def forward(
alibi: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
# [bs, seq_len, num_head * head_dim] or [seq_len / n, bs, num_head * head_dim] (if sequence_parallel)
# enter tp region
if self.config.sequence_parallel:
# [seq_len / n, bs, num_head * head_dim] -> [seq_len, bs, num_head * head_dim] (if sequence_parallel)
hidden_states = dist.reshard(
hidden_states,
get_mesh(self.ipp),
Expand All @@ -422,6 +424,8 @@ def forward(
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)

if self.config.sequence_parallel:
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
# FA and rope not support sequence first
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
Expand Down Expand Up @@ -526,12 +530,12 @@ def forward(
else:
attn_output = outputs

# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
# [bs, q_len, num_head * head_dim]
attn_output = self.o_proj(attn_output)

# enter sp region
if self.config.sequence_parallel:
# [bs, q_len, num_head * head_dim] -> [q_len / n, bs, num_head * head_dim]
attn_output = paddle.transpose(attn_output, [1, 0, 2])
attn_output = dist.reshard(
attn_output,
Expand Down Expand Up @@ -595,7 +599,7 @@ def forward(
cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
"""

# [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
# [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel)
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down
134 changes: 134 additions & 0 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ export llama_data_path=/llama_data

unset CUDA_VISIBLE_DEVICES

function is_a100() {
if [ $(nvidia-smi|grep A100|wc -l) -ne 0 ];then
echo 1
else
echo 0
fi
}


function gpt_case_list_auto() {
gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
Expand Down Expand Up @@ -100,6 +109,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
loss_base=10.507633305
ips_base=3518
mem_base=11750.6
if [ $(is_a100) ];then
loss_base=10.530449009
ips_base=16763
mem_base=11750.6
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -136,6 +150,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
loss_base=10.570028400
ips_base=35050
mem_base=1988.9
if [ $(is_a100) ];then
loss_base=10.559662151
ips_base=83918
mem_base=2022.7
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -173,6 +192,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
loss_base=10.570028400
ips_base=35050
mem_base=1988.9
if [ $(is_a100) ];then
loss_base=10.559662151
ips_base=83918
mem_base=2022.7
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -209,6 +233,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
loss_base=10.700293922
ips_base=32518
mem_base=1535.7
if [ $(is_a100) ];then
loss_base=10.679453373
ips_base=79116
mem_base=1488.2
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -245,6 +274,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
loss_base=10.672543240
ips_base=18681
mem_base=2135.7
if [ $(is_a100) ];then
loss_base=10.651049423
ips_base=41174
mem_base=2064.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -282,6 +316,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
loss_base=10.672543240
ips_base=18681
mem_base=2135.7
if [ $(is_a100) ];then
loss_base=10.651049423
ips_base=41174
mem_base=2064.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -318,6 +357,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
loss_base=10.720068359
ips_base=15232
mem_base=1999.2
if [ $(is_a100) ];then
loss_base=10.657777309
ips_base=30027
mem_base=2002.0
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -355,6 +399,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
loss_base=10.720068359
ips_base=15232
mem_base=1999.2
if [ $(is_a100) ];then
loss_base=10.657777309
ips_base=30027
mem_base=2002.0
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -391,6 +440,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
loss_base=10.720078850
ips_base=15571
mem_base=1999.2
if [ $(is_a100) ];then
loss_base=10.657803535
ips_base=29166
mem_base=2002.0
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -427,6 +481,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
loss_base=10.681921577
ips_base=13813
mem_base=1747.6
if [ $(is_a100) ];then
loss_base=10.662137604
ips_base=24700
mem_base=1750.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -463,6 +522,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
loss_base=10.579057693
ips_base=19822
mem_base=1709.8
if [ $(is_a100) ];then
loss_base=10.586785984
ips_base=42813
mem_base=1743.8
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -500,6 +564,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
loss_base=10.579057693
ips_base=19822
mem_base=1709.8
if [ $(is_a100) ];then
loss_base=10.586785984
ips_base=42813
mem_base=1743.8
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -536,6 +605,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
loss_base=10.579057693
ips_base=20170
mem_base=1709.8
if [ $(is_a100) ];then
loss_base=10.586785984
ips_base=42995
mem_base=1743.8
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -572,6 +646,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
loss_base=10.585316849
ips_base=15742
mem_base=1591.6
if [ $(is_a100) ];then
loss_base=10.555718899
ips_base=34688
mem_base=1625.6
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -608,6 +687,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
loss_base=10.672568035
ips_base=19461
mem_base=1384.7
if [ $(is_a100) ];then
loss_base=10.651032448
ips_base=42435
mem_base=1377.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -644,6 +728,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
loss_base=10.672568035
ips_base=19652
mem_base=1384.7
if [ $(is_a100) ];then
loss_base=10.651032448
ips_base=43008
mem_base=1377.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -681,6 +770,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
loss_base=10.672568035
ips_base=19652
mem_base=1384.7
if [ $(is_a100) ];then
loss_base=10.651032448
ips_base=43008
mem_base=1377.5
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -717,6 +811,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
loss_base=10.696336079
ips_base=16613
mem_base=1280.5
if [ $(is_a100) ];then
loss_base=10.705118465
ips_base=37104
mem_base=1217.3
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -754,6 +853,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
loss_base=10.696336079
ips_base=16613
mem_base=1280.5
if [ $(is_a100) ];then
loss_base=10.705118465
ips_base=37104
mem_base=1217.3
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down Expand Up @@ -900,6 +1004,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.52110565
if [ $(is_a100) ];then
loss_base=9.44003963
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -966,6 +1073,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.42011833
if [ $(is_a100) ];then
loss_base=9.44003963
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1032,6 +1142,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.44299471
if [ $(is_a100) ];then
loss_base=9.45633757
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1098,6 +1211,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.45936012
if [ $(is_a100) ];then
loss_base=9.46121407
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1166,6 +1282,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.46707726
if [ $(is_a100) ];then
loss_base=9.44474411
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1235,6 +1354,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=10.0859375
if [ $(is_a100) ];then
loss_base=10.125
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1302,6 +1424,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.53389835
if [ $(is_a100) ];then
loss_base=9.54253578
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1369,6 +1494,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.39066124
if [ $(is_a100) ];then
loss_base=9.41613197
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1436,6 +1564,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38235474
if [ $(is_a100) ];then
loss_base=9.4053154
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down Expand Up @@ -1504,6 +1635,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38256836
if [ $(is_a100) ];then
loss_base=9.4055137
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down

0 comments on commit fdfc7bf

Please sign in to comment.