Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part #59674

Merged
merged 25 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1762,8 +1762,8 @@
optional : boxes_num

- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign") -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor values, Tensor out, Tensor out_grad, int axis, str reduce, bool include_self)
output : Tensor(arr_grad), Tensor(values_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2432,7 +2432,7 @@
outputs :
out : Result
attrs :
{axis : Axis, reduce : Reduce}
{axis : Axis, reduce : Reduce, include_self: Include_self}

- op : pylayer
backward : pylayer_grad
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2035,7 +2035,7 @@
backward : psroi_pool_grad

- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a parameter of broadcast in Python API, shall we also add it here as include_self or delete it from API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is processed in the Python interface, so there is no need to pass it into the C interface again

output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
Expand Down
175 changes: 175 additions & 0 deletions paddle/phi/backends/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,181 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
CudaAtomicAdd(imag, val.imag));
}

// For atomicMul.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些atomicMul的计算算法,能提供下参考吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的atomicMul都是参考的前面的atomicAdd以及后面的atomicMin这些,只是把加改成了乘

CUDA_ATOMIC_WRAPPER(Mul, int) {
int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, unsigned int) {
unsigned int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
CUDA_ATOMIC_WRAPPER(Mul, unsigned long long int) { // NOLINT
unsigned long long int old = *address, assumed; // NOLINT

do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否也应该有一个返回值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

CUDA_ATOMIC_WRAPPER(Mul, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
long long int res = *address, old = res; // NOLINT
do {
old = res;
res = (long long int)atomicCAS( // NOLINT
(unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val * (unsigned long long int)old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, float) {
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;

do {
assumed = old;
old = atomicCAS(
address_as_i, assumed, __float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);

return __int_as_float(old);
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT

do {
assumed = old;

old = atomicCAS(address_as_ull,
assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);

return __longlong_as_double(old);
}

#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t mul_to_low_half(uint32_t val, float x) {
phi::dtype::float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t mul_to_high_half(uint32_t val, float x) {
phi::dtype::float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::float16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif

inline static __device__ uint32_t bf16_mul_to_low_half(uint32_t val, float x) {
phi::dtype::bfloat16 low_half;
// The bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_mul_to_high_half(uint32_t val, float x) {
phi::dtype::bfloat16 high_half;
// The bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) {
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}

// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ void CummaxGradKernel(const Context& dev_ctx,
}
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
}
}

Expand All @@ -61,10 +61,10 @@ void CumminGradKernel(const Context& dev_ctx,
}
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
}
}

Expand Down
Loading