Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape BUAA No.42-44】Add ops #66930

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,64 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0));

// Validate the bit_length attribute
int bit_length = op->attribute<pir::Int32Attribute>("bit_length").data();
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
phi::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));

// Set the symbolic shape for the output tensors
// Set the shape for the output tensor 'out', same as input tensor 'x'
infer_context->SetShapeOrDataForValue(op->result(0), x_shape);

// Set the shape for the output tensor 'out_scale'
if (op->num_results() > 1 && op->result(1) != nullptr) {
symbol::TensorShapeOrDataDimExprs scalar_shape(
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
Guanhuachen2003 marked this conversation as resolved.
Show resolved Hide resolved
infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape);
}

// Set the shape for the output tensor 'out_state'
if (op->num_results() > 2 && op->result(2) != nullptr) {
symbol::TensorShapeOrDataDimExprs scalar_shape(
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
infer_context->SetShapeOrDataForValue(op->result(2), scalar_shape);
}

// Set the shape for the output tensor 'out_accum'
if (op->num_results() > 3 && op->result(3) != nullptr) {
Guanhuachen2003 marked this conversation as resolved.
Show resolved Hide resolved
symbol::TensorShapeOrDataDimExprs scalar_shape(
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
infer_context->SetShapeOrDataForValue(op->result(3), scalar_shape);
}

return true;
}

bool FakeQuantizeMovingAverageAbsMax_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
}

bool FakeQuantizeDequantizeMovingAverageAbsMaxOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// share the same infermeta FakeQuantOrWithDequantMovingAverageAbsMaxInferMeta
Guanhuachen2003 marked this conversation as resolved.
Show resolved Hide resolved
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
}

bool FakeQuantizeDequantizeMovingAverageAbsMax_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context);
}

bool FullWithTensorOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,34 @@ bool FlattenOpInferSymbolicShape(
return true;
}

bool FakeQuantizeDequantizeAbsMaxOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x_shape =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_shape_or_data

infer_context->GetShapeOrDataForValue(op->operand_source(0));

// Validate the bit_length attribute
int bit_length = op->attribute<pir::Int32Attribute>("bit_length").data();
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
phi::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));

// Set the symbolic shape for the output tensors
// Set the shape for the output tensor 'out', same as input tensor 'x'
infer_context->SetShapeOrDataForValue(op->result(0), x_shape);

// Set the shape for the output tensor 'out_scale' as a scalar {1}
if (op->num_results() > 1 && op->result(1) != nullptr) {
Guanhuachen2003 marked this conversation as resolved.
Show resolved Hide resolved
symbol::TensorShapeOrDataDimExprs scalar_shape(
std::vector<symbol::DimExpr>{symbol::DimExpr(1)});
infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape);
}

return true;
}

bool Flatten_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return FlattenOpInferSymbolicShape(op, infer_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeAbsMax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue)
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,7 @@
func : fake_quantize_dequantize_abs_max
data_type : x
backward : fake_quantize_dequantize_abs_max_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fake_quantize_dequantize_moving_average_abs_max
args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1)
Expand All @@ -1713,6 +1714,7 @@
optional : in_accum, in_state, out_state, out_accum
backward : fake_quantize_dequantize_moving_average_abs_max_grad
inplace: (in_scale -> out_scale)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fake_quantize_moving_average_abs_max
args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1)
Expand All @@ -1724,6 +1726,7 @@
data_type : x
optional : in_accum, in_state, out_state, out_accum
inplace: (in_scale -> out_scale)
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : fake_quantize_range_abs_max
args : (Tensor x, Tensor in_scale, Tensor iter, int window_size = 10000, int bit_length = 8, bool is_test = false, int round_type = 1)
Expand Down