Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.28】为 paddle.clip 进行功能增强 #69193

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/kernels/clip_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ template <typename T, typename Context>
void ClipGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* x_grad);

} // namespace phi
6 changes: 2 additions & 4 deletions paddle/phi/kernels/clip_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@

#pragma once

#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"

namespace phi {

template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out);

} // namespace phi
24 changes: 23 additions & 1 deletion paddle/phi/kernels/cpu/clip_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,29 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void ClipGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* x_grad) {
int64_t numel = out_grad.numel();
auto *d_x_data = dev_ctx.template Alloc<T>(x_grad);
const auto* x_data = x.data<T>();
const auto* min_data = min.data<T>();
const auto* max_data = max.data<T>();
auto* d_out_data = out_grad.data<T>();

for (int i = 0; i < numel; i++) {
d_x_data[i] = ( x_data[i] > min_data[i] && x_data[i] < max_data[i]) ? d_out_data[i] : static_cast<T>(0);
}
}

} // namespace phi

PD_REGISTER_KERNEL(clip_grad,
CPU,
Expand Down
31 changes: 30 additions & 1 deletion paddle/phi/kernels/cpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,36 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"

namespace phi {

template <typename T, typename Context>
void ClipKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
const T* x_data = x.data<T>();
const T* min_data = min.data<T>();
const T* max_data = max.data<T>();
auto x_numel = x.numel();

T* out_data = ctx.template Alloc<T>(out);

for (int i = 0; i < x_numel; i++) {
PADDLE_ENFORCE_LE(
min_data[i],
max_data[i],
errors::InvalidArgument("max should be greater than or equal to min. "
"But received min = %f, max = %f",
static_cast<float>(min_data[i]),
static_cast<float>(max_data[i])));

out_data[i] = x_data[i] < min_data[i] ? min_data[i] : x_data[i] > max_data[i] ? max_data[i] : x_data[i];
}
}

} // namespace phi

PD_REGISTER_KERNEL(
clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {}
28 changes: 27 additions & 1 deletion paddle/phi/kernels/gpu/clip_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,36 @@

#include "paddle/phi/kernels/clip_grad_kernel.h"

#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h"

namespace phi {

template <typename T>
class ClipGradGPUFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T min_, const T max_) const {
return (y > min_ && y < max_) ? x : static_cast<T>(0);
}
};

template <typename T, typename Context>
void ClipGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* x_grad) {
std::vector<const DenseTensor*> ins = {&out_grad, &x, &min, &max};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = ClipGradGPUFunctor<T>();
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}

} // namespace phi

PD_REGISTER_KERNEL(clip_grad,
GPU,
Expand Down
29 changes: 28 additions & 1 deletion paddle/phi/kernels/gpu/clip_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,34 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace phi {

// Cond
template <typename T>
struct ClipGPUFunctor {
inline HOSTDEVICE T operator()(const T x, const T min_, const T max_) const {
return x < min_ ? min_ : x > max_ ? max_ : x;
}
};

template <typename T, typename Context>
void ClipKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
std::vector<const DenseTensor*> ins = {&x, &min, &max};
std::vector<DenseTensor*> outs = {out};
ctx.template Alloc<T>(out);

ClipGPUFunctor<T> func;
funcs::ElementwiseKernel<T, ClipGPUFunctor<T>, 1>(ctx, ins, &outs, func);
}

} // namespace phi

PD_REGISTER_KERNEL(clip,
GPU,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/impl/clip_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void ClipKernel(const Context& dev_ctx,
const Scalar& min,
const Scalar& max,
DenseTensor* out) {
auto max_ = max.to<T>();
auto min_ = min.to<T>();
auto max_ = max.data.to<T>();
auto min_ = min.data.to<T>();

PADDLE_ENFORCE_LE(
min_,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/onednn/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ namespace phi {
template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();

funcs::ClipOneDNNHandler<T> handler(
min, max, onednn_engine, dev_ctx.GetPlace(), &x);
&min, &max, onednn_engine, dev_ctx.GetPlace(), &x);

auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/xpu/clip_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace phi {
template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
const DenseTensor& min,
const DenseTensor& max,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
using XPUDataType = typename XPUTypeTrait<T>::Type;
Expand All @@ -36,8 +36,8 @@ void ClipKernel(const Context& dev_ctx,
x_data,
out_data,
x.numel(),
static_cast<XPUDataType>(min.to<T>()),
static_cast<XPUDataType>(max.to<T>()));
static_cast<XPUDataType>(min.data.to<T>()),
static_cast<XPUDataType>(max.data.to<T>()));

PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@
func : cholesky_solve_grad

- backward_op : clip_double_grad
forward : clip_grad (Tensor x, Tensor grad_out, Scalar min = 0., Scalar max = 0.) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, Scalar min = 0., Scalar max = 0.)
forward : clip_grad (Tensor x, Tensor grad_out, Tensor min, Tensor max) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad, Tensor min, Tensor max)
output : Tensor(grad_out_grad)
infer_meta :
func : UnchangedInferMeta
Expand All @@ -391,8 +391,8 @@
data_type : x

- backward_op : clip_grad
forward : clip (Tensor x, Scalar min, Scalar max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Scalar min = 0., Scalar max = 0.)
forward : clip (Tensor x, Tensor min, Tensor max) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Tensor min, Tensor max)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
9 changes: 1 addition & 8 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -589,16 +589,9 @@
- op : clip
backward : clip_grad, clip_double_grad
inputs :
x : X
{x : X, min : Min, max : Max}
outputs :
out : Out
scalar :
min :
data_type : float
tensor_name : Min
max :
data_type : float
tensor_name : Max
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : clip
args : (Tensor x, Scalar(float) min, Scalar(float) max)
args : (Tensor x, Tensor min, Tensor max)
output : Tensor(out)
inplace : (x -> out)
infer_meta :
Expand Down
Loading