Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Ops from complex64/128 to complex<float/double> types. #33133

Merged
merged 6 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions paddle/fluid/operators/eigen/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -56,8 +55,8 @@ INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::complex64);
INSTANTIATION(EigenPad, platform::complex128);
INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION

} // namespace operators
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/eigen/pad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -58,8 +57,8 @@ INSTANTIATION(EigenPad, int64_t);
INSTANTIATION(EigenPad, float);
INSTANTIATION(EigenPad, double);
INSTANTIATION(EigenPad, platform::float16);
INSTANTIATION(EigenPad, platform::complex64);
INSTANTIATION(EigenPad, platform::complex128);
INSTANTIATION(EigenPad, platform::complex<float>);
INSTANTIATION(EigenPad, platform::complex<double>);
#undef INSTANTIATION

} // namespace operators
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/operators/eigen/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -69,8 +67,6 @@ INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::bfloat16);
INSTANTIATION(EigenSlice, platform::complex64);
INSTANTIATION(EigenSlice, platform::complex128);
INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/eigen/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -58,8 +57,8 @@ INSTANTIATION(EigenSlice, int64_t);
INSTANTIATION(EigenSlice, float);
INSTANTIATION(EigenSlice, double);
INSTANTIATION(EigenSlice, platform::float16);
INSTANTIATION(EigenSlice, platform::complex64);
INSTANTIATION(EigenSlice, platform::complex128);
INSTANTIATION(EigenSlice, platform::complex<float>);
INSTANTIATION(EigenSlice, platform::complex<double>);
#undef INSTANTIATION

} // namespace operators
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/operators/kron_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -185,9 +184,9 @@ REGISTER_OP_CPU_KERNEL(
ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
REGISTER_OP_CPU_KERNEL(
Expand All @@ -198,6 +197,6 @@ REGISTER_OP_CPU_KERNEL(
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
11 changes: 5 additions & 6 deletions paddle/fluid/operators/kron_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/kron_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -26,9 +25,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
Expand All @@ -38,6 +37,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
92 changes: 16 additions & 76 deletions paddle/fluid/operators/kron_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ limitations under the License. */
namespace paddle {
namespace operators {

using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;

// Process an element in the output, used with a parallel-for
template <typename T>
struct KronElemFunctor {
Expand Down Expand Up @@ -175,72 +172,13 @@ struct KronGradElemFunctor {
const int ndims_;
};

template <>
struct KronGradElemFunctor<complex64> {
KronGradElemFunctor(const complex64* dout, const complex64* A,
const complex64* B, complex64* dout_a, complex64* dout_b,
const int64_t* stride_dout, const int64_t* stride_a,
const int64_t* stride_b, const int64_t* shape_b,
const int64_t numel_a, const int64_t numel_b,
const int ndims)
: dout_(dout),
A_(A),
B_(B),
dout_a_(dout_a),
dout_b_(dout_b),
stride_dout_(stride_dout),
stride_a_(stride_a),
stride_b_(stride_b),
shape_b_(shape_b),
numel_a_(numel_a),
numel_b_(numel_b),
ndims_(ndims) {}

HOSTDEVICE void operator()(int64_t idx) {
int64_t index = idx;
int64_t index_a = 0;
int64_t index_b = 0;
for (int i = 0; i < ndims_; i++) {
auto pos_i = index / stride_dout_[i];
index = index % stride_dout_[i];
auto pos_ai = pos_i / shape_b_[i];
auto pos_bi = pos_i % shape_b_[i];
index_a += stride_a_[i] * pos_ai;
index_b += stride_b_[i] * pos_bi;
}

if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag);
}
}

private:
const complex64* dout_;
const complex64* A_;
const complex64* B_;
complex64* dout_a_;
complex64* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
const int64_t* shape_b_;
const int64_t numel_a_;
const int64_t numel_b_;
const int ndims_;
};

template <>
struct KronGradElemFunctor<complex128> {
KronGradElemFunctor(const complex128* dout, const complex128* A,
const complex128* B, complex128* dout_a,
complex128* dout_b, const int64_t* stride_dout,
template <typename T>
struct KronGradElemFunctor<platform::complex<T>> {
KronGradElemFunctor(const platform::complex<T>* dout,
const platform::complex<T>* A,
const platform::complex<T>* B,
platform::complex<T>* dout_a,
platform::complex<T>* dout_b, const int64_t* stride_dout,
const int64_t* stride_a, const int64_t* stride_b,
const int64_t* shape_b, const int64_t numel_a,
const int64_t numel_b, const int ndims)
Expand Down Expand Up @@ -273,21 +211,23 @@ struct KronGradElemFunctor<complex128> {
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] =
dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag);
dout_[idx] *
platform::complex<T>(B_[index_b].real, -B_[index_b].imag);
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] =
dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag);
dout_[idx] *
platform::complex<T>(A_[index_a].real, -A_[index_a].imag);
}
}

private:
const complex128* dout_;
const complex128* A_;
const complex128* B_;
complex128* dout_a_;
complex128* dout_b_;
const platform::complex<T>* dout_;
const platform::complex<T>* A_;
const platform::complex<T>* B_;
platform::complex<T>* dout_a_;
platform::complex<T>* dout_b_;
const int64_t* stride_dout_;
const int64_t* stride_a_;
const int64_t* stride_b_;
Expand Down
34 changes: 18 additions & 16 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,23 +613,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t,
ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel,
paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64,
ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel);
paddle::platform::bfloat16, ops::ReshapeKernel,
paddle::platform::complex<float>, ops::ReshapeKernel,
paddle::platform::complex<double>, ops::ReshapeKernel);

REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, paddle::platform::bfloat16, ops::ReshapeGradKernel,
paddle::platform::complex64, ops::ReshapeGradKernel,
paddle::platform::complex128, ops::ReshapeGradKernel);
paddle::platform::complex<float>, ops::ReshapeGradKernel,
paddle::platform::complex<double>, ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, paddle::platform::bfloat16,
ops::ReshapeDoubleGradKernel, paddle::platform::complex64,
ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
ops::ReshapeDoubleGradKernel, paddle::platform::complex<float>,
ops::ReshapeDoubleGradKernel, paddle::platform::complex<double>,
ops::ReshapeDoubleGradKernel);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand All @@ -650,37 +651,38 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel);
plat::complex<float>, ops::ReshapeKernel,
plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad, float, ops::ReshapeGradKernel, double,
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel);
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex<double>, ops::ReshapeGradKernel);

REGISTER_OP_CUDA_KERNEL_FUNCTOR(
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel,
plat::float16, ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel,
plat::complex128, ops::ReshapeDoubleGradKernel);
ops::ReshapeDoubleGradKernel, plat::complex<float>,
ops::ReshapeDoubleGradKernel, plat::complex<double>,
ops::ReshapeDoubleGradKernel);
#endif

#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel,
plat::complex64, ops::ReshapeKernel,
plat::complex128, ops::ReshapeKernel);
plat::complex<float>, ops::ReshapeKernel,
plat::complex<double>, ops::ReshapeKernel);
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel, plat::complex64,
ops::ReshapeGradKernel, plat::complex128,
ops::ReshapeGradKernel, plat::complex<float>,
ops::ReshapeGradKernel, plat::complex<double>,
ops::ReshapeGradKernel);
#endif
16 changes: 8 additions & 8 deletions paddle/fluid/operators/slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,19 +436,19 @@ REGISTER_OP_CPU_KERNEL(
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
Expand All @@ -458,9 +458,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
slice_grad,
Expand All @@ -471,6 +471,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
Loading