From 3007c7912d25e600ded46ade86de63998f4d91b2 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Mon, 14 Oct 2024 14:35:49 +0800 Subject: [PATCH] [BugFix] fix pir dt2st (#9251) --- paddlenlp/transformers/chatglm_v2/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index ff920f83c6b1..266c6ae4863a 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -147,7 +147,7 @@ def apply_rotary_pos_emb(x: paddle.Tensor, rope_cache: paddle.Tensor) -> paddle. -1, ) x_out2 = x_out2.flatten(3) - return paddle.concat((x_out2, x_pass), axis=-1) + return paddle.concat((x_out2, x_pass.cast(x_out2.dtype)), axis=-1) class RMSNorm(nn.Layer):