Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DemoteFloat16 a conditional pass #43327

Merged
merged 10 commits into from
Nov 21, 2022
38 changes: 37 additions & 1 deletion src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <llvm/IR/Module.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Debug.h>
#include "julia.h"
#include "jitlayers.h"

#define DEBUG_TYPE "demote_float16"

Expand All @@ -43,13 +45,47 @@ INST_STATISTIC(FRem);
INST_STATISTIC(FCmp);
#undef INST_STATISTIC

extern JuliaOJIT *jl_ExecutionEngine;

Optional<bool> always_have_fp16() {
#if defined(_CPU_X86_) || defined(_CPU_X86_64_)
// x86 doesn't support fp16
// TODO: update for sapphire rapids when it comes out
return false;
#else
return {};
#endif
}

namespace {

bool have_fp16(Function &caller) {
auto unconditional = always_have_fp16();
if (unconditional.hasValue())
return unconditional.getValue();

Attribute FSAttr = caller.getFnAttribute("target-features");
StringRef FS =
FSAttr.isValid() ? FSAttr.getValueAsString() : jl_ExecutionEngine->getTargetFeatureString();
#if defined(_CPU_AARCH64_)
if (FS.find("+fp16fml") != llvm::StringRef::npos || FS.find("+fullfp16") != llvm::StringRef::npos){
return true;
}
#else
if (FS.find("+avx512fp16") != llvm::StringRef::npos){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note https://reviews.llvm.org/D107082 we are not there yet, but LLVM will support _Float16 correctly on SSE2 and above.

Note that the LLVM PR also changes the ABI to match GCC12 and thus is going to break us in fun ways. I haven't found how -fexcess-precision=16 is going to be implemented in LLVM,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added this there because GCC was complaining about FS being unused. That part of the branch doesn't matter for now since x86 is considered as never having Float16 for now.
SSE2 has f16C instructions which are just fast conversions which we might already use, I know we use them on aarch64 at least. The first native operations on float16 are the avx512 ones.

return true;
}
#endif
return false;
}

static bool demoteFloat16(Function &F)
{
if (have_fp16(F))
return false;

auto &ctx = F.getContext();
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
Expand Down
10 changes: 10 additions & 0 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ using namespace llvm;

extern Optional<bool> always_have_fma(Function&);

extern Optional<bool> always_have_fp16();

namespace {
constexpr uint32_t clone_mask =
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU;
Expand Down Expand Up @@ -480,6 +482,14 @@ uint32_t CloneCtx::collect_func_info(Function &F)
flag |= JL_TARGET_CLONE_MATH;
}
}
if(!always_have_fp16().hasValue()){
for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
}
}
if (has_veccall && (flag & JL_TARGET_CLONE_SIMD) && (flag & JL_TARGET_CLONE_MATH)) {
return flag;
}
Expand Down
2 changes: 2 additions & 0 deletions src/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ enum {
JL_TARGET_MINSIZE = 1 << 7,
// Clone when the function queries CPU features
JL_TARGET_CLONE_CPU = 1 << 8,
// Clone when the function uses fp16
JL_TARGET_CLONE_FLOAT16 = 1 << 9,
};

#define JL_FEATURE_DEF_NAME(name, bit, llvmver, str) JL_FEATURE_DEF(name, bit, llvmver)
Expand Down
9 changes: 8 additions & 1 deletion src/processor_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,12 +1602,19 @@ static void ensure_jit_target(bool imaging)
auto &t = jit_targets[i];
if (t.en.flags & JL_TARGET_CLONE_ALL)
continue;
auto &features0 = jit_targets[t.base].en.features;
// Always clone when code checks CPU features
t.en.flags |= JL_TARGET_CLONE_CPU;
static constexpr uint32_t clone_fp16[] = {Feature::fp16fml,Feature::fullfp16};
for (auto fe: clone_fp16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_FLOAT16;
break;
}
}
// The most useful one in general...
t.en.flags |= JL_TARGET_CLONE_LOOP;
#ifdef _CPU_ARM_
auto &features0 = jit_targets[t.base].en.features;
static constexpr uint32_t clone_math[] = {Feature::vfp3, Feature::vfp4, Feature::neon};
for (auto fe: clone_math) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
Expand Down