From fbb4028148cf3a87f4fd464b452597c94e321374 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 11 Jan 2022 15:14:26 +0800 Subject: [PATCH] [AMP] Check call order of paddle.amp.decorate and paddle.DataParallel (#38785) * check amp.decorate and DataParallel * refine coverage * fix layer dtype * refine code --- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 ++++ python/paddle/fluid/dygraph/layers.py | 2 ++ .../unittests/test_imperative_auto_mixed_precision.py | 8 ++++++++ 3 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 15adf4cb6faaf..f09e210c3c161 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -145,6 +145,10 @@ def check_models(models): raise RuntimeError( "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.". format(type(model))) + if isinstance(model, paddle.DataParallel): + raise RuntimeError( + "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model." + ) def check_optimizers(optimizers): diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 4c37a378e0aae..6a65b3bd9c684 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1569,6 +1569,8 @@ def _apply(self, func, device, dtype, blocking, include_sublayers=True): for key, buf in self._buffers.items(): self._buffers[key] = func(buf, device, dtype, blocking) + self._dtype = dtype + def _to_impl(self, device=None, dtype=None, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index a8ed23f5938c0..62b40f88571d4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -536,6 +536,14 @@ def __init__(self): self.assertRaises(TypeError, test_error_model) + def test_error_distributed_model(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + model = paddle.DataParallel(model) + with fluid.dygraph.guard(): + model = paddle.amp.decorate(models=model, level='O2') + + self.assertRaises(RuntimeError, test_error_distributed_model) + def test_error_optimizer(): class MyOptimizer(object): def __init__(self):