diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index e51a5b5c4e852..b94d37acd7b4d 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -91,6 +91,14 @@ def fused_rotary_position_embedding( [[ 0.07116699, -0.90966797], [-0.03628540, -0.20202637]]]]) """ + if (sin is None) or (cos is None): + assert ( + position_ids is None + ), "position_ids without sin/cos is not correctly supported now." + assert ( + use_neox_rotary_style + ), "rotate_half without sin/cos is not correctly supported now." + if in_dynamic_or_pir_mode(): return _C_ops.fused_rotary_position_embedding( q, diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 33e6aef4a68c9..1e08b93bca925 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -68,7 +68,7 @@ def mult_qkv_rotate_half(value, cos_tensor, sin_tensor): return query -def get_sin_cos_tensor(seq_len, head_dim, sign=1): +def get_sin_cos_tensor(seq_len, head_dim, sign=1, rotate_half=False): pos_seq = paddle.arange(0, seq_len, 1, dtype="float32") indices = paddle.arange(0, head_dim, 2, dtype="float32") @@ -82,12 +82,23 @@ def get_sin_cos_tensor(seq_len, head_dim, sign=1): i = 0 - for value in iter_array: - sin_sin[i * 2] = sign * np.sin(value) - cos_cos[i * 2 + 0] = np.cos(value) - sin_sin[i * 2 + 1] = np.sin(value) - cos_cos[i * 2 + 1] = np.cos(value) - i += 1 + if rotate_half: + stride = head_dim // 2 + for value in iter_array: + sin_sin[i] = sign * np.sin(value) + cos_cos[i] = np.cos(value) + sin_sin[i + stride] = np.sin(value) + cos_cos[i + stride] = np.cos(value) + i += 1 + if i % head_dim == stride: + i += stride + else: + for value in iter_array: + sin_sin[i * 2] = sign * np.sin(value) + cos_cos[i * 2 + 0] = np.cos(value) + sin_sin[i * 2 + 1] = np.sin(value) + cos_cos[i * 2 + 1] = np.cos(value) + i += 1 tensor_sin = paddle.reshape( paddle.to_tensor(sin_sin), @@ -109,7 +120,7 @@ def paddle_fused_rotary_position_embedding( cos_tensor=None, position_ids=None, use_neox_rotary_style=True, - **kwargs + **kwargs, ): # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # to [batch_size, num_heads, seq_len, head_dim] @@ -152,7 +163,7 @@ def paddle_fused_rotary_position_embedding( "core is not compiled with CUDA ", ) @param.parameterized_class( - ("name", 'shape_q', 'shape_k', 'shape_v', 'position_ids_list'), + ("name", "shape_q", "shape_k", "shape_v", "position_ids_list"), [ ( "qkv_input", @@ -198,10 +209,11 @@ def paddle_fused_rotary_position_embedding( ) class TestFusedRotaryPositionEmbedding(unittest.TestCase): def setUp(self): - self.dtype = 'float32' + self.dtype = "float32" self.training = True self.seed = 1203 self.rtol = 1e-5 + self.atol = 1e-6 def get_paddle_tensor(self, shape): if shape is None: @@ -211,7 +223,9 @@ def get_paddle_tensor(self, shape): tmp.stop_gradient = False return tmp - def get_inputs(self, seed, with_sin_cos): + def get_inputs( + self, seed, with_sin_cos, with_grads=False, rotate_half=False + ): paddle.disable_static() paddle.seed(seed) # tensor_q shape: [batch_size, seq_len, num_heads, head_dim] @@ -220,11 +234,27 @@ def get_inputs(self, seed, with_sin_cos): tensor_v = self.get_paddle_tensor(self.shape_v) tensor_sin, tensor_cos = ( - get_sin_cos_tensor(tensor_q.shape[1], tensor_q.shape[3], 1) + get_sin_cos_tensor( + tensor_q.shape[1], tensor_q.shape[3], 1, rotate_half=rotate_half + ) if with_sin_cos else (None, None) ) - return tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos + if not with_grads: + return (tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos) + tensor_grad_outq = self.get_paddle_tensor(self.shape_q) + tensor_grad_outk = self.get_paddle_tensor(self.shape_k) + tensor_grad_outv = self.get_paddle_tensor(self.shape_v) + return ( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + tensor_grad_outq, + tensor_grad_outk, + tensor_grad_outv, + ) def get_forward_backward( self, @@ -239,8 +269,20 @@ def get_forward_backward( fw = [] bw = [] - tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( - seed, with_sin_cos + ( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + tensor_grad_outq, + tensor_grad_outk, + tensor_grad_outv, + ) = self.get_inputs( + seed, + with_sin_cos, + with_grads=True, + rotate_half=not use_neox_rotary_style, ) if test_time_major: @@ -251,7 +293,27 @@ def get_forward_backward( tensor_k = paddle.transpose(tensor_k, perm=[1, 0]) if tensor_v is not None: tensor_v = paddle.transpose(tensor_v, perm=[1, 0]) + if tensor_grad_outq is not None: + tensor_grad_outq = paddle.transpose( + tensor_grad_outq, perm=[1, 0] + ) + if tensor_grad_outk is not None: + tensor_grad_outk = paddle.transpose( + tensor_grad_outk, perm=[1, 0] + ) + if tensor_grad_outv is not None: + tensor_grad_outv = paddle.transpose( + tensor_grad_outv, perm=[1, 0] + ) + tensor_q = tensor_q.detach().clone() + tensor_q.stop_gradient = False + if tensor_k is not None: + tensor_k = tensor_k.detach().clone() + tensor_k.stop_gradient = False + if tensor_v is not None: + tensor_v = tensor_v.detach().clone() + tensor_v.stop_gradient = False out_q, out_k, out_v = rope_function( tensor_q, tensor_k, @@ -268,12 +330,20 @@ def get_forward_backward( if out_value is None or not out_value._is_initialized(): continue fw.append(out_value) - out_init_grad.append(paddle.randn(out_value.shape, self.dtype)) + for grad_value in [ + tensor_grad_outq, + tensor_grad_outk, + tensor_grad_outv, + ]: + if grad_value is None or not grad_value._is_initialized(): + continue + out_init_grad.append(grad_value) paddle.autograd.backward(fw, out_init_grad, True) bw = list( filter(lambda x: x is not None, [tensor_q, tensor_k, tensor_v]) ) + bw = [x.grad for x in bw] if test_time_major: # transpose back @@ -289,6 +359,8 @@ def check_results(self, p_results, f_results): p_results[i].numpy(), f_results[i].numpy(), rtol=self.rtol, + atol=self.atol, + err_msg=f"Tensor {i} not match", ) def test_fused_rope(self): @@ -388,7 +460,7 @@ def test_fused_rope_position_ids(self): def test_static(self): paddle.disable_static() tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( - self.seed, True + self.seed, True, rotate_half=True ) p_fw, p_bw = self.get_forward_backward( paddle_fused_rotary_position_embedding, @@ -448,11 +520,11 @@ def test_static(self): exe = paddle.static.Executor() feed = { - 'sin': tensor_sin.numpy(), - 'cos': tensor_cos.numpy(), + "sin": tensor_sin.numpy(), + "cos": tensor_cos.numpy(), } for var_name, input_tensor in zip( - ['q', 'k', 'v'], [tensor_q, tensor_k, tensor_v] + ["q", "k", "v"], [tensor_q, tensor_k, tensor_v] ): if input_tensor is not None: feed[var_name] = input_tensor.numpy() @@ -473,9 +545,7 @@ def test_static(self): for i in range(len(p_fw)): np.testing.assert_allclose( - p_fw[i].numpy(), - outs[i], - rtol=self.rtol, + p_fw[i].numpy(), outs[i], rtol=self.rtol, atol=self.atol ) paddle.disable_static() @@ -483,7 +553,7 @@ def test_static(self): def test_static_time_major(self): paddle.disable_static() tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( - self.seed, True + self.seed, True, rotate_half=True ) p_fw, p_bw = self.get_forward_backward( paddle_fused_rotary_position_embedding, @@ -562,11 +632,11 @@ def test_static_time_major(self): exe = paddle.static.Executor() feed = { - 'sin': tensor_sin.numpy(), - 'cos': tensor_cos.numpy(), + "sin": tensor_sin.numpy(), + "cos": tensor_cos.numpy(), } for var_name, input_tensor in zip( - ['q', 'k', 'v'], [tensor_q, tensor_k, tensor_v] + ["q", "k", "v"], [tensor_q, tensor_k, tensor_v] ): if input_tensor is not None: feed[var_name] = input_tensor.numpy().transpose((1, 0, 2, 3)) @@ -590,9 +660,34 @@ def test_static_time_major(self): p_fw[i].numpy(), outs[i].transpose((1, 0, 2, 3)), rtol=self.rtol, + atol=self.atol, ) paddle.disable_static() + def test_errors(self): + def test_error1(): + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + test_time_major=False, + with_sin_cos=False, + use_neox_rotary_style=False, + ) + + self.assertRaises(AssertionError, test_error1) + + def test_error2(): + position_ids = paddle.to_tensor(self.position_ids_list) + f_fw, f_bw = self.get_forward_backward( + fused_rotary_position_embedding, + seed=self.seed, + test_time_major=False, + with_sin_cos=False, + position_ids=position_ids, + ) + + self.assertRaises(AssertionError, test_error2) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()