diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 59501150d6479..093876a402763 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -168,8 +168,11 @@ UniqueFlattendCUDATensor(const Context& context, #ifdef PADDLE_WITH_HIP hipMemset(inv_loc_data_ptr, 0, sizeof(IndexT)); #else - cudaMemsetAsync(inv_loc_data_ptr, 0, sizeof(IndexT), context.stream()); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault #endif + +#ifdef PADDLE_WITH_HIP size_t temp_storage_bytes = 0; cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, @@ -185,6 +188,12 @@ UniqueFlattendCUDATensor(const Context& context, inv_loc_data_ptr, num_input, context.stream()); +#else + thrust::inclusive_scan(exec_policy, + inv_loc_data_ptr, + inv_loc_data_ptr + num_input, + inv_loc_data_ptr); +#endif thrust::scatter(exec_policy, inv_loc_data_ptr, inv_loc_data_ptr + num_input,