From 94e0376b67207a076a2d9dae313281662525b810 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 21 Dec 2022 02:04:56 +0000 Subject: [PATCH] cherry-pick #75b734 --- .../unittests/test_sparse_attention_op.py | 220 +++++++++--------- 1 file changed, 116 insertions(+), 104 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index 4337461d48d42..7f27300e53e87 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -93,14 +93,9 @@ def get_csr_value(mat, layout, nnz): return value -def ref_sparse_attention(q, - k, - v, - offset, - columns, - kp_mask=None, - attn_mask=None, - bsz=None): +def ref_sparse_attention( + q, k, v, offset, columns, kp_mask=None, attn_mask=None, bsz=None +): row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] mat = np.zeros((row, row)) for cur_row in range(row): @@ -111,7 +106,7 @@ def ref_sparse_attention(q, mat[cur_row][cur_col] = 1 a = np.dot(q, k.T) * mat a_value = get_csr_value(a, mat, nnz) - scaling = float(col)**-0.5 + scaling = float(col) ** -0.5 a = scaling * a for i in range(row): for j in range(row): @@ -127,13 +122,9 @@ def ref_sparse_attention(q, return result, a_value, b_value -def ref_batch_sparse_attention(q, - k, - v, - offset, - columns, - kp_mask=None, - attn_mask=None): +def ref_batch_sparse_attention( + q, k, v, offset, columns, kp_mask=None, attn_mask=None +): batch_size, num_heads, row, col = q.shape nnz = columns.shape[2] result = np.zeros((batch_size, num_heads, row, col)) @@ -141,11 +132,16 @@ def ref_batch_sparse_attention(q, result_softmax = np.zeros((batch_size, num_heads, nnz)) for i in range(batch_size): for j in range(num_heads): - cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] + cur_q, cur_k, cur_v, = ( + q[i][j], + k[i][j], + v[i][j], + ) cur_offset, cur_columns = offset[i][j], columns[i][j] if kp_mask is None and attn_mask is None: cur_result, cur_sdd, cur_softmax = ref_sparse_attention( - cur_q, cur_k, cur_v, cur_offset, cur_columns) + cur_q, cur_k, cur_v, cur_offset, cur_columns + ) else: cur_result, cur_sdd, cur_softmax = ref_sparse_attention( cur_q, @@ -155,7 +151,8 @@ def ref_batch_sparse_attention(q, cur_columns, kp_mask=kp_mask, attn_mask=attn_mask, - bsz=i) + bsz=i, + ) result[i][j] = cur_result result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax return result, result_sdd, result_softmax @@ -193,10 +190,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11030, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", ) class TestSparseAttentionOp(OpTest): - def config(self): self.shape = (1, 1, 16, 16) self.blocksize = 4 @@ -212,8 +208,9 @@ def setUp(self): self.k = np.random.random(self.shape).astype(self.dtype) self.v = np.random.random(self.shape).astype(self.dtype) # init CSR tensor - offset, columns = init_csr_format(self.shape[0], self.shape[1], - self.shape[2], self.blocksize) + offset, columns = init_csr_format( + self.shape[0], self.shape[1], self.shape[2], self.blocksize + ) self.offset = offset.astype('int32') self.columns = columns.astype('int32') # init mask tensor @@ -234,10 +231,12 @@ def setUp(self): self.offset, self.columns, kp_mask=self.key_padding_mask, - attn_mask=self.attn_mask) + attn_mask=self.attn_mask, + ) else: result, result_sdd, result_softmax = ref_batch_sparse_attention( - self.q, self.k, self.v, self.offset, self.columns) + self.q, self.k, self.v, self.offset, self.columns + ) if self.use_mask == True: self.inputs = { @@ -260,7 +259,7 @@ def setUp(self): self.outputs = { 'Out': result.astype(self.dtype), 'SparseDotSdd': result_sdd.astype(self.dtype), - 'Softmax': result_softmax.astype(self.dtype) + 'Softmax': result_softmax.astype(self.dtype), } def test_check_output(self): @@ -273,7 +272,6 @@ def test_check_grad(self): class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): - def config(self): self.shape = (1, 1, 8, 16) self.blocksize = 2 @@ -282,7 +280,6 @@ def config(self): class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): - def config(self): self.shape = (2, 2, 32, 8) self.blocksize = 8 @@ -292,10 +289,9 @@ def config(self): @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11030, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3", ) class TestSparseAttentionAPI(unittest.TestCase): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (1, 1, 8, 4) @@ -310,54 +306,62 @@ def test_static_graph(self): K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype) V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype) - batch_size, num_heads, rows = self.shape[0], self.shape[ - 1], self.shape[2] + batch_size, num_heads, rows = ( + self.shape[0], + self.shape[1], + self.shape[2], + ) block_num = rows / self.blocksize block_last = rows % self.blocksize - sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last + sparse_nnz_num = ( + block_num * self.blocksize * self.blocksize + + block_last * block_last + ) offset_shape = (batch_size, num_heads, rows + 1) columns_shape = (batch_size, num_heads, int(sparse_nnz_num)) - offset = paddle.static.data(name="Offset", - shape=offset_shape, - dtype="int32") - columns = paddle.static.data(name="Columns", - shape=columns_shape, - dtype="int32") + offset = paddle.static.data( + name="Offset", shape=offset_shape, dtype="int32" + ) + columns = paddle.static.data( + name="Columns", shape=columns_shape, dtype="int32" + ) key_padding_mask_shape = (self.shape[0], self.shape[2]) attn_mask_shape = (self.shape[2], self.shape[2]) if self.use_mask == True: key_padding_mask = paddle.static.data( name="KeyPaddingMask", shape=key_padding_mask_shape, - dtype=self.dtype) - attn_mask = paddle.static.data(name="AttnMask", - shape=attn_mask_shape, - dtype=self.dtype) - Out = F.sparse_attention(Q, - K, - V, - offset, - columns, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask) + dtype=self.dtype, + ) + attn_mask = paddle.static.data( + name="AttnMask", shape=attn_mask_shape, dtype=self.dtype + ) + Out = F.sparse_attention( + Q, + K, + V, + offset, + columns, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) else: Out = F.sparse_attention(Q, K, V, offset, columns) Q_np = np.random.random(self.shape).astype(self.dtype) K_np = np.random.random(self.shape).astype(self.dtype) V_np = np.random.random(self.shape).astype(self.dtype) - offset_np, columns_np = init_csr_format(self.shape[0], - self.shape[1], - self.shape[2], - self.blocksize) + offset_np, columns_np = init_csr_format( + self.shape[0], self.shape[1], self.shape[2], self.blocksize + ) offset_np = offset_np.astype('int32') columns_np = columns_np.astype('int32') # init mask tensor - key_padding_mask_np = np.random.randint(0, - 2, - size=key_padding_mask_shape) + key_padding_mask_np = np.random.randint( + 0, 2, size=key_padding_mask_shape + ) attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape) key_padding_mask_np = init_mask(key_padding_mask_np) attn_mask_np = init_mask(attn_mask_np) @@ -366,16 +370,18 @@ def test_static_graph(self): exe = fluid.Executor(self.place) if self.use_mask == True: - fetches_result = exe.run(feed={ - "Q": Q_np, - "K": K_np, - "V": V_np, - "Offset": offset_np, - "Columns": columns_np, - 'KeyPaddingMask': key_padding_mask_np, - 'AttnMask': attn_mask_np - }, - fetch_list=[Out]) + fetches_result = exe.run( + feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np, + 'KeyPaddingMask': key_padding_mask_np, + 'AttnMask': attn_mask_np, + }, + fetch_list=[Out], + ) expected_result, __, __ = ref_batch_sparse_attention( Q_np, K_np, @@ -383,28 +389,32 @@ def test_static_graph(self): offset_np, columns_np, kp_mask=key_padding_mask_np, - attn_mask=attn_mask_np) + attn_mask=attn_mask_np, + ) else: - fetches_result = exe.run(feed={ - "Q": Q_np, - "K": K_np, - "V": V_np, - "Offset": offset_np, - "Columns": columns_np - }, - fetch_list=[Out]) + fetches_result = exe.run( + feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np, + }, + fetch_list=[Out], + ) expected_result, __, __ = ref_batch_sparse_attention( - Q_np, K_np, V_np, offset_np, columns_np) + Q_np, K_np, V_np, offset_np, columns_np + ) - np.testing.assert_allclose(fetches_result, - expected_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + fetches_result[0], expected_result, rtol=1e-05, atol=1e-05 + ) def test_dygraph(self): paddle.disable_static() - offset, columns = init_csr_format(self.shape[0], self.shape[1], - self.shape[2], self.blocksize) + offset, columns = init_csr_format( + self.shape[0], self.shape[1], self.shape[2], self.blocksize + ) offset = offset.astype('int32') columns = columns.astype('int32') query = np.random.random(self.shape).astype(self.dtype) @@ -429,13 +439,15 @@ def test_dygraph(self): paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place) if self.use_mask == True: - paddle_result = F.sparse_attention(paddle_query, - paddle_key, - paddle_value, - paddle_offset, - paddle_colunmns, - key_padding_mask=paddle_kp_mask, - attn_mask=paddle_attn_mask) + paddle_result = F.sparse_attention( + paddle_query, + paddle_key, + paddle_value, + paddle_offset, + paddle_colunmns, + key_padding_mask=paddle_kp_mask, + attn_mask=paddle_attn_mask, + ) numpy_result, __, __ = ref_batch_sparse_attention( query, @@ -444,25 +456,29 @@ def test_dygraph(self): offset, columns, kp_mask=key_padding_mask, - attn_mask=attn_mask) + attn_mask=attn_mask, + ) numpy_result = numpy_result.astype(self.dtype) else: - paddle_result = F.sparse_attention(paddle_query, paddle_key, - paddle_value, paddle_offset, - paddle_colunmns) + paddle_result = F.sparse_attention( + paddle_query, + paddle_key, + paddle_value, + paddle_offset, + paddle_colunmns, + ) numpy_result, __, __ = ref_batch_sparse_attention( - query, key, value, offset, columns) + query, key, value, offset, columns + ) numpy_result = numpy_result.astype(self.dtype) - np.testing.assert_allclose(paddle_result.numpy(), - numpy_result, - rtol=1e-05, - atol=1e-05) + np.testing.assert_allclose( + paddle_result.numpy(), numpy_result, rtol=1e-05, atol=1e-05 + ) class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 2, 8, 4) @@ -472,7 +488,6 @@ def setUp(self): class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 2, 64, 32) @@ -482,7 +497,6 @@ def setUp(self): class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 1, 64, 32) @@ -492,7 +506,6 @@ def setUp(self): class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (4, 4, 128, 32) @@ -502,7 +515,6 @@ def setUp(self): class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): - def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (3, 3, 35, 15)