Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat]Fix error when calling sublayer's non-forward func in dy2stat #37296

Merged
merged 10 commits into from
Dec 24, 2021
7 changes: 6 additions & 1 deletion python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,8 +1077,13 @@ def append_var_from_block_desc_static(block,
else:
lod_level = None

if var_desc.persistable():
current_block = block.program.global_block()
else:
current_block = block

vars_append.append(
block.create_var(
current_block.create_var(
name=var_desc.name(),
dtype=data_type,
type=var_type,
Expand Down
17 changes: 6 additions & 11 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, _convert_into_variable
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
Expand Down Expand Up @@ -914,16 +914,7 @@ def _dygraph_call_func(self, *inputs, **kwargs):
return outputs

def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.

if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers):
return self._dygraph_call_func(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs)
return self._dygraph_call_func(*inputs, **kwargs)

def forward(self, *inputs, **kwargs):
"""
Expand Down Expand Up @@ -1103,6 +1094,8 @@ def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in self._parameters:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name]
if '_sub_layers' in self.__dict__:
_sub_layers = self.__dict__['_sub_layers']
Expand All @@ -1111,6 +1104,8 @@ def __getattr__(self, name):
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name]
return object.__getattribute__(self, name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,5 +379,34 @@ def test_raise_error(self):
net.forward.outputs


class CallNonForwardFuncNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncNet, self).__init__()
self.sub = CallNonForwardFuncSubNet()

@paddle.jit.to_static
def forward(self):
return self.sub.func()


class CallNonForwardFuncSubNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncSubNet, self).__init__()
self.a = paddle.to_tensor([1, 2])

def func(self):
x = self.a * 2
return x


class TestCallNonForwardFunc(unittest.TestCase):
def test_call_non_forward(self):
paddle.disable_static()
net = CallNonForwardFuncNet()
out = net()
self.assertEqual(out.numpy().tolist(), [2, 4])
paddle.enable_static()


if __name__ == '__main__':
unittest.main()