diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 0befc873ed696..8ccd1b26a3817 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel { public: void operator()(const framework::ExecutionContext &ctx) const { auto *dd_x = ctx.Input("DDX"); + auto *d_out = ctx.Input("DOut"); auto *dd_out = ctx.Output("DDOut"); dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #endif #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); phi::ReshapeDoubleGradKernel( - static_cast(dev_ctx), *dd_x, dd_out); + static_cast(dev_ctx), *d_out, *dd_x, dd_out); } #endif } @@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad, Reshape2DoubleGradInferShapeFunctor, - PD_INFER_META(phi::GeneralUnaryGradInferMeta)); + PD_INFER_META(phi::ReshapeDoubleGradInferMeta)); REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp, ops::ReshapeDoubleGradInplaceInferer, diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 43d7d0393dd78..49e416fd0152d 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) { dx->set_layout(out_grad.layout()); } +void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x_grad_grad, + MetaTensor* out_grad_grad) { + if (out_grad_grad != nullptr) { + out_grad_grad->share_dims(out_grad); + } +} + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 432c1aacfcffe..eff3731bf2253 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x, void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx); +void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, + const MetaTensor& x_grad_grad, + MetaTensor* out_grad_grad); + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/kernels/reshape_grad_kernel.cc b/paddle/phi/kernels/reshape_grad_kernel.cc index 38132966407dc..129a69d4e4e0f 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.cc +++ b/paddle/phi/kernels/reshape_grad_kernel.cc @@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx, template void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, const DenseTensor& x_grad_grad, DenseTensor* out_grad_grad) { ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad); diff --git a/paddle/phi/kernels/reshape_grad_kernel.h b/paddle/phi/kernels/reshape_grad_kernel.h index 4eb3f68337aff..06ec3de15ab22 100644 --- a/paddle/phi/kernels/reshape_grad_kernel.h +++ b/paddle/phi/kernels/reshape_grad_kernel.h @@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx, template void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, const DenseTensor& x_grad_grad, DenseTensor* out_grad_grad); diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index 6b528efe6d056..04f64e4035273 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping( KernelSignature ReshapeDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"}); + return KernelSignature("reshape_double_grad", {"DOut", "DDX"}, {}, {"DDOut"}); } } // namespace phi