From 2e87c77c67f3ee82c77e0eccbc3a13a895b67771 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 10 Apr 2024 18:41:18 +0800 Subject: [PATCH 1/8] add a100 test ground truth --- paddlenlp/transformers/llama/modeling_auto.py | 12 +- scripts/distribute/ci_case_auto.sh | 134 ++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 64277d744016..fce60999c214 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -181,6 +181,7 @@ def scaled_dot_product_attention( attn_output = paddle.matmul(attn_weights, value_states) attn_output = attn_output.transpose([0, 2, 1, 3]) + # [bsz, q_len, num_heads, head_dim] -> [bsz, q_len, num_heads * head_dim] attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) return (attn_output, attn_weights) if output_attentions else attn_output @@ -399,9 +400,10 @@ def forward( alibi: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" - # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + # [bs, seq_len, num_head * head_dim] or [seq_len / n, bs, num_head * head_dim] (if sequence_parallel) # enter tp region if self.config.sequence_parallel: + # [seq_len / n, bs, num_head * head_dim] -> [seq_len, bs, num_head * head_dim] (if sequence_parallel) hidden_states = dist.reshard( hidden_states, get_mesh(self.ipp), @@ -422,6 +424,8 @@ def forward( value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape) if self.config.sequence_parallel: + # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel) + # FA and rope not support sequence first query_states = paddle.transpose(query_states, [1, 0, 2, 3]) key_states = paddle.transpose(key_states, [1, 0, 2, 3]) value_states = paddle.transpose(value_states, [1, 0, 2, 3]) @@ -526,12 +530,12 @@ def forward( else: attn_output = outputs - # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] - # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + # [bs, q_len, num_head * head_dim] attn_output = self.o_proj(attn_output) # enter sp region if self.config.sequence_parallel: + # [bs, q_len, num_head * head_dim] -> [q_len / n, bs, num_head * head_dim] attn_output = paddle.transpose(attn_output, [1, 0, 2]) attn_output = dist.reshard( attn_output, @@ -595,7 +599,7 @@ def forward( cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states """ - # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + # [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 8707289b7fe2..ce10b6a218c6 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -28,6 +28,15 @@ export llm_gpt_case_path=$root_path/llm/gpt-3/auto_parallel unset CUDA_VISIBLE_DEVICES +function is_a100() { + if [ $(nvidia-smi|grep A100|wc -l) -ne 0 ];then + echo 1 + else + echo 0 + fi +} + + function gpt_case_list_auto() { gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1 gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8 @@ -108,6 +117,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() { loss_base=10.507633305 ips_base=3518 mem_base=11750.6 + if [ $(is_a100) ];then + loss_base=10.530449009 + ips_base=16763 + mem_base=11750.6 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -144,6 +158,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() { loss_base=10.570028400 ips_base=35050 mem_base=1988.9 + if [ $(is_a100) ];then + loss_base=10.559662151 + ips_base=83918 + mem_base=2022.7 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -181,6 +200,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() { loss_base=10.570028400 ips_base=35050 mem_base=1988.9 + if [ $(is_a100) ];then + loss_base=10.559662151 + ips_base=83918 + mem_base=2022.7 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -217,6 +241,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() { loss_base=10.700293922 ips_base=32518 mem_base=1535.7 + if [ $(is_a100) ];then + loss_base=10.679453373 + ips_base=79116 + mem_base=1488.2 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -253,6 +282,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() { loss_base=10.672543240 ips_base=18681 mem_base=2135.7 + if [ $(is_a100) ];then + loss_base=10.651049423 + ips_base=41174 + mem_base=2064.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -290,6 +324,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() { loss_base=10.672543240 ips_base=18681 mem_base=2135.7 + if [ $(is_a100) ];then + loss_base=10.651049423 + ips_base=41174 + mem_base=2064.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -326,6 +365,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() { loss_base=10.720068359 ips_base=15232 mem_base=1999.2 + if [ $(is_a100) ];then + loss_base=10.657777309 + ips_base=30027 + mem_base=2002.0 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -363,6 +407,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() { loss_base=10.720068359 ips_base=15232 mem_base=1999.2 + if [ $(is_a100) ];then + loss_base=10.657777309 + ips_base=30027 + mem_base=2002.0 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -399,6 +448,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() { loss_base=10.720078850 ips_base=15571 mem_base=1999.2 + if [ $(is_a100) ];then + loss_base=10.657803535 + ips_base=29166 + mem_base=2002.0 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -435,6 +489,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() { loss_base=10.681921577 ips_base=13813 mem_base=1747.6 + if [ $(is_a100) ];then + loss_base=10.662137604 + ips_base=24700 + mem_base=1750.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -471,6 +530,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() { loss_base=10.579057693 ips_base=19822 mem_base=1709.8 + if [ $(is_a100) ];then + loss_base=10.586785984 + ips_base=42813 + mem_base=1743.8 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -508,6 +572,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() { loss_base=10.579057693 ips_base=19822 mem_base=1709.8 + if [ $(is_a100) ];then + loss_base=10.586785984 + ips_base=42813 + mem_base=1743.8 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -544,6 +613,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() { loss_base=10.579057693 ips_base=20170 mem_base=1709.8 + if [ $(is_a100) ];then + loss_base=10.586785984 + ips_base=42995 + mem_base=1743.8 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -580,6 +654,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() { loss_base=10.585316849 ips_base=15742 mem_base=1591.6 + if [ $(is_a100) ];then + loss_base=10.555718899 + ips_base=34688 + mem_base=1625.6 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -616,6 +695,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() { loss_base=10.672568035 ips_base=19461 mem_base=1384.7 + if [ $(is_a100) ];then + loss_base=10.651032448 + ips_base=42435 + mem_base=1377.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -652,6 +736,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() { loss_base=10.672568035 ips_base=19652 mem_base=1384.7 + if [ $(is_a100) ];then + loss_base=10.651032448 + ips_base=43008 + mem_base=1377.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -689,6 +778,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() { loss_base=10.672568035 ips_base=19652 mem_base=1384.7 + if [ $(is_a100) ];then + loss_base=10.651032448 + ips_base=43008 + mem_base=1377.5 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -725,6 +819,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() { loss_base=10.696336079 ips_base=16613 mem_base=1280.5 + if [ $(is_a100) ];then + loss_base=10.705118465 + ips_base=37104 + mem_base=1217.3 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -762,6 +861,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() { loss_base=10.696336079 ips_base=16613 mem_base=1280.5 + if [ $(is_a100) ];then + loss_base=10.705118465 + ips_base=37104 + mem_base=1217.3 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -908,6 +1012,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.52110565 + if [ $(is_a100) ];then + loss_base=9.44003963 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -974,6 +1081,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.42011833 + if [ $(is_a100) ];then + loss_base=9.44003963 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1040,6 +1150,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.44299471 + if [ $(is_a100) ];then + loss_base=9.45633757 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1106,6 +1219,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.45936012 + if [ $(is_a100) ];then + loss_base=9.46121407 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1174,6 +1290,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.46707726 + if [ $(is_a100) ];then + loss_base=9.44474411 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1243,6 +1362,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=10.0859375 + if [ $(is_a100) ];then + loss_base=10.125 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1310,6 +1432,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.53389835 + if [ $(is_a100) ];then + loss_base=9.54253578 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1377,6 +1502,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.39066124 + if [ $(is_a100) ];then + loss_base=9.41613197 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1444,6 +1572,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.38235474 + if [ $(is_a100) ];then + loss_base=9.4053154 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} @@ -1512,6 +1643,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.38256836 + if [ $(is_a100) ];then + loss_base=9.4055137 + fi ips_base=-1 mem_base=-1 check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} From 50ef37e1928ca565d5cf151c1cf18a9a73bbc0f5 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 11 Apr 2024 15:59:42 +0800 Subject: [PATCH 2/8] add requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 4b676d900563..33b7f4bf9149 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ safetensors tool_helpers aistudio-sdk>=0.1.3 jinja2 +regex \ No newline at end of file From 98afa7a2992b9bd196422a453210e69e0d68efa9 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 11 Apr 2024 16:44:58 +0800 Subject: [PATCH 3/8] cache is_a100 result --- scripts/distribute/ci_case_auto.sh | 59 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index ce10b6a218c6..332385c85e35 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -36,6 +36,7 @@ function is_a100() { fi } +IS_A100=$(is_a100) function gpt_case_list_auto() { gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1 @@ -117,7 +118,7 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() { loss_base=10.507633305 ips_base=3518 mem_base=11750.6 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.530449009 ips_base=16763 mem_base=11750.6 @@ -158,7 +159,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() { loss_base=10.570028400 ips_base=35050 mem_base=1988.9 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.559662151 ips_base=83918 mem_base=2022.7 @@ -200,7 +201,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() { loss_base=10.570028400 ips_base=35050 mem_base=1988.9 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.559662151 ips_base=83918 mem_base=2022.7 @@ -241,7 +242,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() { loss_base=10.700293922 ips_base=32518 mem_base=1535.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.679453373 ips_base=79116 mem_base=1488.2 @@ -282,7 +283,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() { loss_base=10.672543240 ips_base=18681 mem_base=2135.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.651049423 ips_base=41174 mem_base=2064.5 @@ -324,7 +325,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() { loss_base=10.672543240 ips_base=18681 mem_base=2135.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.651049423 ips_base=41174 mem_base=2064.5 @@ -365,7 +366,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() { loss_base=10.720068359 ips_base=15232 mem_base=1999.2 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.657777309 ips_base=30027 mem_base=2002.0 @@ -407,7 +408,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() { loss_base=10.720068359 ips_base=15232 mem_base=1999.2 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.657777309 ips_base=30027 mem_base=2002.0 @@ -448,7 +449,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() { loss_base=10.720078850 ips_base=15571 mem_base=1999.2 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.657803535 ips_base=29166 mem_base=2002.0 @@ -489,7 +490,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() { loss_base=10.681921577 ips_base=13813 mem_base=1747.6 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.662137604 ips_base=24700 mem_base=1750.5 @@ -530,7 +531,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() { loss_base=10.579057693 ips_base=19822 mem_base=1709.8 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.586785984 ips_base=42813 mem_base=1743.8 @@ -572,7 +573,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() { loss_base=10.579057693 ips_base=19822 mem_base=1709.8 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.586785984 ips_base=42813 mem_base=1743.8 @@ -613,7 +614,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() { loss_base=10.579057693 ips_base=20170 mem_base=1709.8 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.586785984 ips_base=42995 mem_base=1743.8 @@ -654,7 +655,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() { loss_base=10.585316849 ips_base=15742 mem_base=1591.6 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.555718899 ips_base=34688 mem_base=1625.6 @@ -695,7 +696,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() { loss_base=10.672568035 ips_base=19461 mem_base=1384.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.651032448 ips_base=42435 mem_base=1377.5 @@ -736,7 +737,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() { loss_base=10.672568035 ips_base=19652 mem_base=1384.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.651032448 ips_base=43008 mem_base=1377.5 @@ -778,7 +779,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() { loss_base=10.672568035 ips_base=19652 mem_base=1384.7 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.651032448 ips_base=43008 mem_base=1377.5 @@ -819,7 +820,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() { loss_base=10.696336079 ips_base=16613 mem_base=1280.5 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.705118465 ips_base=37104 mem_base=1217.3 @@ -861,7 +862,7 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() { loss_base=10.696336079 ips_base=16613 mem_base=1280.5 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.705118465 ips_base=37104 mem_base=1217.3 @@ -1012,7 +1013,7 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.52110565 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.44003963 fi ips_base=-1 @@ -1081,7 +1082,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.42011833 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.44003963 fi ips_base=-1 @@ -1150,7 +1151,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.44299471 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.45633757 fi ips_base=-1 @@ -1219,7 +1220,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.45936012 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.46121407 fi ips_base=-1 @@ -1290,7 +1291,7 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.46707726 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.44474411 fi ips_base=-1 @@ -1362,7 +1363,7 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2 mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=10.0859375 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=10.125 fi ips_base=-1 @@ -1432,7 +1433,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.53389835 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.54253578 fi ips_base=-1 @@ -1502,7 +1503,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.39066124 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.41613197 fi ips_base=-1 @@ -1572,7 +1573,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.38235474 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.4053154 fi ips_base=-1 @@ -1643,7 +1644,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { mem=-1 echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.38256836 - if [ $(is_a100) ];then + if [ $IS_A100 -ne 0 ];then loss_base=9.4055137 fi ips_base=-1 From 62e84be650ba87adcd7b4e30976fed4ee4cab7e3 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 16 Apr 2024 14:57:24 +0800 Subject: [PATCH 4/8] update --- scripts/distribute/ci_case_auto.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 332385c85e35..b01d086b170f 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1645,7 +1645,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.38256836 if [ $IS_A100 -ne 0 ];then - loss_base=9.4055137 + loss_base=9.4055109 fi ips_base=-1 mem_base=-1 From f5287fe8fb10b039c2a9724715f1857281105130 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 16 Apr 2024 15:52:59 +0800 Subject: [PATCH 5/8] update --- scripts/distribute/ci_case_auto.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index b01d086b170f..8f637aa4fe23 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1014,7 +1014,7 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() { echo "result: loss=$loss ips=$ips mem=$mem" loss_base=9.52110565 if [ $IS_A100 -ne 0 ];then - loss_base=9.44003963 + loss_base=9.54202747 fi ips_base=-1 mem_base=-1 From fcb7462e3d470d7a781b7dfbcba13f77a473fe6f Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 17 Apr 2024 20:13:08 +0800 Subject: [PATCH 6/8] update sp allclose --- scripts/distribute/ci_case_auto.sh | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 8f637aa4fe23..9c58d896861c 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -948,8 +948,12 @@ function gpt_auto_sp_acc_check() { loss_base=`cat ${log_dir_spFalse}/workerlog.0 | grep '30/30' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` ips_base=-1 mem_base=-1 + allclose=0 echo "result: loss_spTrue=$loss loss_spFasle=$loss_base" - check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + if [ $IS_A100 -ne 0 ];then + allclose=1 + fi + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} ${allclose} echo "=========== $FUNCNAME run end ===========" } @@ -1948,8 +1952,17 @@ function check_result() { diff_loss=$(echo $2 $3|awk '{printf "%0.2f\n", ($2-$1)/$1*100}') echo -e "loss_base: $2 loss_test: $3 loss_diff: $diff_loss%" | tee -a ${log_path}/result.log if [ $2 != $3 ];then - echo -e "\033[31m $1 loss diff check failed! \033[0m" | tee -a ${log_path}/result.log - exit -1 + if [ -z "$8" ] || [ $8 -ne 1 ] ;then + echo -e "\033[31m $1 loss diff check failed! \033[0m" | tee -a ${log_path}/result.log + exit -1 + else + diff=$(echo "$2 $3" | awk '{print $1-$2}') + gt=$(echo "${diff#-} 1e-5" | awk '{print ($1>$2)?"1":"0"}') + if [ $st -eq 1 ];then + echo -e "\033[31m $1 loss diff check failed! \033[0m" | tee -a ${log_path}/result.log + exit -1 + fi + fi fi diff_ips=$(echo $4 $5|awk '{printf "%0.2f\n", ($2-$1)/$1*100}') From abf7b82b92c9552c2b41c23ff56a7719987ca396 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 18 Apr 2024 11:54:53 +0800 Subject: [PATCH 7/8] fix check_result --- scripts/distribute/ci_case_auto.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 9c58d896861c..2c3596c798a1 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1944,7 +1944,7 @@ function check_result() { exit -1 fi - if [ $# -ne 7 ]; then + if [ $# -ne 7 ] && [ $# -ne 8 ]; then echo -e "\033[31m $1 parameter transfer failed: $@ \033[0m" | tee -a ${log_path}/result.log exit -1 fi @@ -1958,7 +1958,7 @@ function check_result() { else diff=$(echo "$2 $3" | awk '{print $1-$2}') gt=$(echo "${diff#-} 1e-5" | awk '{print ($1>$2)?"1":"0"}') - if [ $st -eq 1 ];then + if [ $gt -eq 1 ];then echo -e "\033[31m $1 loss diff check failed! \033[0m" | tee -a ${log_path}/result.log exit -1 fi From f1b93c8e76d821182c80d1bed7c776b1ce618977 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 18 Apr 2024 16:38:55 +0800 Subject: [PATCH 8/8] add ground truth for llm_gpt_dygraph --- scripts/distribute/ci_case_auto.sh | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 2c3596c798a1..7dde1396d725 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -1710,7 +1710,7 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { --enable_auto_parallel 1 \ --to_static 0 \ --fp16 0 \ - --fp16_opt_level "O2" + --fp16_opt_level "O2" \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` @@ -1721,6 +1721,9 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() { loss_md5_base=0ebf68698887b33b33a46518621cf412 ips_base=-1 mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base=10.58541679 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -1780,7 +1783,7 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { --enable_auto_parallel 1 \ --to_static 0 \ --fp16 0 \ - --fp16_opt_level "O2" + --fp16_opt_level "O2" \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` @@ -1791,6 +1794,9 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() { loss_md5_base=6df87d01bd08113a92930f6349514b35 ips_base=-1 mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base=10.58452606 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -1850,7 +1856,7 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { --enable_auto_parallel 1 \ --to_static 0 \ --fp16 0 \ - --fp16_opt_level "O2" + --fp16_opt_level "O2" \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` @@ -1861,6 +1867,9 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() { loss_md5_base=6cb4e151b35f026190df90ab240d9a95 ips_base=-1 mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base=10.57996178 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" } @@ -1920,7 +1929,7 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { --enable_auto_parallel 1 \ --to_static 0 \ --fp16 1 \ - --fp16_opt_level "O2" + --fp16_opt_level "O2" \ >>${log_path}/$FUNCNAME 2>&1 loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'` @@ -1931,6 +1940,9 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() { loss_md5_base=e82a1f5668870d18a2d45b3ee0a25386 ips_base=-1 mem_base=-1 + if [ $IS_A100 -ne 0 ];then + loss_base=10.58061218 + fi check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} echo "=========== $FUNCNAME run end ===========" }