-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[phi] move stft to phi - Step 1 #53517
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
DenseTensor* out) { | ||
using C = phi::dtype::complex<T>; | ||
|
||
out->mutable_data<C>(ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在phi下不可使用mutable_data
,请用Context.Alloc
接口代替
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/stft_kernel_impl.h" | ||
|
||
PD_REGISTER_KERNEL(stft, CPU, ALL_LAYOUT, phi::StftKernel, float, double) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个算子迁移后需要注册输出类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1; | ||
onesided_dims.at(axes.back()) = onesided_axis_size; | ||
phi::DenseTensor onesided_out; | ||
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还有一个mutable_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
此算子动态图下是通过python端组合其它API实现,未调用C++端kernel,C++算子仅用于静态图。test_stft_op
单测原本通过paddle.enable_static()
设置只测试静态图,动态图API在test_signal
中进行了测试。本PR将算子迁移到phi后,触发单测框架动态图测试拦截,因而在test_stft_op
中通过设置check_dygraph=False
保持静态图测试。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for check_dygraph
PR types
Others
PR changes
Others
Description
迁移
stft
相关issues