Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape】【BUAA】Add placeholders #67049

Merged
merged 24 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ bool BceLoss_OpInferSymbolicShape(
return BceLossOpInferSymbolicShape(op, infer_context);
}

// bool BincountOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool BmmOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// // pass
// return true;
// }

// bool CholeskySolveOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool CtcAlignOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape =
Expand Down Expand Up @@ -286,6 +306,12 @@ bool CrossOpInferSymbolicShape(pir::Operation *op,
return true;
}

// bool DotOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// // pass
// return true;
// }

bool EmbeddingOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const std::vector<symbol::DimExpr> &x_dims =
Expand Down Expand Up @@ -378,6 +404,20 @@ bool FillDiagonalTensor_OpInferSymbolicShape(
return FillDiagonalTensorOpInferSymbolicShape(op, infer_context);
}

// bool FusedSoftmaxMaskOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool GridSampleOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool GatherOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &input_shape_or_data =
Expand Down Expand Up @@ -500,6 +540,13 @@ bool GatherNdOpInferSymbolicShape(
return true;
}

// bool HistogramOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool IndexSampleOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &operand_shape_or_data =
Expand Down Expand Up @@ -536,6 +583,20 @@ bool KronOpInferSymbolicShape(pir::Operation *op,
return true;
}

// bool LstsqOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }

// bool MatrixRankTolOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool MaskedSelectOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const std::vector<symbol::DimExpr> &out_dims = [&] {
Expand Down Expand Up @@ -659,6 +720,27 @@ bool MatmulOpInferSymbolicShape(pir::Operation *op,
return true;
}

// bool PullBoxSparseOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool PullGpuPsSparseOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool PullSparseV2OpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool SearchsortedOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// The shape of output is the same as input `values` (op->operand_source(1))
Expand All @@ -670,6 +752,20 @@ bool SearchsortedOpInferSymbolicShape(
return true;
}

// bool SequenceMaskOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool SwigluOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

bool IscloseOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// The shape of output is the same as input `values` (op->operand_source(1))
Expand Down Expand Up @@ -761,6 +857,55 @@ bool TopPSamplingOpInferSymbolicShape(
return true;
}

// bool TdmChildOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool UnpoolOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool YoloBoxOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool YoloBoxHeadOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool GammainccOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool HeavisideOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

// bool NextafterOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,49 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Allclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dot)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedSoftmaxMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gammaincc)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Histogram)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Heaviside)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AccuracyCheck)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstsq)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nextafter)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullGpuPsSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullSparseV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(SequenceMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Swiglu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(TdmChild)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unpool)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBox)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBoxHead)

} // namespace paddle::dialect

Expand Down
Loading