From 44721a707948f8deca9ec9bc7985005a4e34318b Mon Sep 17 00:00:00 2001 From: feifei-111 <2364819892@qq.com> Date: Wed, 22 Nov 2023 17:28:37 +0800 Subject: [PATCH] [dy2s] speed up PartialProgram.__call__ (#58771) * move Tensor construction to cpp * mv _remove_no_value to ASTStaticFunction * update --- .../eager/to_static/run_program_op_func.h | 3 - paddle/fluid/pybind/eager_utils.cc | 198 ++++++++++++++++++ paddle/fluid/pybind/eager_utils.h | 18 ++ paddle/fluid/pybind/imperative.cc | 10 - paddle/fluid/pybind/pir.h | 7 + paddle/fluid/pybind/pybind.cc | 13 +- .../paddle/jit/dy2static/partial_program.py | 116 +++++----- .../jit/dy2static/pir_partial_program.py | 56 +++-- .../paddle/jit/sot/symbolic/compile_cache.py | 2 +- test/dygraph_to_static/test_resnet_v2.py | 11 +- test/legacy_test/test_eager_run_program.py | 1 + 11 files changed, 328 insertions(+), 107 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index 8e788bd94162e..996eabcd58296 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -165,7 +165,6 @@ inline void run_program_ad_func( auto x_names = PADDLE_GET_CONST(std::vector, attrs.at("x_names")); - egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) auto grad_node = std::make_shared(1, 2); @@ -270,8 +269,6 @@ inline void pir_run_program_ad_func( PirRunProgramAPI( x, params, out, middles, step_scope, require_any_grad, attrs); if (!is_test && require_any_grad) { - egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); - // Set Attributes grad_node->SetAttrMap(attrs); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 1361641085357..e8d3e01ba4a74 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/hooks.h" @@ -30,9 +31,12 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/op_function_common.h" +#include "paddle/fluid/pybind/pir.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/common/data_type.h" @@ -41,6 +45,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/placement_types.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/flags.h" +#include "paddle/pir/core/attribute.h" PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_int32(check_nan_inf_level); @@ -1858,6 +1863,180 @@ std::vector GetTensorListFromPyObject(PyObject* obj, paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj) { return reinterpret_cast(obj)->tensor; } + +paddle::Tensor CreateTensorFromVarDesc( + const paddle::framework::VarDesc& var_desc) { + auto tensor = paddle::Tensor(); + + auto dtype = var_desc.GetDataType(); + std::vector dims = var_desc.GetShape(); + + auto var_type = var_desc.GetType(); + + auto ddims = phi::make_ddim(dims); + tensor.set_name(var_desc.Name()); + auto autograd_meta = egr::EagerUtils::autograd_meta(&tensor); + autograd_meta->SetPersistable(false); + autograd_meta->SetStopGradient(var_desc.StopGradient()); + + if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { + // TODO(jiabin): Maybe support LOD later + std::shared_ptr dense_tensor = nullptr; + if (dims.size() == 1 && dims[0] == 0) { + std::shared_ptr allocation_ptr = nullptr; + dense_tensor = std::make_shared( + allocation_ptr, + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } else { + // TODO(dev): we need enhance check for ddims. + dense_tensor = std::make_shared( + std::make_shared(), + phi::DenseTensorMeta(paddle::framework::TransToPhiDataType(dtype), + ddims)); + } + tensor.set_impl(dense_tensor); + } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { + std::shared_ptr selected_rows_tensor = + std::make_shared(); + tensor.set_impl(selected_rows_tensor); + } + + if (!autograd_meta->GetMutableGradNode()) { + autograd_meta->SetGradNode( + std::make_shared(autograd_meta)); + } + + return tensor; +} + +PyObject* GetEmpytyTensorsWithVarDesc(PyObject* self, PyObject* args) { + std::vector result; + std::unordered_map out_tensor_map; + + auto var_desc_list = PyTuple_GetItem(args, 0); + + if (PyList_Check(var_desc_list)) { + Py_ssize_t len = PyList_Size(var_desc_list); + for (Py_ssize_t i = 0; i < len; i++) { + auto var_desc = PyObjectCast( + PyList_GetItem(var_desc_list, i)); + auto var_name = var_desc.Name(); + if (out_tensor_map.find(var_name) == out_tensor_map.end()) { + paddle::Tensor tensor = CreateTensorFromVarDesc(var_desc); + out_tensor_map[var_name] = tensor; + result.emplace_back(tensor); + } else { + result.emplace_back(out_tensor_map[var_name]); + } + } + } else if (PyTuple_Check(var_desc_list)) { + Py_ssize_t len = PyTuple_Size(var_desc_list); + for (Py_ssize_t i = 0; i < len; i++) { + auto var_desc = PyObjectCast( + PyTuple_GetItem(var_desc_list, i)); + auto var_name = var_desc.Name(); + if (out_tensor_map.find(var_name) == out_tensor_map.end()) { + paddle::Tensor tensor = CreateTensorFromVarDesc(var_desc); + out_tensor_map[var_name] = tensor; + result.emplace_back(tensor); + } else { + result.emplace_back(out_tensor_map[var_name]); + } + } + } else if (var_desc_list != Py_None) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Argument of CreateTensorsWithVarDesc must be list of VarDesc, but got " + "%s", + (reinterpret_cast(var_desc_list->ob_type))->tp_name)); + } + return ToPyObject(result); +} + +paddle::Tensor CreateTensorFromOpResult(const pir::OpResult& op_result) { + auto tensor = paddle::Tensor(); + + auto dims = phi::vectorize(GetOpResultDims(op_result)); + auto ddims = phi::make_ddim(dims); + auto autograd_meta = egr::EagerUtils::autograd_meta(&tensor); + autograd_meta->SetPersistable(false); + autograd_meta->SetStopGradient( + GetOpResultBoolAttr(op_result, kAttrStopGradients)); + + if (op_result.type().isa()) { + // TODO(jiabin): Maybe support LOD later + std::shared_ptr dense_tensor = nullptr; + auto dtype = paddle::dialect::TransToPhiDataType( + op_result.type().dyn_cast().dtype()); + + if (dims.size() == 1 && dims[0] == 0) { + std::shared_ptr allocation_ptr = nullptr; + dense_tensor = std::make_shared( + allocation_ptr, phi::DenseTensorMeta(dtype, ddims)); + } else { + // TODO(dev): we need enhance check for ddims. + dense_tensor = std::make_shared( + std::make_shared(), + phi::DenseTensorMeta(dtype, ddims)); + } + tensor.set_impl(dense_tensor); + } else if (op_result.type().isa()) { + std::shared_ptr selected_rows_tensor = + std::make_shared(); + tensor.set_impl(selected_rows_tensor); + } + + if (!autograd_meta->GetMutableGradNode()) { + autograd_meta->SetGradNode( + std::make_shared(autograd_meta)); + } + + return tensor; +} + +PyObject* GetEmpytyTensorsWithOpResult(PyObject* self, PyObject* args) { + std::vector result; + std::unordered_map out_tensor_map; + + auto op_result_list = PyTuple_GetItem(args, 0); + + if (PyList_Check(op_result_list)) { + Py_ssize_t len = PyList_Size(op_result_list); + for (Py_ssize_t i = 0; i < len; i++) { + auto op_result = + PyObjectCast(PyList_GetItem(op_result_list, i)); + if (out_tensor_map.find(op_result) == out_tensor_map.end()) { + paddle::Tensor tensor = CreateTensorFromOpResult(op_result); + out_tensor_map[op_result] = tensor; + result.emplace_back(tensor); + } else { + result.emplace_back(out_tensor_map[op_result]); + } + } + } else if (PyTuple_Check(op_result_list)) { + Py_ssize_t len = PyTuple_Size(op_result_list); + for (Py_ssize_t i = 0; i < len; i++) { + auto op_result = + PyObjectCast(PyTuple_GetItem(op_result_list, i)); + if (out_tensor_map.find(op_result) == out_tensor_map.end()) { + paddle::Tensor tensor = CreateTensorFromOpResult(op_result); + out_tensor_map[op_result] = tensor; + result.emplace_back(tensor); + } else { + result.emplace_back(out_tensor_map[op_result]); + } + } + } else if (op_result_list != Py_None) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Argument of GetTensorsWithOpResultInArgs must be list of OpResult, " + "but got " + "%s", + (reinterpret_cast(op_result_list->ob_type))->tp_name)); + } + + return ToPyObject(result); +} + paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { @@ -2484,5 +2663,24 @@ void DistTensorConverter::operator()(paddle::optional>* x) { } } +static PyMethodDef EagerUtilMethods[] = { + {"create_empty_tensors_with_var_descs", + (PyCFunction)(void (*)(void))GetEmpytyTensorsWithVarDesc, + METH_VARARGS, + "GetEmpytyTensorsWithVarDesc"}, + {"create_empty_tensors_with_op_results", + (PyCFunction)(void (*)(void))GetEmpytyTensorsWithOpResult, + METH_VARARGS, + "GetEmpytyTensorsWithOpResult."}, + {nullptr, nullptr, 0, nullptr}}; + +void BindEagerUtils(PyObject* module) { + if (PyModule_AddFunctions(module, EagerUtilMethods) < 0) { + PADDLE_THROW(platform::errors::Fatal( + "Init Paddle error in BindEagerUtils(PyModule_AddFunctions).")); + return; + } +} + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index bf4be9f2277e3..0cbbc893e98c9 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -54,6 +54,18 @@ namespace pybind { namespace py = ::pybind11; +template +static T PyObjectCast(PyObject* obj) { + try { + return py::cast(py::handle(obj)); + } catch (py::cast_error&) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Python object is not type of %s, the real type is %s", + typeid(T).name(), + obj->ob_type->tp_name)); + } +} + int TensorDtype2NumpyDtype(phi::DataType dtype); bool PyObject_CheckLongOrConvertToLong(PyObject** obj); @@ -381,6 +393,10 @@ std::vector GetTensorListFromPyObject(PyObject* obj, bool allow_none = false); paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj); +PyObject* GetEmpytyTensorsWithVarDesc(PyObject* self, PyObject* args); + +PyObject* GetEmpytyTensorsWithOpResult(PyObject* self, PyObject* args); + // end of Slice related methods std::vector GetScopePtrListFromArgs( @@ -468,5 +484,7 @@ void ConvertAllInputsToDistTensor(const phi::distributed::ProcessMesh* mesh, } void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh); +void BindEagerUtils(PyObject* module); + } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 55efda46c86b0..8ba56008fb2b0 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -72,16 +72,6 @@ std::atomic VarBaseUniqueNameID{0}; namespace py = ::pybind11; -template -static T PyObjectCast(PyObject *obj) { - try { - return py::cast(py::handle(obj)); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Python object is not type of %s", typeid(T).name())); - } -} - class PyVariableWrapperHook : public imperative::VariableWrapperHook { public: explicit PyVariableWrapperHook(PyObject *func) : py_func_(func) { diff --git a/paddle/fluid/pybind/pir.h b/paddle/fluid/pybind/pir.h index 5bc01c63e62e7..9ebaadc07ca09 100644 --- a/paddle/fluid/pybind/pir.h +++ b/paddle/fluid/pybind/pir.h @@ -15,9 +15,16 @@ #pragma once #include +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/pir/core/op_result.h" namespace paddle { namespace pybind { +using pir::OpResult; void BindPir(pybind11::module *m); +phi::DataType GetOpResultDtype(const OpResult &result); +const phi::DDim &GetOpResultDims(const OpResult &result); +bool GetOpResultBoolAttr(const OpResult &self, const std::string &attr_name); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index e9877b5325357..7e2b13d430ecf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -526,18 +526,6 @@ static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { } } -template -static T PyObjectCast(PyObject *obj) { - try { - return py::cast(py::handle(obj)); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Python object is not type of %s, the real type is %s", - typeid(T).name(), - obj->ob_type->tp_name)); - } -} - using PyNameVarBaseMap = std::unordered_map; static std::vector> GetVarBaseList( @@ -814,6 +802,7 @@ PYBIND11_MODULE(libpaddle, m) { BindJit(&m); BindEvalFrame(&m); BindCustomDevicePy(&m); + BindEagerUtils(m.ptr()); // Not used, just make sure cpu_info.cc is linked. phi::backends::cpu::CpuTotalPhysicalMemory(); diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 2b6cca032beae..b61e46d97b469 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -24,7 +24,7 @@ from paddle.base.data_feeder import check_type, convert_dtype from paddle.base.dygraph.base import switch_to_static_graph from paddle.base.framework import _apply_pass, get_flags -from paddle.base.unique_name import guard as UniqueNameGuard +from paddle.base.unique_name import switch from paddle.optimizer.lr import LRScheduler from . import logging_utils @@ -32,7 +32,6 @@ RETURN_NO_VALUE_MAGIC_NUM, backend_guard, construct_grad_names, - tensor_name_guard, ) __all__ = [] @@ -220,33 +219,70 @@ def __init__( self._backend = kwargs.get('backend', None) self._grad_var_names = {} + self._in_var_names = [] + for var in self._inputs: + if isinstance(var, framework.Variable): + self._in_var_names.append(var.desc.name()) + self._out_var_descs = [ + self._outputs[var_id].desc for var_id in self._outputs.var_ids + ] + def __call__(self, inputs): """ Execute static graph by Interpreter and Return dynamic Tensors. """ - with UniqueNameGuard(self._name_generator): - in_vars, out_vars, in_var_names = self._prepare(inputs) - self._cast_fp16_if_pure_fp16(in_vars) - attrs = self._prepare_attributes() - attrs.extend(["x_names", in_var_names]) - - self._sync_lr_value_with_scheduler() - - with tensor_name_guard(in_vars, in_var_names): - _legacy_C_ops.run_program( - self._valid_vars(in_vars), - self._valid_vars(self._params), - self._valid_vars(out_vars), - self._create_scope_vec( - program_id=self.program_id, use_scope_cache=True - ), - self._cuda_graph_vec, - *attrs - ) + old_generator, old_para_name_checker = switch(self._name_generator) + + in_vars, in_var_names = self._prepare_inputs(inputs) + out_vars = self._prepare_outputs() + self._cast_fp16_if_pure_fp16(in_vars) + attrs = self._prepare_attributes() + attrs.extend(["x_names", in_var_names]) + + self._sync_lr_value_with_scheduler() + + _legacy_C_ops.run_program( + self._valid_vars(in_vars), + self._valid_vars(self._params), + self._valid_vars(out_vars), + self._create_scope_vec( + program_id=self.program_id, use_scope_cache=True + ), + self._cuda_graph_vec, + *attrs + ) - self._update_stop_gradient(out_vars) - restored_nest_out = self._restore_out(out_vars) - return self._remove_no_value(restored_nest_out) + restored_nest_out = self._restore_out(out_vars) + restored_nest_out = self._remove_no_value(restored_nest_out) + + switch(old_generator, old_para_name_checker) + return restored_nest_out + + def sot_call(self, inputs): + """ + In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up + """ + old_generator, old_para_name_checker = switch(self._name_generator) + + out_vars = self._prepare_outputs() + self._cast_fp16_if_pure_fp16(inputs) + attrs = self._prepare_attributes() + attrs.extend(["x_names", self._in_var_names]) + self._sync_lr_value_with_scheduler() + + _legacy_C_ops.run_program( + self._valid_vars(inputs), + self._valid_vars(self._params), + self._valid_vars(out_vars), + self._create_scope_vec( + program_id=self.program_id, use_scope_cache=True + ), + self._cuda_graph_vec, + *attrs + ) + + switch(old_generator, old_para_name_checker) + return out_vars def _sync_lr_value_with_scheduler(self): """Update lr_var value with calculated by lr_scheduler.""" @@ -895,7 +931,7 @@ def _parse_skip_gc_vars(self, program, backward_program=None): skip_vars.append(var_name) return skip_vars - def _prepare(self, inputs): + def _prepare_inputs(self, inputs): """ Prepare inputs, outputs, attrs. """ @@ -932,32 +968,12 @@ def _prepare(self, inputs): input_var_names.append(self._inputs[i].desc.name()) input_vars.append(var) - # mapping from name(string) -> Tensor - out_tensor_map = {} - - def create_out(var_id): - var = self._outputs[var_id] - assert isinstance(var, framework.Variable) - var_desc = var.desc - - if var_desc.name() in out_tensor_map: - return out_tensor_map[var_desc.name()] - - out = core.eager.Tensor( - var_desc.dtype(), - var_desc.shape(), - var_desc.name(), - var_desc.type(), - False, - ) - out.stop_gradient = var.stop_gradient - out_tensor_map[var_desc.name()] = out - return out + return input_vars, input_var_names - # Create Tensor to receive output data. - out_vars = list(map(create_out, self._outputs.var_ids)) - - return input_vars, out_vars, input_var_names + def _prepare_outputs(self): + return paddle.framework.core.create_empty_tensors_with_var_descs( + self._out_var_descs + ) def _create_scope_vec(self, program_id=None, use_scope_cache=False): inner_scope = self._get_scope( diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 1aa47a61c19d8..031774a25fb6e 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -448,7 +448,8 @@ def __call__(self, inputs): """ Execute static graph by Interpreter and Return dynamic Tensors. """ - in_vars, out_vars = self._prepare(inputs) + in_vars = self._prepare_inputs(inputs) + out_vars = self._prepare_outputs() attrs = self._prepare_attributes() _legacy_C_ops.pir_run_program( self._valid_vars(in_vars), @@ -460,10 +461,27 @@ def __call__(self, inputs): self._cuda_graph_vec, *attrs, ) - self._update_stop_gradient(out_vars) restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) + def sot_call(self, inputs): + """ + In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up + """ + out_vars = self._prepare_outputs() + attrs = self._prepare_attributes() + _legacy_C_ops.pir_run_program( + self._valid_vars(inputs), + self._valid_vars(self._params), + self._valid_vars(out_vars), + self._create_scope_vec( + program_id=self.program_id, use_scope_cache=True + ), + self._cuda_graph_vec, + *attrs, + ) + return out_vars + @cached_property def origin_runable_program(self): inputs = list(self._inputs.var_list) @@ -848,7 +866,7 @@ def _prepare_attributes(self): ) return attrs - def _prepare(self, inputs): + def _prepare_inputs(self, inputs): """ Prepare inputs, outputs, attrs. """ @@ -881,34 +899,12 @@ def _prepare(self, inputs): else: continue input_vars.append(var) + return input_vars - # mapping from name(string) -> Tensor - out_tensor_map = {} - - def create_out(var): - assert isinstance(var, OpResult) - - if id(var) in out_tensor_map: - return out_tensor_map[id(var)] - - if var.is_dense_tensor_type(): - tensor_type = paddle.dtype(7) # LOD TENSOR - else: - tensor_type = paddle.dtype(8) # SELECT ROW TENSOR - out = core.eager.Tensor( - framework.paddle_type_to_proto_type[var.dtype], - var.shape, - "", - tensor_type, - False, - ) - out.stop_gradient = var.stop_gradient - out_tensor_map[id(var)] = out - return out - - # Create Tensor to receive output data. - out_vars = list(map(create_out, self._outputs.var_list)) - return input_vars, out_vars + def _prepare_outputs(self): + return paddle.framework.core.create_empty_tensors_with_op_results( + self._outputs.var_list + ) def _create_scope_vec(self, program_id=None, use_scope_cache=False): inner_scope = self._get_scope( diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py index b189f9ce2278d..98af0b9b712f1 100644 --- a/python/paddle/jit/sot/symbolic/compile_cache.py +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -113,7 +113,7 @@ def __call__(self, *args, **kwargs): else: # Speed up Resnet from 0.0068 --> 0.0057 with EventGuard("FallbackWrapper: call partial_program"): - outputs = self.partial_program(*args, **kwargs) + outputs = self.partial_program.sot_call(*args, **kwargs) clear_eager_tensor_name(outputs) log_do( diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index ed3519ce17cd6..c5e9696ae7dfc 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -290,7 +290,16 @@ def do_train(self, to_static): for batch_id, data in enumerate(data_loader()): start_time = time.time() - img, label = data + img_, label = data + + expected_place = paddle.framework._current_expected_place() + if img_.stop_gradient and not img_.place._equals( + expected_place + ): + img = img_._copy_to(expected_place, False) + img.stop_gradient = True + else: + img = img_ pred = resnet(img) loss = paddle.nn.functional.cross_entropy( diff --git a/test/legacy_test/test_eager_run_program.py b/test/legacy_test/test_eager_run_program.py index be00a4f83c05c..6e57a4fb590d7 100644 --- a/test/legacy_test/test_eager_run_program.py +++ b/test/legacy_test/test_eager_run_program.py @@ -64,6 +64,7 @@ def _create_out(var): var_desc.type(), False, ) + out.stop_gradient = False return out