Skip to content

Commit

Permalink
[CustomDevice] fix set_constant (#52360)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Mar 31, 2023
1 parent 4e23af7 commit f22b966
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions paddle/phi/kernels/funcs/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,16 @@ void set_constant_with_place<phi::IPUPlace>(const phi::DeviceContext& context,
template <>
void set_constant_with_place<phi::CustomPlace>(
const phi::DeviceContext& context, phi::DenseTensor* tensor, float value) {
PADDLE_THROW(phi::errors::Unimplemented("CustomPlace is not supported"));
phi::DenseTensor tmp_tensor;
tmp_tensor.Resize(tensor->dims());
context.HostAlloc(&tmp_tensor, tensor->dtype());
phi::VisitDataType(tmp_tensor.dtype(),
TensorSetConstantCPU(&tmp_tensor, value));
phi::memory_utils::Copy(tensor->place(),
tensor->data(),
phi::CPUPlace(),
tmp_tensor.data(),
tensor->numel() * phi::SizeOf(tensor->dtype()));
}

template <>
Expand Down Expand Up @@ -230,7 +239,7 @@ void set_constant(const phi::DeviceContext& context,
TensorSetConstantWithPlace func(context, tensor, value);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (context.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
func(phi::CPUPlace());
func(phi::CustomPlace());
return;
}
#endif
Expand Down

0 comments on commit f22b966

Please sign in to comment.