Skip to content

Commit

Permalink
[mlir][complex] Support Fastmath flag in conversion of complex.sqrt t…
Browse files Browse the repository at this point in the history
…o standard (llvm#85019)

When converting complex.sqrt op to standard, we need to keep the fast
math flag given to the op.

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981
  • Loading branch information
Lewuathe authored Mar 14, 2024
1 parent e6048b7 commit 34ba907
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 27 deletions.
15 changes: 8 additions & 7 deletions mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,21 +845,22 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
auto type = cast<ComplexType>(op.getType());
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();

Value zero =
b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));

Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());

Value absLhs = b.create<math::AbsFOp>(real);
Value absArg = b.create<complex::AbsOp>(elementType, arg);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
Value absLhs = b.create<math::AbsFOp>(real, fmf);
Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);

Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);

Value realIsNegative =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
Expand All @@ -869,7 +870,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
Value resultReal = sqrtAddAbs;

Value imagDivTwoResultReal = b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultReal, resultReal));
imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);

Value negativeResultReal = b.create<arith::NegFOp>(resultReal);

Expand All @@ -882,7 +883,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
resultReal = b.create<arith::SelectOp>(
realIsNegative,
b.create<arith::DivFOp>(
imag, b.create<arith::AddFOp>(resultImag, resultImag)),
imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
resultReal);

Value realIsZero =
Expand Down
146 changes: 126 additions & 20 deletions mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,60 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
// -----

// CHECK-LABEL: func @complex_sqrt
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg : complex<f32>
return %sqrt : complex<f32>
}

// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_conj
Expand Down Expand Up @@ -1254,42 +1303,42 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR187:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR188:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] : f32
// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_7:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST_8:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR190:.*]] = complex.re %[[VAR186]] : complex<f32>
// CHECK: %[[VAR191:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[VAR192:.*]] = arith.cmpf oeq, %[[VAR190]], %[[CST_7]] : f32
// CHECK: %[[VAR193:.*]] = arith.cmpf oeq, %[[VAR191]], %[[CST_7]] : f32
// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] : f32
// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] : f32
// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] : f32
// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] : f32
// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] : f32
// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] : f32
// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] : f32
// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] : f32
// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] : f32
// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] : f32
// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] : f32
// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] : f32
// CHECK: %[[VAR194:.*]] = arith.divf %[[VAR191]], %[[VAR190]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR195:.*]] = arith.mulf %[[VAR194]], %[[VAR194]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR196:.*]] = arith.addf %[[VAR195]], %[[CST_8]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR197:.*]] = math.sqrt %[[VAR196]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR198:.*]] = math.absf %[[VAR190]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR199:.*]] = arith.mulf %[[VAR197]], %[[VAR198]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR200:.*]] = arith.divf %[[VAR190]], %[[VAR191]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR201:.*]] = arith.mulf %[[VAR200]], %[[VAR200]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR202:.*]] = arith.addf %[[VAR201]], %[[CST_8]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR203:.*]] = math.sqrt %[[VAR202]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR204:.*]] = math.absf %[[VAR191]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR205:.*]] = arith.mulf %[[VAR203]], %[[VAR204]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR206:.*]] = arith.cmpf ogt, %[[VAR190]], %[[VAR191]] : f32
// CHECK: %[[VAR207:.*]] = arith.select %[[VAR206]], %[[VAR199]], %[[VAR205]] : f32
// CHECK: %[[VAR208:.*]] = arith.select %[[VAR193]], %[[VAR198]], %[[VAR207]] : f32
// CHECK: %[[VAR209:.*]] = arith.select %[[VAR192]], %[[VAR204]], %[[VAR208]] : f32
// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] : f32
// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[VAR209]] fastmath<nnan,contract> : f32
// CHECK: %[[CST_9:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] : f32
// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] : f32
// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR213:.*]] = arith.cmpf olt, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR214:.*]] = arith.cmpf olt, %[[VAR188]], %[[CST_6]] : f32
// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] : f32
// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] : f32
// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR217:.*]] = arith.negf %[[VAR212]] : f32
// CHECK: %[[VAR218:.*]] = arith.select %[[VAR214]], %[[VAR217]], %[[VAR212]] : f32
// CHECK: %[[VAR219:.*]] = arith.select %[[VAR213]], %[[VAR218]], %[[VAR216]] : f32
// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] : f32
// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] : f32
// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR222:.*]] = arith.select %[[VAR213]], %[[VAR221]], %[[VAR212]] : f32
// CHECK: %[[VAR223:.*]] = arith.cmpf oeq, %[[VAR187]], %[[CST_6]] : f32
// CHECK: %[[VAR224:.*]] = arith.cmpf oeq, %[[VAR188]], %[[CST_6]] : f32
Expand Down Expand Up @@ -1728,3 +1777,60 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>

// -----

// CHECK-LABEL: func @complex_sqrt_with_fmf
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
%sqrt = complex.sqrt %arg fastmath<nnan,contract> : complex<f32>
return %sqrt : complex<f32>
}

// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] fastmath<nnan,contract> : f32
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] fastmath<nnan,contract> : f32
// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
// CHECK: return %[[VAR41]] : complex<f32>

0 comments on commit 34ba907

Please sign in to comment.