-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Infer Symbolic Shape No.161,162][BUAA] norm
, p_norm
#67136
[Infer Symbolic Shape No.161,162][BUAA] norm
, p_norm
#67136
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
int axis = op->attribute<pir::Int32Attribute>("axis").data(); | ||
bool is_test = op->attribute<pir::BoolAttribute>("is_test").data(); | ||
|
||
if (!is_test) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉恭喜遇到一个不明参数,这个参数在api文档里没有,但yaml文件里声明了,一般是用来区分训练和推理的,is_test开着代表是跑推理,麻烦下周会跟同学们分享下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
if (!is_test) { | ||
if (axis < 0) axis += x_shape.size(); | ||
|
||
auto norm_shape = x_shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要copy一份
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我后来看了一下好像是因为 x_shape 被加了 const 标记,所以需要 copy 一份才能进行相关修改,泓清老师说这里保留原样不用修改了。
auto x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &x_shape = x_shape_or_data.shape(); | ||
auto x_rank = x_shape.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要滥用auto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
|
||
bool PNormOpInferSymbolicShape(pir::Operation *op, | ||
pir::InferSymbolicShapeContext *infer_context) { | ||
auto x_shape_or_data = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto &
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
"Current Input(X)'s shape is=[%s].", | ||
axis, | ||
x_rank, | ||
x_shape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么写有点丑,能不能把下面的axis 取正数放前面来,然后修改一下这两个断言
x_rank, | ||
x_shape)); | ||
|
||
std::vector<symbol::DimExpr> out_dim_vector; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有歧义,out_shape还是out_data
|
||
if (keepdim) { | ||
for (size_t i = 0; i < x_rank; ++i) { | ||
if (static_cast<int>(i) == axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i不如直接用int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
} | ||
} else { | ||
for (size_t i = 0; i < x_rank; ++i) { | ||
if (static_cast<int>(i) != axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,非必要的cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
LGTM |
…e#67136) * Finished norm and p_norm * Fixed name errors * Removed unused variables * Updated norm and pnorm in unary_infer_sym.cc according to suggested changes * Removed comments * Fixed errors returned by CI * Put some old code back * Use a bool variable to make life easier
PR Category
CINN
PR Types
Others
Description
添加
norm
和p_norm
两个中等难度的算子的符号推导接口实现。