Skip to content

Commit

Permalink
Fix unittest of fused_rope (#63538)
Browse files Browse the repository at this point in the history
* fix test_fused_rotary_position_embedding
* add assertion & add test_errors
  • Loading branch information
kircle888 authored Apr 16, 2024
1 parent bafb836 commit 42371b2
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def fused_rotary_position_embedding(
[[ 0.07116699, -0.90966797],
[-0.03628540, -0.20202637]]]])
"""
if (sin is None) or (cos is None):
assert (
position_ids is None
), "position_ids without sin/cos is not correctly supported now."
assert (
use_neox_rotary_style
), "rotate_half without sin/cos is not correctly supported now."

if in_dynamic_or_pir_mode():
return _C_ops.fused_rotary_position_embedding(
q,
Expand Down
151 changes: 123 additions & 28 deletions test/legacy_test/test_fused_rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def mult_qkv_rotate_half(value, cos_tensor, sin_tensor):
return query


def get_sin_cos_tensor(seq_len, head_dim, sign=1):
def get_sin_cos_tensor(seq_len, head_dim, sign=1, rotate_half=False):
pos_seq = paddle.arange(0, seq_len, 1, dtype="float32")
indices = paddle.arange(0, head_dim, 2, dtype="float32")

Expand All @@ -82,12 +82,23 @@ def get_sin_cos_tensor(seq_len, head_dim, sign=1):

i = 0

for value in iter_array:
sin_sin[i * 2] = sign * np.sin(value)
cos_cos[i * 2 + 0] = np.cos(value)
sin_sin[i * 2 + 1] = np.sin(value)
cos_cos[i * 2 + 1] = np.cos(value)
i += 1
if rotate_half:
stride = head_dim // 2
for value in iter_array:
sin_sin[i] = sign * np.sin(value)
cos_cos[i] = np.cos(value)
sin_sin[i + stride] = np.sin(value)
cos_cos[i + stride] = np.cos(value)
i += 1
if i % head_dim == stride:
i += stride
else:
for value in iter_array:
sin_sin[i * 2] = sign * np.sin(value)
cos_cos[i * 2 + 0] = np.cos(value)
sin_sin[i * 2 + 1] = np.sin(value)
cos_cos[i * 2 + 1] = np.cos(value)
i += 1

tensor_sin = paddle.reshape(
paddle.to_tensor(sin_sin),
Expand All @@ -109,7 +120,7 @@ def paddle_fused_rotary_position_embedding(
cos_tensor=None,
position_ids=None,
use_neox_rotary_style=True,
**kwargs
**kwargs,
):
# permute q, k, v from [batch_size, seq_len, num_heads, head_dim]
# to [batch_size, num_heads, seq_len, head_dim]
Expand Down Expand Up @@ -152,7 +163,7 @@ def paddle_fused_rotary_position_embedding(
"core is not compiled with CUDA ",
)
@param.parameterized_class(
("name", 'shape_q', 'shape_k', 'shape_v', 'position_ids_list'),
("name", "shape_q", "shape_k", "shape_v", "position_ids_list"),
[
(
"qkv_input",
Expand Down Expand Up @@ -198,10 +209,11 @@ def paddle_fused_rotary_position_embedding(
)
class TestFusedRotaryPositionEmbedding(unittest.TestCase):
def setUp(self):
self.dtype = 'float32'
self.dtype = "float32"
self.training = True
self.seed = 1203
self.rtol = 1e-5
self.atol = 1e-6

def get_paddle_tensor(self, shape):
if shape is None:
Expand All @@ -211,7 +223,9 @@ def get_paddle_tensor(self, shape):
tmp.stop_gradient = False
return tmp

def get_inputs(self, seed, with_sin_cos):
def get_inputs(
self, seed, with_sin_cos, with_grads=False, rotate_half=False
):
paddle.disable_static()
paddle.seed(seed)
# tensor_q shape: [batch_size, seq_len, num_heads, head_dim]
Expand All @@ -220,11 +234,27 @@ def get_inputs(self, seed, with_sin_cos):
tensor_v = self.get_paddle_tensor(self.shape_v)

tensor_sin, tensor_cos = (
get_sin_cos_tensor(tensor_q.shape[1], tensor_q.shape[3], 1)
get_sin_cos_tensor(
tensor_q.shape[1], tensor_q.shape[3], 1, rotate_half=rotate_half
)
if with_sin_cos
else (None, None)
)
return tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos
if not with_grads:
return (tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos)
tensor_grad_outq = self.get_paddle_tensor(self.shape_q)
tensor_grad_outk = self.get_paddle_tensor(self.shape_k)
tensor_grad_outv = self.get_paddle_tensor(self.shape_v)
return (
tensor_q,
tensor_k,
tensor_v,
tensor_sin,
tensor_cos,
tensor_grad_outq,
tensor_grad_outk,
tensor_grad_outv,
)

def get_forward_backward(
self,
Expand All @@ -239,8 +269,20 @@ def get_forward_backward(
fw = []
bw = []

tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs(
seed, with_sin_cos
(
tensor_q,
tensor_k,
tensor_v,
tensor_sin,
tensor_cos,
tensor_grad_outq,
tensor_grad_outk,
tensor_grad_outv,
) = self.get_inputs(
seed,
with_sin_cos,
with_grads=True,
rotate_half=not use_neox_rotary_style,
)

if test_time_major:
Expand All @@ -251,7 +293,27 @@ def get_forward_backward(
tensor_k = paddle.transpose(tensor_k, perm=[1, 0])
if tensor_v is not None:
tensor_v = paddle.transpose(tensor_v, perm=[1, 0])
if tensor_grad_outq is not None:
tensor_grad_outq = paddle.transpose(
tensor_grad_outq, perm=[1, 0]
)
if tensor_grad_outk is not None:
tensor_grad_outk = paddle.transpose(
tensor_grad_outk, perm=[1, 0]
)
if tensor_grad_outv is not None:
tensor_grad_outv = paddle.transpose(
tensor_grad_outv, perm=[1, 0]
)

tensor_q = tensor_q.detach().clone()
tensor_q.stop_gradient = False
if tensor_k is not None:
tensor_k = tensor_k.detach().clone()
tensor_k.stop_gradient = False
if tensor_v is not None:
tensor_v = tensor_v.detach().clone()
tensor_v.stop_gradient = False
out_q, out_k, out_v = rope_function(
tensor_q,
tensor_k,
Expand All @@ -268,12 +330,20 @@ def get_forward_backward(
if out_value is None or not out_value._is_initialized():
continue
fw.append(out_value)
out_init_grad.append(paddle.randn(out_value.shape, self.dtype))
for grad_value in [
tensor_grad_outq,
tensor_grad_outk,
tensor_grad_outv,
]:
if grad_value is None or not grad_value._is_initialized():
continue
out_init_grad.append(grad_value)

paddle.autograd.backward(fw, out_init_grad, True)
bw = list(
filter(lambda x: x is not None, [tensor_q, tensor_k, tensor_v])
)
bw = [x.grad for x in bw]

if test_time_major:
# transpose back
Expand All @@ -289,6 +359,8 @@ def check_results(self, p_results, f_results):
p_results[i].numpy(),
f_results[i].numpy(),
rtol=self.rtol,
atol=self.atol,
err_msg=f"Tensor {i} not match",
)

def test_fused_rope(self):
Expand Down Expand Up @@ -388,7 +460,7 @@ def test_fused_rope_position_ids(self):
def test_static(self):
paddle.disable_static()
tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs(
self.seed, True
self.seed, True, rotate_half=True
)
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding,
Expand Down Expand Up @@ -448,11 +520,11 @@ def test_static(self):
exe = paddle.static.Executor()

feed = {
'sin': tensor_sin.numpy(),
'cos': tensor_cos.numpy(),
"sin": tensor_sin.numpy(),
"cos": tensor_cos.numpy(),
}
for var_name, input_tensor in zip(
['q', 'k', 'v'], [tensor_q, tensor_k, tensor_v]
["q", "k", "v"], [tensor_q, tensor_k, tensor_v]
):
if input_tensor is not None:
feed[var_name] = input_tensor.numpy()
Expand All @@ -473,17 +545,15 @@ def test_static(self):

for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(),
outs[i],
rtol=self.rtol,
p_fw[i].numpy(), outs[i], rtol=self.rtol, atol=self.atol
)
paddle.disable_static()

@test_with_pir_api
def test_static_time_major(self):
paddle.disable_static()
tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs(
self.seed, True
self.seed, True, rotate_half=True
)
p_fw, p_bw = self.get_forward_backward(
paddle_fused_rotary_position_embedding,
Expand Down Expand Up @@ -562,11 +632,11 @@ def test_static_time_major(self):
exe = paddle.static.Executor()

feed = {
'sin': tensor_sin.numpy(),
'cos': tensor_cos.numpy(),
"sin": tensor_sin.numpy(),
"cos": tensor_cos.numpy(),
}
for var_name, input_tensor in zip(
['q', 'k', 'v'], [tensor_q, tensor_k, tensor_v]
["q", "k", "v"], [tensor_q, tensor_k, tensor_v]
):
if input_tensor is not None:
feed[var_name] = input_tensor.numpy().transpose((1, 0, 2, 3))
Expand All @@ -590,9 +660,34 @@ def test_static_time_major(self):
p_fw[i].numpy(),
outs[i].transpose((1, 0, 2, 3)),
rtol=self.rtol,
atol=self.atol,
)
paddle.disable_static()

def test_errors(self):
def test_error1():
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding,
seed=self.seed,
test_time_major=False,
with_sin_cos=False,
use_neox_rotary_style=False,
)

self.assertRaises(AssertionError, test_error1)

def test_error2():
position_ids = paddle.to_tensor(self.position_ids_list)
f_fw, f_bw = self.get_forward_backward(
fused_rotary_position_embedding,
seed=self.seed,
test_time_major=False,
with_sin_cos=False,
position_ids=position_ids,
)

self.assertRaises(AssertionError, test_error2)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit 42371b2

Please sign in to comment.