Skip to content

Commit

Permalink
Dygraph Recompute: support amp (#33251)
Browse files Browse the repository at this point in the history
* Dygraph Recompute support AMP

* dygraph recompute: update unitest
  • Loading branch information
JZ-LIANG authored Jun 3, 2021
1 parent c70f1ca commit 8752c91
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 25 deletions.
23 changes: 16 additions & 7 deletions python/paddle/distributed/fleet/utils/recompute.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ def forward(ctx, run_function, preserve_rng_state, *args):
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
outputs = run_function(*args)

return outputs

@staticmethod
Expand All @@ -119,15 +121,23 @@ def backward(ctx, *args):
tracer = framework._dygraph_tracer()
tracer._has_grad = True

# TODO support AMP

# NOTE support AMP
# need restore auto_cast state as well as w/b list
if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, core.VarBase):
outputs = (outputs, )
Expand Down Expand Up @@ -155,7 +165,6 @@ def backward(ctx, *args):

grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))

return grads


Expand Down
73 changes: 55 additions & 18 deletions python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,12 @@ def forward(self, inputs):
return inputs


def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
def run_model(recompute_block=[], recompute_kwargs={}, enable_autocast=False):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)

if cuda_state:
paddle.set_cuda_rng_state(cuda_state)

batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
Expand All @@ -110,19 +107,27 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())

if enable_autocast:
scaler = paddle.amp.GradScaler()

loss_ = []
param_ = []
grad_ = []
for step in range(10):

x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
y_pred = model(x)
loss = y_pred.mean()

loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()
with paddle.amp.auto_cast(True):
y_pred = model(x)
loss = y_pred.mean()
if enable_autocast:
scaler.scale(loss).backward()
scaler.minimize(optimizer, loss)
else:
loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()

param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())
Expand All @@ -138,38 +143,70 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)

cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(recompute_block=[])

# recompute second block
loss, param, grad = run_model(recompute_block=[1])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute fourth block
loss, param, grad = run_model(recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second to fourth block
loss, param, grad = run_model(recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second & fourth block
loss, param, grad = run_model(recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

def test_fc_net_without_restore_rng(self):
loss_ref, param_ref, grad_ref = run_model(
recompute_block=[2],
recompute_kwargs={"preserve_rng_state": False},
enable_autocast=True)

def test_fc_net_with_amp(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)

# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[])
recompute_block=[], enable_autocast=True)

# recompute second block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
loss, param, grad = run_model(recompute_block=[1], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[3])
loss, param, grad = run_model(recompute_block=[3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second to fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3])
loss, param, grad = run_model(
recompute_block=[1, 2, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second & fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
loss, param, grad = run_model(
recompute_block=[1, 3], enable_autocast=True)
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
None, recompute_block=[2], recompute_kwargs=kwargs)
recompute_block=[2], recompute_kwargs=kwargs)

def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2])
loss_ref, param_ref, grad_ref = run_model(recompute_block=[2])


if __name__ == '__main__':
Expand Down

0 comments on commit 8752c91

Please sign in to comment.