diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 84c3bc82ee3aa..409fb3b713d3b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -782,6 +782,51 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, return true; } +bool FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + + // Validate the bit_length attribute + int bit_length = op->attribute("bit_length").data(); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, + phi::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); + + // Set the shape for the output tensor 'out', same as input tensor 'x' + infer_context->SetShapeOrDataForValue(op->result(0), x_shape); + + // Create a scalar shape for the other output tensors + symbol::TensorShapeOrDataDimExprs scalar_shape( + std::vector{symbol::DimExpr(1)}); + + // Set the shape for all scalar output tensors: 'out_scale', 'out_state', + // 'out_accum' + for (size_t i = 1; i < op->num_results(); ++i) { + infer_context->SetShapeOrDataForValue(op->result(i), scalar_shape); + } + + return true; +} + +bool FakeQuantizeMovingAverageAbsMax_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context); +} + +bool FakeQuantizeDequantizeMovingAverageAbsMaxOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context); +} + +bool FakeQuantizeDequantizeMovingAverageAbsMax_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape(op, infer_context); +} + bool FullWithTensorOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(1); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index a06c91cfde013..650b374b8cec1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -35,6 +35,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeMovingAverageAbsMax_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax_) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 4b0b2fdd1eec9..0d6ac49f92181 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -943,6 +943,31 @@ bool FlattenOpInferSymbolicShape( return true; } +bool FakeQuantizeDequantizeAbsMaxOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + + // Validate the bit_length attribute + int bit_length = op->attribute("bit_length").data(); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, + phi::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); + + // Set the shape for the output tensor 'out', same as input tensor 'x' + infer_context->SetShapeOrDataForValue(op->result(0), x_shape); + + // Set the shape for the output tensor 'out_scale' as a scalar {1} + symbol::TensorShapeOrDataDimExprs scalar_shape( + std::vector{symbol::DimExpr(1)}); + infer_context->SetShapeOrDataForValue(op->result(1), scalar_shape); + + return true; +} + bool Flatten_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { return FlattenOpInferSymbolicShape(op, infer_context); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9148f36e3796a..a9c978c6f63e6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -62,6 +62,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonal_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeDequantizeAbsMax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Inverse) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(GumbelSoftmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(IdentityLoss) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index c0be672232e0e..7fa78304e5aa0 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1715,6 +1715,7 @@ func : fake_quantize_dequantize_abs_max data_type : x backward : fake_quantize_dequantize_abs_max_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : fake_quantize_dequantize_moving_average_abs_max args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1) @@ -1727,6 +1728,7 @@ optional : in_accum, in_state, out_state, out_accum backward : fake_quantize_dequantize_moving_average_abs_max_grad inplace: (in_scale -> out_scale) + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : fake_quantize_moving_average_abs_max args : (Tensor x, Tensor in_scale, Tensor in_accum, Tensor in_state, float moving_rate = 0.9, int bit_length = 8, bool is_test = false, int round_type = 1) @@ -1738,6 +1740,7 @@ data_type : x optional : in_accum, in_state, out_state, out_accum inplace: (in_scale -> out_scale) + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : fake_quantize_range_abs_max args : (Tensor x, Tensor in_scale, Tensor iter, int window_size = 10000, int bit_length = 8, bool is_test = false, int round_type = 1)