Skip to content

Commit

Permalink
【Hackathon 5th No.25】add gammaln api (#59311)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored Dec 28, 2023
1 parent 875fbfb commit beba862
Show file tree
Hide file tree
Showing 15 changed files with 538 additions and 0 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,16 @@
kernel :
func : frame_grad

- backward_op : gammaln_grad
forward : gammaln(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : gammaln_grad

- backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,16 @@
data_type : dtype
backend : place

- op : gammaln
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : gammaln
inplace: (x -> out)
backward : gammaln_grad

- op : gather
args : (Tensor x, Tensor index, Scalar axis=0)
output : Tensor(out)
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/kernels/cpu/gammaln_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gammaln_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {}
22 changes: 22 additions & 0 deletions paddle/phi/kernels/cpu/gammaln_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gammaln_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h"

PD_REGISTER_KERNEL(
gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/gammaln_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void GammalnGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& d_out,
DenseTensor* d_x);
} // namespace phi
26 changes: 26 additions & 0 deletions paddle/phi/kernels/gammaln_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void GammalnKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/kernels/gpu/gammaln_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gammaln_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h"

PD_REGISTER_KERNEL(gammaln_grad,
GPU,
ALL_LAYOUT,
phi::GammalnGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
29 changes: 29 additions & 0 deletions paddle/phi/kernels/gpu/gammaln_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gammaln_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h"

PD_REGISTER_KERNEL(gammaln,
GPU,
ALL_LAYOUT,
phi::GammalnKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
92 changes: 92 additions & 0 deletions paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {
template <typename T>
HOSTDEVICE T digamma(T x) {
static T c = T{8.5};
static T euler_mascheroni = T{0.57721566490153286060};
T r;
T value;
T x2;

if (x <= T{0.0}) {
value = T{0.0};
return value;
}

if (x <= T{0.000001}) {
value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x;
return value;
}

value = T{0.0};
x2 = x;
while (x2 < c) {
value = value - T{1.0} / x2;
x2 = x2 + T{1.0};
}

r = T{1.0} / x2;
value = value + std::log(x2) - T{0.5} * r;

r = r * r;

value = value -
r * (T{1.0} / T{12.0} -
r * (T{1.0} / T{120.0} -
r * (T{1.0} / T{252.0} -
r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0})))));

return value;
}

template <typename T>
struct GammalnGradFunctor {
GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_dout = static_cast<MT>(dout_[idx]);
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(mp_dout * digamma<MT>(mp_x));
}

private:
const T* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
template <typename T, typename Context>
void GammalnGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& d_out,
DenseTensor* d_x) {
auto numel = d_out.numel();
auto* dout_data = d_out.data<T>();
auto* x_data = x.data<T>();
auto* dx_data =
dev_ctx.template Alloc<T>(d_x, static_cast<size_t>(numel * sizeof(T)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
GammalnGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
} // namespace phi
49 changes: 49 additions & 0 deletions paddle/phi/kernels/impl/gammaln_kernel_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {
template <typename T>
struct GammalnFunctor {
GammalnFunctor(const T* x, T* output, int64_t numel)
: x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(std::lgamma(mp_x));
}

private:
const T* x_;
T* output_;
int64_t numel_;
};

template <typename T, typename Context>
void GammalnKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
GammalnFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
} // namespace phi
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@
frac,
frac_,
frexp,
gammaln,
gammaln_,
gcd,
gcd_,
heaviside,
Expand Down Expand Up @@ -773,6 +775,8 @@
'square_',
'divide',
'divide_',
'gammaln',
'gammaln_',
'ceil',
'atan',
'atan_',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@
frac,
frac_,
frexp,
gammaln,
gammaln_,
gcd,
gcd_,
heaviside,
Expand Down Expand Up @@ -668,6 +670,8 @@
'real',
'imag',
'is_floating_point',
'gammaln',
'gammaln_',
'digamma',
'digamma_',
'diagonal',
Expand Down
Loading

0 comments on commit beba862

Please sign in to comment.