Skip to content

Commit

Permalink
fix A100 fused linear grad add ut bug (PaddlePaddle#56136)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Aug 29, 2023
1 parent 22fe39f commit 272b375
Showing 1 changed file with 39 additions and 22 deletions.
61 changes: 39 additions & 22 deletions test/legacy_test/test_fused_linear_param_grad_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def recreate(x, multi_precision):
return paddle.to_tensor(x.numpy())


def run_ground_truth(x, dy, dweight, dbias, multi_precision):
def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias):
x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision)

dweight_tmp = paddle.matmul(
Expand All @@ -69,24 +69,35 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision):
assert dweight.dtype == dweight.dtype
dweight += dweight_tmp

dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None:
dbias = dbias_tmp
else:
assert dbias.shape == dbias_tmp.shape
assert dbias.dtype == dbias_tmp.dtype
dbias += dbias_tmp
if has_bias:
dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0)
if dbias is None:
dbias = dbias_tmp
else:
assert dbias.shape == dbias_tmp.shape
assert dbias.dtype == dbias_tmp.dtype
dbias += dbias_tmp

return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy()
else:
return promote_dtype(dweight).numpy()


def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision):
def run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision, has_bias
):
dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
x, dy, dweight, dbias, multi_precision, has_bias
)
if dweight is not None:
assert dweight_new.data_ptr() == dweight.data_ptr()
return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy()
if has_bias:
return (
promote_dtype(dweight_new).numpy(),
promote_dtype(dbias_new).numpy(),
)
else:
return promote_dtype(dweight_new).numpy()


class TestMainClassBase(unittest.TestCase):
Expand All @@ -103,7 +114,9 @@ def rand(self, shape, dtype=None):
x = paddle.to_tensor(x)
return x.astype(dtype or self.dtype)

def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision):
def generate_rand_inputs(
self, has_dweight, has_dbias, multi_precision, has_bias
):
x_shape = self.shape
dy_shape = self.shape[:-1] + [self.output_size]
dweight_shape = [self.shape[-1], self.output_size]
Expand All @@ -118,22 +131,23 @@ def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision):
else:
dweight = None

if has_dbias:
if has_bias and has_dbias:
dbias = self.rand(dbias_shape)
if multi_precision:
dbias = promote_dtype(dbias)
else:
dbias = None
return x, dy, dweight, dbias

def check_main(self, has_dweight, has_dbias, multi_precision):
print(has_dweight, has_dbias, multi_precision)
def check_main(self, has_dweight, has_dbias, multi_precision, has_bias):
x, dy, dweight, dbias = self.generate_rand_inputs(
has_dweight, has_dbias, multi_precision
has_dweight, has_dbias, multi_precision, has_bias
)
res1 = run_ground_truth(
x, dy, dweight, dbias, multi_precision, has_bias
)
res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision)
res2 = run_fused_linear_param_grad_add(
x, dy, dweight, dbias, multi_precision
x, dy, dweight, dbias, multi_precision, has_bias
)
self.assertEqual(len(res1), len(res2))
for r1, r2 in zip(res1, res2):
Expand All @@ -153,9 +167,12 @@ def test_main(self):
return

for has_dweight in [False, True]:
for has_dbias in [False, True]:
for multi_precision in [False, True]:
self.check_main(has_dweight, has_dbias, multi_precision)
for has_bias in [False, True]:
for has_dbias in [False, True]:
for multi_precision in [False, True]:
self.check_main(
has_dweight, has_dbias, multi_precision, has_bias
)


class TestMainClassBF16(TestMainClassBase):
Expand Down

0 comments on commit 272b375

Please sign in to comment.