Skip to content

Commit

Permalink
【Infer Symbolic Shape BUAA No.84】add numel (#66607)
Browse files Browse the repository at this point in the history
* add numel

* remove unused var
  • Loading branch information
uanu2002 authored Jul 31, 2024
1 parent 2be8b7e commit ac6df27
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,17 @@ bool NonzeroOpInferSymbolicShape(
return true;
}

bool NumelOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
std::vector<symbol::DimExpr> out_shape = {};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});

return true;
}

bool PadOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
// input(0): Tensor x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3348,6 +3348,7 @@
data_transform:
skip_transform : x
no_need_buffer : x
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : one_hot
args : (Tensor x, Scalar(int) num_classes)
Expand Down

0 comments on commit ac6df27

Please sign in to comment.