Skip to content

Commit

Permalink
add control/status API (#37885)
Browse files Browse the repository at this point in the history
* add control/status API, test=develop

* fix import error, test=develop

* add is_grad_enabled unittest, test=develop

* add code comment for example code and API, test=develop

* add checking for type, test=develop

* add api description, test=develop

* fix docs index_en, test=document_fix

* fix doc of is_floating_point, test=document_fix
  • Loading branch information
Avin0323 authored Dec 23, 2021
1 parent 745477f commit 21b7ed3
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .tensor.attribute import shape # noqa: F401
from .tensor.attribute import real # noqa: F401
from .tensor.attribute import imag # noqa: F401
from .tensor.attribute import is_floating_point # noqa: F401
from .tensor.creation import to_tensor # noqa: F401
from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401
Expand Down Expand Up @@ -285,6 +286,7 @@
from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401
from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
from .framework import load # noqa: F401
from .framework import DataParallel # noqa: F401
Expand Down Expand Up @@ -453,6 +455,7 @@
'shape',
'real',
'imag',
'is_floating_point',
'complex',
'reciprocal',
'rand',
Expand All @@ -468,6 +471,7 @@
'median',
'no_grad',
'set_grad_enabled',
'is_grad_enabled',
'mod',
'abs',
'tril',
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def test_paddle_imperative_set_grad_enabled(self):
self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None)

def test_paddle_imperative_is_grad_enabled(self):
with fluid.dygraph.guard():
with paddle.set_grad_enabled(False):
self.assertTrue(paddle.is_grad_enabled() is False)
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())

def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
Expand Down
1 change: 1 addition & 0 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401

from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.layers.tensor import create_parameter # noqa: F401
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,28 @@ def set_grad_enabled(mode):
tracer._has_grad = prev_mode
else:
yield


def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = _dygraph_tracer()
return tracer._has_grad if tracer else False
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .attribute import shape # noqa: F401
from .attribute import real # noqa: F401
from .attribute import imag # noqa: F401
from .attribute import is_floating_point # noqa: F401
from .creation import to_tensor # noqa: F401
from .creation import diag # noqa: F401
from .creation import diagflat # noqa: F401
Expand Down Expand Up @@ -418,6 +419,7 @@
'shape',
'real',
'imag',
'is_floating_point',
'digamma',
'diagonal',
'trunc',
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/tensor/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@ def is_complex(x):


def is_floating_point(x):
"""
Returns whether the dtype of `x` is one of paddle.float64, paddle.float32, paddle.float16, and paddle.bfloat16.
Args:
x (Tensor): The input tensor.
Returns:
bool: True if the dtype of `x` is floating type, otherwise false.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(1., 5., dtype='float32')
y = paddle.arange(1, 5, dtype='int32')
print(paddle.is_floating_point(x))
# True
print(paddle.is_floating_point(y))
# False
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or
Expand Down

0 comments on commit 21b7ed3

Please sign in to comment.