-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape BUAA No.27】edit_distance #66870
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -637,6 +637,75 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, | |
return true; | ||
} | ||
|
||
bool EditDistanceOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &hyps_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &refs_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
const auto &hyps_dims = hyps_shape_or_data.shape(); | ||
const auto &refs_dims = refs_shape_or_data.shape(); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
hyps_dims.size(), | ||
2, | ||
phi::errors::InvalidArgument( | ||
"Input(Hyps) must be a 2-D Tensor, but received rank %u.", | ||
hyps_dims.size())); | ||
PADDLE_ENFORCE_EQ( | ||
refs_dims.size(), | ||
2, | ||
phi::errors::InvalidArgument( | ||
"Input(Refs) must be a 2-D Tensor, but received rank %u.", | ||
refs_dims.size())); | ||
PADDLE_ENFORCE_EQ( | ||
hyps_dims[0], | ||
refs_dims[0], | ||
phi::errors::InvalidArgument("The first dimension of Input(Hyps) and " | ||
"Input(Refs) must be the same, " | ||
"but received: %d vs %d.", | ||
hyps_dims[0], | ||
refs_dims[0])); | ||
|
||
bool has_lengths = op->operand_source(2) && op->operand_source(3); | ||
if (has_lengths) { | ||
const auto &hypslength_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
const auto &refslength_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
hypslength_shape_or_data.shape()[0], | ||
hyps_dims[0], | ||
phi::errors::InvalidArgument("Input(HypsLength) and Input(Hyps) should " | ||
"have identical first dimension.")); | ||
PADDLE_ENFORCE_EQ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该使用 Addequalcstr 添加两个DimExpr之间的约束 |
||
refslength_shape_or_data.shape()[0], | ||
refs_dims[0], | ||
phi::errors::InvalidArgument("Input(RefsLength) and Input(Refs) should " | ||
"have identical first dimension.")); | ||
} else { | ||
PADDLE_ENFORCE_EQ( | ||
hyps_dims[1], | ||
1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 构造一个值为1的DimExpr,然后使用Addequalcstr 添加两个DimExpr之间的约束 |
||
phi::errors::InvalidArgument("Input(Hyps) must be a 2-D Tensor with " | ||
"the 2nd dimension equal to 1.")); | ||
PADDLE_ENFORCE_EQ( | ||
refs_dims[1], | ||
1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
phi::errors::InvalidArgument("Input(Refs) must be a 2-D Tensor with " | ||
"the 2nd dimension equal to 1.")); | ||
} | ||
|
||
infer_context->SetShapeOrDataForValue(op->result(0), refs_shape_or_data); | ||
symbol::ShapeOrData<symbol::DimExpr> single_dim_expr({symbol::DimExpr(1)}); | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(1), symbol::ShapeOrDataDimExprs({single_dim_expr})); | ||
|
||
return true; | ||
} | ||
|
||
bool FullWithTensorOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
pir::Value operand_source = op->operand_source(1); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该使用 Addequalcstr 添加两个DimExpr之间的约束