From 58a38a0d2c1066615c1908facae38db50e982fbf Mon Sep 17 00:00:00 2001 From: "Tian Zheng (Engrg-Hardware 1)" Date: Thu, 30 Nov 2023 03:36:41 +0000 Subject: [PATCH] Review changes --- paddle/phi/infermeta/fusion.cc | 21 ++++++++++++------- .../gpu/fused_dconv_drelu_dbn_kernel.cu | 2 +- .../test_fused_dconv_drelu_dbn_op.py | 2 +- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 740b5cf24ad3b..cb4de6f93f600 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2247,18 +2247,25 @@ void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output, !!bn2_beta, !!bn2_input)); } - grad_weight->set_dims(weight.dims()); - grad_bn1_input->set_dims(bn1_input.dims()); - grad_bn1_gamma->set_dims(bn1_gamma.dims()); - grad_bn1_beta->set_dims(bn1_beta.dims()); + + auto set_unchanged_meta = [](MetaTensor* out, const MetaTensor& input) { + out->set_dims(input.dims()); + out->set_dtype(input.dtype()); + out->set_layout(input.layout()); + }; + + set_unchanged_meta(grad_weight, weight); + set_unchanged_meta(grad_bn1_input, bn1_input); + set_unchanged_meta(grad_bn1_gamma, bn1_gamma); + set_unchanged_meta(grad_bn1_beta, bn1_beta); if (grad_bn2_input) { - grad_bn2_input->set_dims(bn1_input.dims()); + set_unchanged_meta(grad_bn2_input, bn1_input); } if (grad_bn2_gamma) { - grad_bn2_gamma->set_dims(bn1_gamma.dims()); + set_unchanged_meta(grad_bn2_gamma, bn1_gamma); } if (grad_bn2_beta) { - grad_bn2_beta->set_dims(bn1_beta.dims()); + set_unchanged_meta(grad_bn2_beta, bn1_beta); } } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu index e194ae3f4756b..6b041753c1a38 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dconv_drelu_dbn_kernel.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py index 8862404b6acb9..d038d8a83caa2 100644 --- a/test/legacy_test/test_fused_dconv_drelu_dbn_op.py +++ b/test/legacy_test/test_fused_dconv_drelu_dbn_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.