Skip to content

Commit

Permalink
[dtype] add fp16 support for dist_kernel (#56184)
Browse files Browse the repository at this point in the history
* [dtype] add fp16 support for dist_kernel

* fix typo

* fix CE

* fix CE

* fix CE

* fix CE

* fix CE

* refactor

* fix CE

* fix CE

* fix varname

* add bf16

* add ut for bf16

* fix CE
  • Loading branch information
jinyouzhi committed Aug 15, 2023
1 parent ac44d79 commit ea590ef
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 42 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/dist_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
PD_REGISTER_KERNEL(dist_grad,
GPU,
ALL_LAYOUT,
phi::DistGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#endif
90 changes: 54 additions & 36 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/dist_kernel.h"
#include <algorithm>

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
Expand All @@ -24,47 +27,53 @@ namespace phi {

#define FULL_MASK 0xffffffff

template <typename T>
template <typename Tx, typename Ty = Tx>
struct ZeroOrderFunctor {
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
HOSTDEVICE explicit inline ZeroOrderFunctor() {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(x != y);
}
};

template <typename T>
template <typename Tx, typename Ty = Tx>
struct OtherOrderFunctor {
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& p_order)
: p_order_(p_order) {}

HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(
pow(abs(static_cast<Ty>(x) - static_cast<Ty>(y)), p_order_));
}

private:
T p_order_;
Ty p_order_;
};

template <typename T>
template <typename Tx, typename Ty = Tx>
struct PowFunctor {
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
HOSTDEVICE explicit inline PowFunctor(const Ty& p_order)
: p_order_(p_order) {}
HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tx>(pow(static_cast<Ty>(x), p_order_));
}
T p_order_;
Ty p_order_;
};

template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
T sum_val = 0;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT sum_val(0.0);
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += func(x[i], y[i]);
}

__syncthreads();
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = sum_val;
out[blockIdx.x] = static_cast<T>(sum_val);
}
}

Expand All @@ -73,16 +82,17 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = -1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT max_val = std::numeric_limits<MT>::min();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = max(max_val, abs(x[i] - y[i]));
max_val = max(max_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}

__syncthreads();
max_val = phi::funcs::BlockReduceMax<T>(max_val, FULL_MASK);
max_val = phi::funcs::BlockReduceMax<MT>(max_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = max_val;
out[blockIdx.x] = static_cast<T>(max_val);
}
}

Expand All @@ -91,16 +101,17 @@ __global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = 1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT min_val = std::numeric_limits<MT>::max();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = min(min_val, abs(x[i] - y[i]));
min_val = min(min_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}

__syncthreads();
min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK);
min_val = phi::funcs::BlockReduceMin<MT>(min_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = min_val;
out[blockIdx.x] = static_cast<T>(min_val);
}
}

Expand All @@ -110,6 +121,7 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
Expand All @@ -130,10 +142,9 @@ void DistKernel(const Context& dev_ctx,
if (p == 0) {
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T, MT>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
Expand All @@ -150,19 +161,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else {
T p_order = static_cast<T>(p);
MT p_order = static_cast<MT>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);

const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
T p_order_ = static_cast<T>(1. / p_order);
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
}

} else {
Expand All @@ -173,4 +184,11 @@ void DistKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
PD_REGISTER_KERNEL(dist,
GPU,
ALL_LAYOUT,
phi::DistKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
12 changes: 8 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}
Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -701,8 +701,12 @@ def dist(x, y, p=2, name=None):
if in_dynamic_mode():
return _C_ops.dist(x, y, p)

check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(
x, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(
y, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down
80 changes: 80 additions & 0 deletions test/legacy_test/test_dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,86 @@ def init_case(self):
self.p = 1.5


class TestDistBF16Op(OpTest):
def init_data_type(self):
self.data_type = 'bfloat16'


class TestDistBF16OpCase1(TestDistBF16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0


class TestDistBF16OpCase2(TestDistBF16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0


class TestDistBF16OpCase3(TestDistBF16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")


class TestDistBF16OpCase4(TestDistBF16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")


class TestDistBF16OpCase5(TestDistBF16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5


class TestDistFP16Op(OpTest):
def init_data_type(self):
self.data_type = 'float16'


class TestDistFP16OpCase1(TestDistFP16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0


class TestDistFP16OpCase2(TestDistFP16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0


class TestDistFP16OpCase3(TestDistFP16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")


class TestDistFP16OpCase4(TestDistFP16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")


class TestDistFP16OpCase5(TestDistFP16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5


class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
Expand Down

0 comments on commit ea590ef

Please sign in to comment.