Skip to content

Commit

Permalink
Review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Zheng committed Nov 30, 2023
1 parent 2226cf9 commit 58a38a0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
21 changes: 14 additions & 7 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2247,18 +2247,25 @@ void FusedDconvDreluDbnInferMeta(const MetaTensor& grad_output,
!!bn2_beta,
!!bn2_input));
}
grad_weight->set_dims(weight.dims());
grad_bn1_input->set_dims(bn1_input.dims());
grad_bn1_gamma->set_dims(bn1_gamma.dims());
grad_bn1_beta->set_dims(bn1_beta.dims());

auto set_unchanged_meta = [](MetaTensor* out, const MetaTensor& input) {
out->set_dims(input.dims());
out->set_dtype(input.dtype());
out->set_layout(input.layout());
};

set_unchanged_meta(grad_weight, weight);
set_unchanged_meta(grad_bn1_input, bn1_input);
set_unchanged_meta(grad_bn1_gamma, bn1_gamma);
set_unchanged_meta(grad_bn1_beta, bn1_beta);
if (grad_bn2_input) {
grad_bn2_input->set_dims(bn1_input.dims());
set_unchanged_meta(grad_bn2_input, bn1_input);
}
if (grad_bn2_gamma) {
grad_bn2_gamma->set_dims(bn1_gamma.dims());
set_unchanged_meta(grad_bn2_gamma, bn1_gamma);
}
if (grad_bn2_beta) {
grad_bn2_beta->set_dims(bn1_beta.dims());
set_unchanged_meta(grad_bn2_beta, bn1_beta);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_fused_dconv_drelu_dbn_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 58a38a0

Please sign in to comment.