Skip to content

Commit

Permalink
Merge pull request #7 from xymyeah/reduce_mean_for_xpu
Browse files Browse the repository at this point in the history
fix reduce_mean_op_xpu op bug for cvrq
  • Loading branch information
tiancaitzp authored Sep 13, 2023
2 parents de6dd8f + 5c961ae commit d2010f0
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {

bool reduce_all = ctx.Attr<bool>("reduce_all");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");

std::vector<int> xdims;
for (int i = 0; i < input->dims().size(); i++) {
Expand All @@ -114,6 +115,13 @@ class ReduceMeanGradXPUKernel : public framework::OpKernel<T> {
reduce_numel *= xdims[d];
}

if (keep_dim != true) {
sort(reduce_dims.begin(), reduce_dims.end());
for (auto& d : reduce_dims) {
ydims.insert(ydims.begin() + d, 1);
}
}

float val = 1.0f / static_cast<float>(reduce_numel);

auto& dev_ctx = ctx.template device_context<DeviceContext>();
Expand Down

0 comments on commit d2010f0

Please sign in to comment.