Skip to content

Commit

Permalink
fixup! codegen: explicitly handle Float16 intrinsics
Browse files Browse the repository at this point in the history
Also need to handle vectors, since the vectorizer may have introduced them.

Also change our runtime emulation versions to f32 for consistency.
  • Loading branch information
vtjnash committed May 10, 2022
1 parent d2a7deb commit 09859cc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
54 changes: 41 additions & 13 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inline AttributeSet getRetAttrs(const AttributeList &Attrs)
#endif
}

static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetType, ArrayRef<Value*> args)
static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
{
Intrinsic::ID ID = call->getIntrinsicID();
assert(ID);
Expand All @@ -73,7 +73,7 @@ static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetType, Arr
SmallVector<Type*, 8> argTys(nargs);
for (unsigned i = 0; i < nargs; i++)
argTys[i] = args[i]->getType();
auto newfType = FunctionType::get(RetType, argTys, oldfType->isVarArg());
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());

// Accumulate an array of overloaded types for the given intrinsic
// and compute the new name mangling schema
Expand Down Expand Up @@ -104,6 +104,8 @@ static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetType, Arr

static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
{
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
if (auto *VC = dyn_cast<Constant>(V)) {
// The input IR often has things of the form
// fcmp olt half %0, 0xH7C00
Expand All @@ -113,10 +115,22 @@ static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy,
if (VC)
return VC;
}
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
if (SrcTy->isVectorTy()) {
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
Value *NewV = UndefValue::get(DestTy);
RetTy = RetTy->getScalarType();
for (unsigned i = 0; i < NumElems; ++i) {
Value *I = builder.getInt32(i);
Value *Vi = builder.CreateExtractElement(V, I);
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
NewV = builder.CreateInsertElement(NewV, Vi, I);
}
return NewV;
}
auto &M = *builder.GetInsertBlock()->getModule();
auto &ctx = M.getContext();
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
// Pick the Function to call in the Julia runtime
StringRef Name;
switch (opcode) {
Expand Down Expand Up @@ -147,10 +161,21 @@ static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy,
SrcTy = Type::getFloatTy(ctx);
break;
default:
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("invalid cast");
}
if (Name.empty())
if (Name.empty()) {
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("illegal cast");
}
// Coerce the source to the required size and type
auto T_int16 = Type::getInt16Ty(ctx);
if (SrcTy->isHalfTy())
Expand Down Expand Up @@ -188,10 +213,10 @@ static bool demoteFloat16(Function &F)
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->isHalfTy();
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->isHalfTy())
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
Expand Down Expand Up @@ -232,12 +257,13 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->isHalfTy()) {
if (Op->getType()->getScalarType()->isHalfTy()) {
++TotalExt;
Op = CreateFPCast(Instruction::FPExt, Op, T_float32, builder);
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
}
Operands[i] = (Op);
}
Expand Down Expand Up @@ -285,10 +311,12 @@ static bool demoteFloat16(Function &F)
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
Type *RetType = I.getType();
if (RetType->isHalfTy())
RetType = T_float32;
NewI = replaceIntrinsicWith(intrinsic, RetType, Operands);
// XXX: this is not correct in general
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
Type *RetTy = I.getType();
if (RetTy->getScalarType()->isHalfTy())
RetTy = RetTy->getWithNewType(T_float32);
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
break;
}
abort();
Expand Down
11 changes: 4 additions & 7 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
const unsigned int host_char_bit = 8;

// float16 intrinsics
// TODO: use LLVM's compiler-rt on all platforms (Xcode already links compiler-rt)

#if !defined(_OS_DARWIN_)

static inline float half_to_float(uint16_t ival) JL_NOTSAFEPOINT
{
Expand Down Expand Up @@ -238,7 +235,7 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
//HANDLE_LIBCALL(F16, I128, __fixunshfti)
//HANDLE_LIBCALL(I128, F16, __floattihf)
//HANDLE_LIBCALL(I128, F16, __floatuntihf)
#endif


// run time version of bitcast intrinsic
JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v)
Expand Down Expand Up @@ -564,9 +561,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT
}

#define fp_select(a, func) \
sizeof(a) == sizeof(float) ? func##f((float)a) : func(a)
sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a)
#define fp_select2(a, b, func) \
sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b)
sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b)

// fast-function generators //

Expand Down Expand Up @@ -1331,7 +1328,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \
fpiseq_n(float, 32)
fpiseq_n(double, 64)
#define fpiseq(a,b) \
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)

bool_fintrinsic(eq,eq_float)
bool_fintrinsic(ne,ne_float)
Expand Down

0 comments on commit 09859cc

Please sign in to comment.