Skip to content

Commit

Permalink
AMDGPU: Support llvm.exp10 (llvm#65860)
Browse files Browse the repository at this point in the history
  • Loading branch information
arsenm authored Dec 2, 2023
1 parent 3c86bc0 commit db8b85a
Show file tree
Hide file tree
Showing 4 changed files with 7,641 additions and 16 deletions.
89 changes: 75 additions & 14 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,9 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FLOG2, MVT::f32, Custom);
setOperationAction(ISD::FROUND, {MVT::f32, MVT::f64}, Custom);

setOperationAction({ISD::FLOG, ISD::FLOG10, ISD::FEXP, ISD::FEXP2}, MVT::f32,
Custom);
setOperationAction(
{ISD::FLOG, ISD::FLOG10, ISD::FEXP, ISD::FEXP2, ISD::FEXP10}, MVT::f32,
Custom);

setOperationAction(ISD::FNEARBYINT, {MVT::f16, MVT::f32, MVT::f64}, Custom);

Expand All @@ -352,7 +353,8 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FLOG2, ISD::FEXP2}, MVT::f16, Custom);
}

setOperationAction({ISD::FLOG10, ISD::FLOG, ISD::FEXP}, MVT::f16, Custom);
setOperationAction({ISD::FLOG10, ISD::FLOG, ISD::FEXP, ISD::FEXP10}, MVT::f16,
Custom);

// FIXME: These IS_FPCLASS vector fp types are marked custom so it reaches
// scalarization code. Can be removed when IS_FPCLASS expand isn't called by
Expand Down Expand Up @@ -457,14 +459,17 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,

for (MVT VT : FloatVectorTypes) {
setOperationAction(
{ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD,
ISD::FCEIL, ISD::FCOS, ISD::FDIV, ISD::FEXP2,
ISD::FEXP, ISD::FLOG2, ISD::FREM, ISD::FLOG,
ISD::FLOG10, ISD::FPOW, ISD::FFLOOR, ISD::FTRUNC,
ISD::FMUL, ISD::FMA, ISD::FRINT, ISD::FNEARBYINT,
ISD::FSQRT, ISD::FSIN, ISD::FSUB, ISD::FNEG,
ISD::VSELECT, ISD::SELECT_CC, ISD::FCOPYSIGN, ISD::VECTOR_SHUFFLE,
ISD::SETCC, ISD::FCANONICALIZE, ISD::FROUNDEVEN},
{ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM,
ISD::FADD, ISD::FCEIL, ISD::FCOS,
ISD::FDIV, ISD::FEXP2, ISD::FEXP,
ISD::FEXP10, ISD::FLOG2, ISD::FREM,
ISD::FLOG, ISD::FLOG10, ISD::FPOW,
ISD::FFLOOR, ISD::FTRUNC, ISD::FMUL,
ISD::FMA, ISD::FRINT, ISD::FNEARBYINT,
ISD::FSQRT, ISD::FSIN, ISD::FSUB,
ISD::FNEG, ISD::VSELECT, ISD::SELECT_CC,
ISD::FCOPYSIGN, ISD::VECTOR_SHUFFLE, ISD::SETCC,
ISD::FCANONICALIZE, ISD::FROUNDEVEN},
VT, Expand);
}

Expand Down Expand Up @@ -1322,6 +1327,7 @@ SDValue AMDGPUTargetLowering::LowerOperation(SDValue Op,
case ISD::FLOG10:
return LowerFLOGCommon(Op, DAG);
case ISD::FEXP:
case ISD::FEXP10:
return lowerFEXP(Op, DAG);
case ISD::FEXP2:
return lowerFEXP2(Op, DAG);
Expand Down Expand Up @@ -1367,6 +1373,7 @@ void AMDGPUTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(Lowered);
return;
case ISD::FEXP:
case ISD::FEXP10:
if (SDValue Lowered = lowerFEXP(SDValue(N, 0), DAG))
Results.push_back(Lowered);
return;
Expand Down Expand Up @@ -2841,12 +2848,66 @@ SDValue AMDGPUTargetLowering::lowerFEXPUnsafe(SDValue X, const SDLoc &SL,
Flags);
}

/// Emit approx-funcs appropriate lowering for exp10. inf/nan should still be
/// handled correctly.
SDValue AMDGPUTargetLowering::lowerFEXP10Unsafe(SDValue X, const SDLoc &SL,
SelectionDAG &DAG,
SDNodeFlags Flags) const {
const EVT VT = X.getValueType();
const unsigned Exp2Op = VT == MVT::f32 ? AMDGPUISD::EXP : ISD::FEXP2;

if (VT != MVT::f32 || !needsDenormHandlingF32(DAG, X, Flags)) {
// exp2(x * 0x1.a92000p+1f) * exp2(x * 0x1.4f0978p-11f);
SDValue K0 = DAG.getConstantFP(0x1.a92000p+1f, SL, VT);
SDValue K1 = DAG.getConstantFP(0x1.4f0978p-11f, SL, VT);

SDValue Mul0 = DAG.getNode(ISD::FMUL, SL, VT, X, K0, Flags);
SDValue Exp2_0 = DAG.getNode(Exp2Op, SL, VT, Mul0, Flags);
SDValue Mul1 = DAG.getNode(ISD::FMUL, SL, VT, X, K1, Flags);
SDValue Exp2_1 = DAG.getNode(Exp2Op, SL, VT, Mul1, Flags);
return DAG.getNode(ISD::FMUL, SL, VT, Exp2_0, Exp2_1);
}

// bool s = x < -0x1.2f7030p+5f;
// x += s ? 0x1.0p+5f : 0.0f;
// exp10 = exp2(x * 0x1.a92000p+1f) *
// exp2(x * 0x1.4f0978p-11f) *
// (s ? 0x1.9f623ep-107f : 1.0f);

EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);

SDValue Threshold = DAG.getConstantFP(-0x1.2f7030p+5f, SL, VT);
SDValue NeedsScaling = DAG.getSetCC(SL, SetCCVT, X, Threshold, ISD::SETOLT);

SDValue ScaleOffset = DAG.getConstantFP(0x1.0p+5f, SL, VT);
SDValue ScaledX = DAG.getNode(ISD::FADD, SL, VT, X, ScaleOffset, Flags);
SDValue AdjustedX =
DAG.getNode(ISD::SELECT, SL, VT, NeedsScaling, ScaledX, X);

SDValue K0 = DAG.getConstantFP(0x1.a92000p+1f, SL, VT);
SDValue K1 = DAG.getConstantFP(0x1.4f0978p-11f, SL, VT);

SDValue Mul0 = DAG.getNode(ISD::FMUL, SL, VT, AdjustedX, K0, Flags);
SDValue Exp2_0 = DAG.getNode(Exp2Op, SL, VT, Mul0, Flags);
SDValue Mul1 = DAG.getNode(ISD::FMUL, SL, VT, AdjustedX, K1, Flags);
SDValue Exp2_1 = DAG.getNode(Exp2Op, SL, VT, Mul1, Flags);

SDValue MulExps = DAG.getNode(ISD::FMUL, SL, VT, Exp2_0, Exp2_1, Flags);

SDValue ResultScaleFactor = DAG.getConstantFP(0x1.9f623ep-107f, SL, VT);
SDValue AdjustedResult =
DAG.getNode(ISD::FMUL, SL, VT, MulExps, ResultScaleFactor, Flags);

return DAG.getNode(ISD::SELECT, SL, VT, NeedsScaling, AdjustedResult, MulExps,
Flags);
}

SDValue AMDGPUTargetLowering::lowerFEXP(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
SDLoc SL(Op);
SDValue X = Op.getOperand(0);
SDNodeFlags Flags = Op->getFlags();
const bool IsExp10 = false; // TODO: For some reason exp10 is missing
const bool IsExp10 = Op.getOpcode() == ISD::FEXP10;

if (VT.getScalarType() == MVT::f16) {
// v_exp_f16 (fmul x, log2e)
Expand All @@ -2871,8 +2932,8 @@ SDValue AMDGPUTargetLowering::lowerFEXP(SDValue Op, SelectionDAG &DAG) const {
// TODO: Interpret allowApproxFunc as ignoring DAZ. This is currently copying
// library behavior. Also, is known-not-daz source sufficient?
if (allowApproxFunc(DAG, Flags)) {
assert(!IsExp10 && "todo exp10 support");
return lowerFEXPUnsafe(X, SL, DAG, Flags);
return IsExp10 ? lowerFEXP10Unsafe(X, SL, DAG, Flags)
: lowerFEXPUnsafe(X, SL, DAG, Flags);
}

// Algorithm:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class AMDGPUTargetLowering : public TargetLowering {

SDValue lowerFEXPUnsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
SDNodeFlags Flags) const;
SDValue lowerFEXP10Unsafe(SDValue Op, const SDLoc &SL, SelectionDAG &DAG,
SDNodeFlags Flags) const;
SDValue lowerFEXP(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerCTLZ_CTTZ(SDValue Op, SelectionDAG &DAG) const;
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,8 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
Log2Ops.scalarize(0)
.lower();

auto &LogOps = getActionDefinitionsBuilder({G_FLOG, G_FLOG10, G_FEXP});
auto &LogOps =
getActionDefinitionsBuilder({G_FLOG, G_FLOG10, G_FEXP, G_FEXP10});
LogOps.customFor({S32, S16});
LogOps.clampScalar(0, MinScalarFPTy, S32)
.scalarize(0);
Expand Down Expand Up @@ -2045,6 +2046,7 @@ bool AMDGPULegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
case TargetOpcode::G_FEXP2:
return legalizeFExp2(MI, B);
case TargetOpcode::G_FEXP:
case TargetOpcode::G_FEXP10:
return legalizeFExp(MI, B);
case TargetOpcode::G_FPOW:
return legalizeFPow(MI, B);
Expand Down Expand Up @@ -3466,7 +3468,7 @@ bool AMDGPULegalizerInfo::legalizeFExp(MachineInstr &MI,
LLT Ty = MRI.getType(Dst);
const LLT F16 = LLT::scalar(16);
const LLT F32 = LLT::scalar(32);
const bool IsExp10 = false; // TODO: For some reason exp10 is missing
const bool IsExp10 = MI.getOpcode() == TargetOpcode::G_FEXP10;

if (Ty == F16) {
// v_exp_f16 (fmul x, log2e)
Expand Down
Loading

0 comments on commit db8b85a

Please sign in to comment.