Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtype] add fp16 support for dist_kernel #56184

Merged
merged 14 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/phi/kernels/dist_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
PD_REGISTER_KERNEL(dist_grad,
GPU,
ALL_LAYOUT,
phi::DistGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16也记得注册一下,另外添加对应单测

#endif
90 changes: 54 additions & 36 deletions paddle/phi/kernels/gpu/dist_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/dist_kernel.h"
#include <algorithm>

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
Expand All @@ -24,47 +27,53 @@ namespace phi {

#define FULL_MASK 0xffffffff

template <typename T>
template <typename Tx, typename Ty = Tx>
struct ZeroOrderFunctor {
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
HOSTDEVICE explicit inline ZeroOrderFunctor() {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(x != y);
}
};

template <typename T>
template <typename Tx, typename Ty = Tx>
struct OtherOrderFunctor {
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& p_order)
: p_order_(p_order) {}

HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(
pow(abs(static_cast<Ty>(x) - static_cast<Ty>(y)), p_order_));
}

private:
T p_order_;
Ty p_order_;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量的命名请参考google-cpp-styleguide


template <typename T>
template <typename Tx, typename Ty = Tx>
struct PowFunctor {
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
HOSTDEVICE explicit inline PowFunctor(const Ty& p_order)
: p_order_(p_order) {}
HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tx>(pow(static_cast<Ty>(x), p_order_));
}
T p_order_;
Ty p_order_;
};

template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
T sum_val = 0;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT sum_val(0.0);
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += func(x[i], y[i]);
}

__syncthreads();
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = sum_val;
out[blockIdx.x] = static_cast<T>(sum_val);
}
}

Expand All @@ -73,16 +82,17 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = -1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT max_val = std::numeric_limits<MT>::min();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = max(max_val, abs(x[i] - y[i]));
max_val = max(max_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}

__syncthreads();
max_val = phi::funcs::BlockReduceMax<T>(max_val, FULL_MASK);
max_val = phi::funcs::BlockReduceMax<MT>(max_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = max_val;
out[blockIdx.x] = static_cast<T>(max_val);
}
}

Expand All @@ -91,16 +101,17 @@ __global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = 1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT min_val = std::numeric_limits<MT>::max();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = min(min_val, abs(x[i] - y[i]));
min_val = min(min_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}

__syncthreads();
min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK);
min_val = phi::funcs::BlockReduceMin<MT>(min_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = min_val;
out[blockIdx.x] = static_cast<T>(min_val);
}
}

Expand All @@ -110,6 +121,7 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
Expand All @@ -130,10 +142,9 @@ void DistKernel(const Context& dev_ctx,
if (p == 0) {
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T, MT>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
Expand All @@ -150,19 +161,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);

} else {
T p_order = static_cast<T>(p);
MT p_order = static_cast<MT>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);

const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
T p_order_ = static_cast<T>(1. / p_order);
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
}

} else {
Expand All @@ -173,4 +184,11 @@ void DistKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
PD_REGISTER_KERNEL(dist,
GPU,
ALL_LAYOUT,
phi::DistKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
12 changes: 8 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}

Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -701,8 +701,12 @@ def dist(x, y, p=2, name=None):
if in_dynamic_mode():
return _C_ops.dist(x, y, p)

check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(
x, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(
y, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,bfloat16类型的支持

check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down
80 changes: 80 additions & 0 deletions test/legacy_test/test_dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,86 @@ def init_case(self):
self.p = 1.5


class TestDistBF16Op(OpTest):
def init_data_type(self):
self.data_type = 'bfloat16'


class TestDistBF16OpCase1(TestDistBF16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0


class TestDistBF16OpCase2(TestDistBF16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0


class TestDistBF16OpCase3(TestDistBF16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")


class TestDistBF16OpCase4(TestDistBF16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")


class TestDistBF16OpCase5(TestDistBF16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5


class TestDistFP16Op(OpTest):
def init_data_type(self):
self.data_type = 'float16'


class TestDistFP16OpCase1(TestDistFP16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0


class TestDistFP16OpCase2(TestDistFP16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0


class TestDistFP16OpCase3(TestDistFP16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")


class TestDistFP16OpCase4(TestDistFP16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")


class TestDistFP16OpCase5(TestDistFP16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5


class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
Expand Down