Skip to content

Commit

Permalink
Make DemoteFloat16 a conditional pass (#43327)
Browse files Browse the repository at this point in the history
* add TargetMachine check

* Add initial float16 multiversioning stuff

* make check more robust and remove x86 check

* move check to inside the pass

* C++ is hard

* Comment out the ckeck because it won't work inside the pass

* whitespace in the comment

* Change the logic not to depend on a TM

* Add preliminary support for x86 test

* Cosmetic changes
  • Loading branch information
gbaraldi authored Nov 21, 2022
1 parent c9eccfc commit d18fd47
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
38 changes: 37 additions & 1 deletion src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <llvm/IR/Module.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Debug.h>
#include "julia.h"
#include "jitlayers.h"

#define DEBUG_TYPE "demote_float16"

Expand All @@ -43,13 +45,47 @@ INST_STATISTIC(FRem);
INST_STATISTIC(FCmp);
#undef INST_STATISTIC

extern JuliaOJIT *jl_ExecutionEngine;

Optional<bool> always_have_fp16() {
#if defined(_CPU_X86_) || defined(_CPU_X86_64_)
// x86 doesn't support fp16
// TODO: update for sapphire rapids when it comes out
return false;
#else
return {};
#endif
}

namespace {

bool have_fp16(Function &caller) {
auto unconditional = always_have_fp16();
if (unconditional.hasValue())
return unconditional.getValue();

Attribute FSAttr = caller.getFnAttribute("target-features");
StringRef FS =
FSAttr.isValid() ? FSAttr.getValueAsString() : jl_ExecutionEngine->getTargetFeatureString();
#if defined(_CPU_AARCH64_)
if (FS.find("+fp16fml") != llvm::StringRef::npos || FS.find("+fullfp16") != llvm::StringRef::npos){
return true;
}
#else
if (FS.find("+avx512fp16") != llvm::StringRef::npos){
return true;
}
#endif
return false;
}

static bool demoteFloat16(Function &F)
{
if (have_fp16(F))
return false;

auto &ctx = F.getContext();
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
Expand Down
10 changes: 10 additions & 0 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ using namespace llvm;

extern Optional<bool> always_have_fma(Function&);

extern Optional<bool> always_have_fp16();

namespace {
constexpr uint32_t clone_mask =
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU;
Expand Down Expand Up @@ -480,6 +482,14 @@ uint32_t CloneCtx::collect_func_info(Function &F)
flag |= JL_TARGET_CLONE_MATH;
}
}
if(!always_have_fp16().hasValue()){
for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
}
}
if (has_veccall && (flag & JL_TARGET_CLONE_SIMD) && (flag & JL_TARGET_CLONE_MATH)) {
return flag;
}
Expand Down
2 changes: 2 additions & 0 deletions src/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ enum {
JL_TARGET_MINSIZE = 1 << 7,
// Clone when the function queries CPU features
JL_TARGET_CLONE_CPU = 1 << 8,
// Clone when the function uses fp16
JL_TARGET_CLONE_FLOAT16 = 1 << 9,
};

#define JL_FEATURE_DEF_NAME(name, bit, llvmver, str) JL_FEATURE_DEF(name, bit, llvmver)
Expand Down
9 changes: 8 additions & 1 deletion src/processor_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,12 +1602,19 @@ static void ensure_jit_target(bool imaging)
auto &t = jit_targets[i];
if (t.en.flags & JL_TARGET_CLONE_ALL)
continue;
auto &features0 = jit_targets[t.base].en.features;
// Always clone when code checks CPU features
t.en.flags |= JL_TARGET_CLONE_CPU;
static constexpr uint32_t clone_fp16[] = {Feature::fp16fml,Feature::fullfp16};
for (auto fe: clone_fp16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_FLOAT16;
break;
}
}
// The most useful one in general...
t.en.flags |= JL_TARGET_CLONE_LOOP;
#ifdef _CPU_ARM_
auto &features0 = jit_targets[t.base].en.features;
static constexpr uint32_t clone_math[] = {Feature::vfp3, Feature::vfp4, Feature::neon};
for (auto fe: clone_math) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
Expand Down

0 comments on commit d18fd47

Please sign in to comment.