Skip to content

Commit

Permalink
miscellaneous fixes for python APIs (PaddlePaddle#26)
Browse files Browse the repository at this point in the history
* add placeholder for unittests

* resize fft inputs before computation is n or s is provided.

* add complex kernels for pad and pad_grad

* simplify argument checking.

* add type promotion

* add int to float or complex promotion

* fix output data type for static mode

* fix fft's input dtype dispatch, import fft to paddle
  • Loading branch information
Feiyu Chan authored Sep 9, 2021
1 parent 75c2ca0 commit 7a22378
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 168 deletions.
25 changes: 21 additions & 4 deletions paddle/fluid/operators/pad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/pad_op.h"
#include <memory>
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -170,20 +171,36 @@ REGISTER_OP_CPU_KERNEL(
pad, ops::PadKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int>,
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::PadKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::PadGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int>,
ops::PadKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
paddle::platform::float16>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
paddle::platform::float16>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/spectral_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ class FFTR2COp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of FFTC2ROp should not be null.", "X"));
"Input(%s) of FFTR2COp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of FFTC2ROp should not be null.", "Out"));
"Output(%s) of FFTR2COp should not be null.", "Out"));
const auto axes = ctx->Attrs().Get<std::vector<int64_t>>("axes");
const bool onesided = ctx->Attrs().Get<bool>("onesided");
if (!onesided) {
Expand Down
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import paddle.static # noqa: F401
import paddle.vision # noqa: F401

from .tensor import fft
from .tensor.random import bernoulli # noqa: F401

from .tensor.attribute import rank # noqa: F401
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/tensor/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def _complex_to_real_dtype(dtype):
return dtype


def _real_to_complex_dtype(dtype):
if dtype == core.VarDesc.VarType.FP32:
return core.VarDesc.VarType.COMPLEX64
elif dtype == core.VarDesc.VarType.FP64:
return core.VarDesc.VarType.COMPLEX128
else:
return dtype


def is_complex(x):
dtype = x.dtype
is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or
Expand All @@ -51,6 +60,16 @@ def is_floating_point(x):
return is_fp_dtype


def is_interger(x):
dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or
dtype == core.VarDesc.VarType.INT16 or
dtype == core.VarDesc.VarType.INT32 or
dtype == core.VarDesc.VarType.INT64)
return is_int_dtype


def real(x, name=None):
"""
Returns a new tensor containing real values of the input tensor.
Expand Down
Loading

0 comments on commit 7a22378

Please sign in to comment.