Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to upstream changes wrt. native support for BFloat16 #51

Merged
merged 7 commits into from
Nov 2, 2023

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Oct 4, 2023

This PR adapts BFloat16s.jl to JuliaLang/julia#51470, where I'm adding native support for BFloat16s to Julia (using the bfloat type in LLVM). I decided to keep as much functionality as possible in this package, so Base only defines Core.BFloat16 and the necessary codegen support.

The main benefit of this change is that we now emit drastically simpler IR, and rely on LLVM to lower it to something that the hardware supports. For example:

julia> test(x::T) where T = T(2) * x + one(T)

julia> test(BFloat16(1))
BFloat16(3.0)

Before this PR:

julia> @code_llvm debuginfo=:none test(BFloat16(1))
define i16 @julia_test_556(i16 zeroext %0) #0 {
top:
  %1 = zext i16 %0 to i32
  %2 = shl nuw i32 %1, 16
  %bitcast_coercion = bitcast i32 %2 to float
  %3 = fmul float %bitcast_coercion, 2.000000e+00
  %4 = fcmp ord float %3, 0.000000e+00
  br i1 %4, label %L38, label %L75

L38:                                              ; preds = %top
  %bitcast_coercion4 = bitcast float %3 to i32
  %5 = lshr i32 %bitcast_coercion4, 16
  %.op5 = and i32 %5, 1
  %6 = add i32 %bitcast_coercion4, 32767
  %7 = add i32 %6, %.op5
  %8 = and i32 %7, -65536
  %phi.cast = bitcast i32 %8 to float
  %phi.bo = fadd float %phi.cast, 1.000000e+00
  %9 = fcmp ord float %phi.bo, 0.000000e+00
  br i1 %9, label %L54, label %L75

L54:                                              ; preds = %L38
  %bitcast_coercion3 = bitcast float %phi.bo to i32
  %10 = lshr i32 %bitcast_coercion3, 16
  %11 = and i32 %10, 1
  %narrow = add nuw nsw i32 %11, 32767
  %12 = zext i32 %narrow to i64
  %13 = zext i32 %bitcast_coercion3 to i64
  %14 = add nuw nsw i64 %12, %13
  %15 = lshr i64 %14, 16
  %16 = trunc i64 %15 to i16
  br label %L75

L75:                                              ; preds = %L54, %L38, %top
  %value_phi2 = phi i16 [ %16, %L54 ], [ 32704, %L38 ], [ 32704, %top ]
  ret i16 %value_phi2
}

julia> @code_native debuginfo=:none test(BFloat16(1))
julia_test_560:                         # @julia_test_560
# %bb.0:                                # %top
	shl	edi, 16
	mov	ax, 32704
	vmovd	xmm0, edi
	vaddss	xmm0, xmm0, xmm0
	vucomiss	xmm0, xmm0
	jp	.LBB0_3
# %bb.1:                                # %L38
	push	rbp
	vmovd	ecx, xmm0
	movabs	rdx, offset .LCPI0_0
	mov	rbp, rsp
	bt	ecx, 16
	adc	ecx, 32767
	and	ecx, -65536
	vmovd	xmm0, ecx
	vaddss	xmm0, xmm0, dword ptr [rdx]
	vucomiss	xmm0, xmm0
	pop	rbp
	jp	.LBB0_3
# %bb.2:                                # %L54
	vmovd	eax, xmm0
	bt	eax, 16
	adc	eax, 32767
	shr	eax, 16
.LBB0_3:                                # %L75
                                        # kill: def $ax killed $ax killed $eax
	ret

Using this PR, on JuliaLang/julia#51470:

julia> @code_llvm debuginfo=:none test(BFloat16(1))
; Function Signature: test(Core.BFloat16)
define bfloat @julia_test_2911(bfloat %"x::BFloat16") #0 {
top:
  %0 = fpext bfloat %"x::BFloat16" to float
  %1 = fmul float %0, 2.000000e+00
  %2 = fptrunc float %1 to bfloat
  %3 = fpext bfloat %2 to float
  %4 = fadd float %3, 1.000000e+00
  %5 = fptrunc float %4 to bfloat
  ret bfloat %5
}

julia> @code_native debuginfo=:none test(BFloat16(1))
julia_test_3001:                        # @julia_test_3001
; Function Signature: test(Core.BFloat16)
# %bb.0:                                # %top
	#DEBUG_VALUE: test:x <- $xmm0
	push	rbp
	mov	rbp, rsp
	push	rbx
	sub	rsp, 8
	vmovd	eax, xmm0
	movabs	rbx, offset __truncsfbf2
	shl	eax, 16
	vmovd	xmm0, eax
	vaddss	xmm0, xmm0, xmm0
	call	rbx
	vmovd	eax, xmm0
	movabs	rcx, offset .LCPI0_0
	shl	eax, 16
	vmovd	xmm0, eax
	vaddss	xmm0, xmm0, dword ptr [rcx]
	call	rbx
	add	rsp, 8
	pop	rbx
	pop	rbp
	ret

So the LLVM IR is much simpler, while the native code is (as expected) similar in complexity.

Performance is hard to compare for such simple operations, but representing BFloat16s natively should make it possible for LLVM to optimize them, and also select better instructions when possible. For example, with a CPU supporting AVX512BF16 and LLVM 17, we compile:

define <16 x bfloat> @trunc(<16 x float>) {
    %2 = fptrunc <16 x float> %0 to <16 x bfloat>
    ret <16 x bfloat> %2
}

to:

trunc:                                  # @trunc
        vcvtneps2bf16   ymm0, zmm0
        ret

So this will make it possible to use BFloat16s.jl with our vectorization packages (by using NTuple{16,Core.VecElement{BFloat}}, which now lowers to <16 x bfloat>).


This PR also switches the significand implementation, as it contained undefined behavior (for one(BFloat16), isig is Int16(0)). The new implementation is copied from Base.

Closes #51

@codecov
Copy link

codecov bot commented Oct 4, 2023

Codecov Report

Attention: 54 lines in your changes are missing coverage. Please review.

Comparison is base (a42c4fa) 65.41% compared to head (75a6d23) 22.22%.

❗ Current head 75a6d23 differs from pull request most recent head ca97442. Consider uploading reports for the commit ca97442 to get more accurate results

Additional details and impacted files
@@             Coverage Diff             @@
##           master      #51       +/-   ##
===========================================
- Coverage   65.41%   22.22%   -43.20%     
===========================================
  Files           3        3               
  Lines         133      171       +38     
===========================================
- Hits           87       38       -49     
- Misses         46      133       +87     
Files Coverage Δ
src/bfloat16.jl 23.56% <23.94%> (-48.71%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@maleadt
Copy link
Member Author

maleadt commented Oct 6, 2023

Interestingly, even though bfloat was added to LLVM 11 specifically to support ARM intrinsics, the storage-level support is really limited. For example, just synthesizing a constant results in a selection error:

target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-linux-none"

define bfloat @julia_BFloat16_2304() {
top:
  ret bfloat 0xR0000
}
LLVM ERROR: Cannot select: 0x562986b35f10: bf16 = ConstantFP<APFloat(0)>

This is fixed on LLVM 17, but aarch64 still lacks arithmetic-level support there:

; ModuleID = 'f'
source_filename = "f"
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
target triple = "aarch64-none-eabi"

define float @julia_f_572(bfloat %"x::BFloat16") {
top:
  %0 = fpext bfloat %"x::BFloat16" to float
  ret float %0
}
LLVM ERROR: Cannot select: 0x558b0b8a7ac0: f32 = fp_extend 0x558b0b8a7a50
  0x558b0b8a7a50: bf16,ch = CopyFromReg 0x558b0b82d970, Register:bf16 %0
    0x558b0b8a79e0: bf16 = Register %0

On x86, both these work on LLVM 15+ which is the lower bound for this feature (as we only added Core.BFloat16 to Julia 1.11).

cc @vchuravy

@maleadt maleadt force-pushed the codegen branch 4 times, most recently from afb5922 to 35c4799 Compare October 10, 2023 13:26
@maleadt
Copy link
Member Author

maleadt commented Oct 10, 2023

Well this is weird. I cannot reproduce the CI failure on any system of mine. I thought it was ABI related, but it looks like LLVM somehow materializes a wrong constant here. BFloat16(1f0) is 0x3f80, and when we ask LLVM to emit IR that truncates 1f0 to BFloat16 (which gets const-propper by the IRBuilder) we do get that value, but not when Julia does so:

bfloat 0xR3F80

vs

ret bfloat 0xR3C00

@maleadt maleadt force-pushed the codegen branch 2 times, most recently from 07a1dec to d2f8bbd Compare October 10, 2023 13:59
@maleadt
Copy link
Member Author

maleadt commented Oct 10, 2023

Alright, found something that reproduces locally:

julia> f() = Core.Intrinsics.fptrunc(Core.BFloat16, 1f0)
f (generic function with 1 method)

julia> f()
Core.BFloat16(0x3c00)

julia> fptrunc(x) = Core.Intrinsics.fptrunc(Core.BFloat16, x)
fptrunc (generic function with 1 method)

julia> h() = fptrunc(1f0)
h (generic function with 1 method)

julia> h()
Core.BFloat16(0x3f80)

@maleadt maleadt marked this pull request as ready for review October 26, 2023 14:34
@maleadt maleadt force-pushed the codegen branch 3 times, most recently from 75a6d23 to ca97442 Compare November 2, 2023 07:29
@maleadt maleadt merged commit 730511b into JuliaMath:master Nov 2, 2023
48 of 54 checks passed
@maleadt maleadt deleted the codegen branch November 2, 2023 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant