Skip to content

Commit

Permalink
codegen: truncate Float16 vector ops also (JuliaLang#46130)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash authored and Francesco Fucci committed Aug 11, 2022
1 parent 2b532c2 commit 02940c4
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 62 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ configure-y: | $(BUILDDIRMAKE)
configure:
ifeq ("$(origin O)", "command line")
@if [ "$$(ls '$(BUILDROOT)' 2> /dev/null)" ]; then \
echo 'WARNING: configure called on non-empty directory $(BUILDROOT)'; \
printf $(WARNCOLOR)'WARNING: configure called on non-empty directory'$(ENDCOLOR)' %s\n' '$(BUILDROOT)'; \
read -p "Proceed [y/n]? " answer; \
else \
answer=y;\
fi; \
[ $$answer = 'y' ] && $(MAKE) configure-$$answer
[ "y$$answer" = yy ] && $(MAKE) configure-$$answer
else
$(error "cannot rerun configure from within a build directory")
endif
Expand Down
120 changes: 63 additions & 57 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,21 @@ namespace {
static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -64,6 +73,7 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
// TODO: Do calls to llvm.fma.f16 may need to go to f64 to be correct?
continue;
}

Expand All @@ -75,72 +85,68 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
if (Op->getType()->getScalarType()->isHalfTy()) {
++TotalExt;
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
}
Operands[i] = (Op);
Operands[i] = Op;
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
24 changes: 21 additions & 3 deletions test/llvmpasses/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,29 @@ include $(JULIAHOME)/Make.inc

check: .

TESTS = $(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.jl $(SRCDIR)/*.ll))
TESTS_ll := $(filter-out update-%,$(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.ll)))
TESTS_jl := $(patsubst $(SRCDIR)/%,%,$(wildcard $(SRCDIR)/*.jl))
TESTS := $(TESTS_ll) $(TESTS_jl)

. $(TESTS):
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(build_depsbindir)/lit/lit.py -v $(addprefix $(SRCDIR)/,$@)
$(build_depsbindir)/lit/lit.py -v "$(addprefix $(SRCDIR)/,$@)"

.PHONY: $(TESTS) check all .
$(addprefix update-,$(TESTS_ll)):
@echo 'NOTE: This requires a llvm source files locally, such as via `make -C deps USE_BINARYBUILDER_LLVM=0 DEPS_GIT=llvm checkout-llvm`'
@read -p "$$(printf $(WARNCOLOR)'This will directly modify %s, are you sure you want to proceed? '$(ENDCOLOR) '$@')" REPLY && [ yy = "y$$REPLY" ]
sed -e 's/%shlibext/.$(SHLIB_EXT)/g' < "$(@:update-%=$(SRCDIR)/%)" > "$@"
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(JULIAHOME)/deps/srccache/llvm/llvm/utils/update_test_checks.py "$@" \
--preserve-names
mv "$@" "$(@:update-%=$(SRCDIR)/%)"

update-help:
PATH=$(build_bindir):$(build_depsbindir):$$PATH \
LD_LIBRARY_PATH=${build_libdir}:$$LD_LIBRARY_PATH \
$(JULIAHOME)/deps/srccache/llvm/llvm/utils/update_test_checks.py \
--help

.PHONY: $(TESTS) $(addprefix update-,$(TESTS_ll)) check all .
64 changes: 64 additions & 0 deletions test/llvmpasses/float16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s

define half @demotehalf_test(half %a, half %b) {
; CHECK-LABEL: @demotehalf_test(
; CHECK-NEXT: top:
; CHECK-NEXT: %0 = fpext half %a to float
; CHECK-NEXT: %1 = fpext half %b to float
; CHECK-NEXT: %2 = fadd float %0, %1
; CHECK-NEXT: %3 = fptrunc float %2 to half
; CHECK-NEXT: %4 = fpext half %3 to float
; CHECK-NEXT: %5 = fpext half %b to float
; CHECK-NEXT: %6 = fadd float %4, %5
; CHECK-NEXT: %7 = fptrunc float %6 to half
; CHECK-NEXT: %8 = fpext half %7 to float
; CHECK-NEXT: %9 = fpext half %b to float
; CHECK-NEXT: %10 = fadd float %8, %9
; CHECK-NEXT: %11 = fptrunc float %10 to half
; CHECK-NEXT: %12 = fpext half %11 to float
; CHECK-NEXT: %13 = fpext half %b to float
; CHECK-NEXT: %14 = fmul float %12, %13
; CHECK-NEXT: %15 = fptrunc float %14 to half
; CHECK-NEXT: %16 = fpext half %15 to float
; CHECK-NEXT: %17 = fpext half %b to float
; CHECK-NEXT: %18 = fdiv float %16, %17
; CHECK-NEXT: %19 = fptrunc float %18 to half
; CHECK-NEXT: %20 = insertelement <2 x half> undef, half %a, i32 0
; CHECK-NEXT: %21 = insertelement <2 x half> %20, half %b, i32 1
; CHECK-NEXT: %22 = insertelement <2 x half> undef, half %b, i32 0
; CHECK-NEXT: %23 = insertelement <2 x half> %22, half %b, i32 1
; CHECK-NEXT: %24 = fpext <2 x half> %21 to <2 x float>
; CHECK-NEXT: %25 = fpext <2 x half> %23 to <2 x float>
; CHECK-NEXT: %26 = fadd <2 x float> %24, %25
; CHECK-NEXT: %27 = fptrunc <2 x float> %26 to <2 x half>
; CHECK-NEXT: %28 = extractelement <2 x half> %27, i32 0
; CHECK-NEXT: %29 = extractelement <2 x half> %27, i32 1
; CHECK-NEXT: %30 = fpext half %28 to float
; CHECK-NEXT: %31 = fpext half %29 to float
; CHECK-NEXT: %32 = fadd float %30, %31
; CHECK-NEXT: %33 = fptrunc float %32 to half
; CHECK-NEXT: %34 = fpext half %33 to float
; CHECK-NEXT: %35 = fpext half %19 to float
; CHECK-NEXT: %36 = fadd float %34, %35
; CHECK-NEXT: %37 = fptrunc float %36 to half
; CHECK-NEXT: ret half %37
;
top:
%0 = fadd half %a, %b
%1 = fadd half %0, %b
%2 = fadd half %1, %b
%3 = fmul half %2, %b
%4 = fdiv half %3, %b
%5 = insertelement <2 x half> undef, half %a, i32 0
%6 = insertelement <2 x half> %5, half %b, i32 1
%7 = insertelement <2 x half> undef, half %b, i32 0
%8 = insertelement <2 x half> %7, half %b, i32 1
%9 = fadd <2 x half> %6, %8
%10 = extractelement <2 x half> %9, i32 0
%11 = extractelement <2 x half> %9, i32 1
%12 = fadd half %10, %11
%13 = fadd half %12, %4
ret half %13
}

0 comments on commit 02940c4

Please sign in to comment.