Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add control/status API #37885

Merged
merged 10 commits into from
Dec 23, 2021
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .tensor.attribute import shape # noqa: F401
from .tensor.attribute import real # noqa: F401
from .tensor.attribute import imag # noqa: F401
from .tensor.attribute import is_floating_point # noqa: F401
from .tensor.creation import to_tensor # noqa: F401
from .tensor.creation import diag # noqa: F401
from .tensor.creation import diagflat # noqa: F401
Expand Down Expand Up @@ -281,6 +282,7 @@
from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401
from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
from .framework import load # noqa: F401
from .framework import DataParallel # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def test_paddle_imperative_set_grad_enabled(self):
self.assertTrue(tmp2._grad_ivar() is not None)
self.assertTrue(l0.weight._grad_ivar() is not None)

def test_paddle_imperative_is_grad_enabled(self):
with fluid.dygraph.guard():
with paddle.set_grad_enabled(False):
self.assertTrue(paddle.is_grad_enabled() is False)
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())

def test_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
Expand Down
1 change: 1 addition & 0 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401

from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.layers.tensor import create_parameter # noqa: F401
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,28 @@ def set_grad_enabled(mode):
tracer._has_grad = prev_mode
else:
yield


def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add comment "Dygraph gradient calculation mode is enabled by default."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = _dygraph_tracer()
return tracer._has_grad if tracer else False
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .attribute import shape # noqa: F401
from .attribute import real # noqa: F401
from .attribute import imag # noqa: F401
from .attribute import is_floating_point # noqa: F401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code of is_floating_point, do we have to add API document for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .creation import to_tensor # noqa: F401
from .creation import diag # noqa: F401
from .creation import diagflat # noqa: F401
Expand Down Expand Up @@ -412,6 +413,7 @@
'shape',
'real',
'imag',
'is_floating_point',
'digamma',
'diagonal',
'trunc',
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/tensor/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,30 @@ def is_complex(x):


def is_floating_point(x):
"""
Returns whether the dtype of `x` is one of paddle.float64, paddle.float32, paddle.float16, and paddle.bfloat16.

Args:
x (Tensor): The input tensor.

Returns:
bool: True if the dtype of `x` is floating type, otherwise false.

Examples:
.. code-block:: python

import paddle

x = paddle.arange(1., 5., dtype='float32')
y = paddle.arange(1, 5, dtype='int32')
print(paddle.is_floating_point(x))
# True
print(paddle.is_floating_point(y))
# False
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype
is_fp_dtype = (dtype == core.VarDesc.VarType.FP32 or
dtype == core.VarDesc.VarType.FP64 or
Expand Down