diff --git a/include/matx/transforms/reduce.h b/include/matx/transforms/reduce.h index 2e03f601..8410819c 100644 --- a/include/matx/transforms/reduce.h +++ b/include/matx/transforms/reduce.h @@ -52,6 +52,14 @@ union HalfBits { __nv_bfloat16 b; }; +union PascalHalfBits { + constexpr PascalHalfBits(unsigned short x) : i(x) {} + PascalHalfBits() = default; + unsigned int i; + __half h[2]; + __nv_bfloat16 b[2]; +}; + #ifdef __CUDACC__ /** * Warp shuffle down with a complex float @@ -159,20 +167,50 @@ __MATX_DEVICE__ __MATX_INLINE__ void atomicMax(float *addr, float val) */ __MATX_DEVICE__ __MATX_INLINE__ __nv_bfloat16 atomicMax(__nv_bfloat16 *addr, __nv_bfloat16 val) { +#if __CUDA_ARCH__ > 600 HalfBits tmpval; HalfBits old; - unsigned short *address_as_ushort = (unsigned short *)addr; + unsigned short *address_as_other = (unsigned short *)addr; + unsigned short assumed; - old.i = *address_as_ushort; + old.i = *address_as_other; tmpval.b = val; // nan should be ok here but should verify while (val > old.b) { assumed = old.i; - old.b = static_cast(atomicCAS(address_as_ushort, assumed, tmpval.i)); + old.b = static_cast(atomicCAS(address_as_other, assumed, tmpval.i)); + } + + return old.b; +#else // Pascal doesn't support short atomicCAS + PascalHalfBits tmpval; + PascalHalfBits old; + unsigned int *address_as_other; + int offset; + + // We need to move our pointer back + if ((uintptr_t)addr & 0x10) { + address_as_other = (unsigned int *)(reinterpret_cast(addr) - 2); + offset = 1; + } + else { + address_as_other = (unsigned int *)(addr); + offset = 0; + } + + unsigned short assumed; + old.i = *address_as_other; + tmpval.b[offset] = val; + + // nan should be ok here but should verify + while (val > old.b[offset]) { + assumed = old.i; + old.b[offset] = static_cast(atomicCAS(address_as_other, assumed, tmpval.i)); } - return old.b; + return old.b[offset]; +#endif }; @@ -188,20 +226,50 @@ __MATX_DEVICE__ __MATX_INLINE__ __nv_bfloat16 atomicMax(__nv_bfloat16 *addr, __n */ __MATX_DEVICE__ __MATX_INLINE__ __half atomicMax(__half *addr, __half val) { +#if __CUDA_ARCH__ > 600 HalfBits tmpval; HalfBits old; - unsigned short *address_as_ushort = (unsigned short *)addr; + unsigned short *address_as_other = (unsigned short *)addr; + unsigned short assumed; - old.i = *address_as_ushort; + old.i = *address_as_other; tmpval.h = val; // nan should be ok here but should verify while (val > old.h) { assumed = old.i; - old.h = atomicCAS(address_as_ushort, assumed, tmpval.i); + old.h = atomicCAS(address_as_other, assumed, tmpval.i); + } + + return old.h; +#else // Pascal doesn't support short atomicCAS + PascalHalfBits tmpval; + PascalHalfBits old; + unsigned int *address_as_other; + int offset; + + // We need to move our pointer back to align to a 2b boundary + if ((uintptr_t)addr & 0x10) { + address_as_other = (unsigned int *)(reinterpret_cast(addr) - 2); + offset = 1; + } + else { + address_as_other = (unsigned int *)(addr); + offset = 0; + } + + unsigned short assumed; + old.i = *address_as_other; + tmpval.h[offset] = val; + + // nan should be ok here but should verify + while (val > old.h[offset]) { + assumed = old.i; + old.h[offset] = atomicCAS(address_as_other, assumed, tmpval.i); } - return old.h; + return old.h[offset]; +#endif }; /**