Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: explicitly handle Float16 intrinsics #45249

Merged
merged 2 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 12 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1523,8 +1523,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
296 changes: 242 additions & 54 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,194 @@ INST_STATISTIC(FCmp);

namespace {

inline AttributeSet getFnAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getFnAttrs();
#else
return Attrs.getFnAttributes();
#endif
}

inline AttributeSet getRetAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getRetAttrs();
#else
return Attrs.getRetAttributes();
#endif
}

static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
{
Intrinsic::ID ID = call->getIntrinsicID();
assert(ID);
auto oldfType = call->getFunctionType();
auto nargs = oldfType->getNumParams();
assert(args.size() > nargs);
SmallVector<Type*, 8> argTys(nargs);
for (unsigned i = 0; i < nargs; i++)
argTys[i] = args[i]->getType();
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());

// Accumulate an array of overloaded types for the given intrinsic
// and compute the new name mangling schema
SmallVector<Type*, 4> overloadTys;
{
SmallVector<Intrinsic::IITDescriptor, 8> Table;
getIntrinsicInfoTableEntries(ID, Table);
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
(void)res;
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
assert(matchvararg);
(void)matchvararg;
}
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
assert(newF->getFunctionType() == newfType);
newF->setCallingConv(call->getCallingConv());
assert(args.back() == call->getCalledFunction());
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
newCall->setTailCallKind(call->getTailCallKind());
auto old_attrs = call->getAttributes();
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
getRetAttrs(old_attrs), {})); // drop parameter attributes
return newCall;
}


static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
{
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
if (auto *VC = dyn_cast<Constant>(V)) {
// The input IR often has things of the form
// fcmp olt half %0, 0xH7C00
// and we would like to avoid turning that constant into a call here
// if we can simply constant fold it to the new type.
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
if (VC)
return VC;
}
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
if (SrcTy->isVectorTy()) {
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
Value *NewV = UndefValue::get(DestTy);
RetTy = RetTy->getScalarType();
for (unsigned i = 0; i < NumElems; ++i) {
Value *I = builder.getInt32(i);
Value *Vi = builder.CreateExtractElement(V, I);
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
NewV = builder.CreateInsertElement(NewV, Vi, I);
}
return NewV;
}
auto &M = *builder.GetInsertBlock()->getModule();
auto &ctx = M.getContext();
// Pick the Function to call in the Julia runtime
StringRef Name;
switch (opcode) {
case Instruction::FPExt:
// this is exact, so we only need one conversion
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::FPTrunc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it not possible to register our variants of these functions with the TLI as @vchuravy suggested, and avoid having to rewrite code early like this?

assert(DestTy->isHalfTy());
if (SrcTy->isFloatTy())
Name = "julia__gnu_f2h_ieee";
else if (SrcTy->isDoubleTy())
Name = "julia__truncdfhf2";
break;
// All F16 fit exactly in Int32 (-65504 to 65504)
case Instruction::FPToSI: JL_FALLTHROUGH;
case Instruction::FPToUI:
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::SIToFP: JL_FALLTHROUGH;
case Instruction::UIToFP:
assert(DestTy->isHalfTy());
Name = "julia__gnu_f2h_ieee";
SrcTy = Type::getFloatTy(ctx);
break;
default:
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("invalid cast");
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
}
if (Name.empty()) {
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("illegal cast");
}
// Coerce the source to the required size and type
auto T_int16 = Type::getInt16Ty(ctx);
if (SrcTy->isHalfTy())
SrcTy = T_int16;
if (opcode == Instruction::SIToFP)
V = builder.CreateSIToFP(V, SrcTy);
else if (opcode == Instruction::UIToFP)
V = builder.CreateUIToFP(V, SrcTy);
else
V = builder.CreateBitCast(V, SrcTy);
// Call our intrinsic
if (RetTy->isHalfTy())
RetTy = T_int16;
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
FunctionCallee F = M.getOrInsertFunction(Name, FT);
Value *I = builder.CreateCall(F, {V});
// Coerce the result to the expected type
if (opcode == Instruction::FPToSI)
I = builder.CreateFPToSI(I, DestTy);
else if (opcode == Instruction::FPToUI)
I = builder.CreateFPToUI(I, DestTy);
else if (opcode == Instruction::FPExt)
I = builder.CreateFPCast(I, DestTy);
else
I = builder.CreateBitCast(I, DestTy);
return I;
}

static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

if (auto CI = dyn_cast<CastInst>(&I)) {
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
++TotalChanged;
IRBuilder<> builder(&I);
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
continue;
}

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -64,6 +243,9 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
if (intrinsic->getIntrinsicID())
break;
continue;
}

Expand All @@ -75,72 +257,78 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
if (Op->getType()->getScalarType()->isHalfTy()) {
++TotalExt;
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
}
Operands[i] = (Op);
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
// XXX: this is not correct in general
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
Type *RetTy = I.getType();
if (RetTy->getScalarType()->isHalfTy())
RetTy = RetTy->getWithNewType(T_float32);
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
Loading