Skip to content

Commit

Permalink
llvm: use builtin llvm.uadd.with.overflow.iXXX to try to generate opt…
Browse files Browse the repository at this point in the history
…imal code (and fail for i320 and i384 llvm/llvm-project#103717)
  • Loading branch information
mratsim committed Aug 14, 2024
1 parent b415418 commit 480ede5
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 56 deletions.
55 changes: 19 additions & 36 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ import

const SectionName = "ctt.fields"

proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array, carry: ValueRef) =
proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M, carry: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -87,28 +87,22 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arr
##
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)
let t = asy.makeArray(fd.fieldTy)

# Mask: contains 0xFFFF or 0x0000
let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
var b: ValueRef
(b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(b, t[i]) = asy.br.subborrow(a[i], M[i], b)
let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
(b, _) = asy.br.subborrow(mask, fd.zero, b)
# smaller than the modulus and we don't need `a-M`
let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow)

for i in 0 ..< fd.numWords:
t[i] = asy.br.select(b, a[i], t[i])
let t = asy.br.select(ctl, a, a_minus_M)
asy.store(rr, t)

asy.store(r, t)

proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) =
proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -117,28 +111,23 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra
##
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)
let t = asy.makeArray(fd.fieldTy)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
var b: ValueRef
(b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(b, t[i]) = asy.br.subborrow(a[i], M[i], b)

# If it underflows here a was smaller than the modulus, which is what we want
for i in 0 ..< fd.numWords:
t[i] = asy.br.select(b, a[i], t[i])
let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M)

asy.store(r, t)
# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
let t = asy.br.select(borrow, a, a_minus_M)
asy.store(rr, t)

proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
## Generate an optimized modular addition kernel
## with parameters `a, b, modulus: Limbs -> Limbs`

let red = if fd.spareBits >= 1: "noo"
else: "mayo"
let name = "_modadd_" & red & "_u" & $fd.w & "x" & $fd.numWords
let name = "_modadd_" & red & ".u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, b, M]),
Expand All @@ -149,21 +138,15 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
let (rr, aa, bb, MM) = llvmParams

# Pointers are opaque in LLVM now
let r = asy.asArray(rr, fd.fieldTy)
let a = asy.asArray(aa, fd.fieldTy)
let b = asy.asArray(bb, fd.fieldTy)
let M = asy.asArray(MM, fd.fieldTy)

let apb = asy.makeArray(fd.fieldTy)
var c: ValueRef
(c, apb[0]) = asy.br.addcarry(a[0], b[0], fd.zero_i1)
for i in 1 ..< fd.numWords:
(c, apb[i]) = asy.br.addcarry(a[i], b[i], c)
let a = asy.load2(fd.intBufTy, aa, "a")
let b = asy.load2(fd.intBufTy, bb, "b")
let M = asy.load2(fd.intBufTy, MM, "M")

let (carry, apb) = asy.br.llvm_add_overflow(a, b)
if fd.spareBits >= 1:
asy.finalSubNoOverflow(fd, r, apb, M)
asy.finalSubNoOverflow(fd, rr, apb, M)
else:
asy.finalSubMayOverflow(fd, r, apb, M, c)
asy.finalSubMayOverflow(fd, rr, apb, M, carry)

asy.br.retVoid()

Expand Down
8 changes: 8 additions & 0 deletions constantine/math_compiler/ir.nim
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ proc configureField*(ctx: ContextRef,
result.spareBits = uint8(next_multiple_wordsize - modBits)

proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) =
asy.ctx.def_llvm_add_overflow(asy.module, fd.wordTy)
asy.ctx.def_llvm_add_overflow(asy.module, fd.intBufTy)
asy.ctx.def_llvm_sub_overflow(asy.module, fd.wordTy)
asy.ctx.def_llvm_sub_overflow(asy.module, fd.intBufTy)

asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy)
asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy)

Expand Down Expand Up @@ -524,3 +529,6 @@ proc callFn*(

template load2*(asy: Assembler_LLVM, ty: TypeRef, `ptr`: ValueRef, name: cstring = ""): ValueRef =
asy.br.load2(ty, `ptr`, name)

template store*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") =
asy.br.store(src, dst)
86 changes: 70 additions & 16 deletions constantine/platforms/llvm/super_instructions.nim
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefi

const SectionName = "ctt.superinstructions"

proc getInstrName(baseName: string, ty: TypeRef): string =
proc getInstrName(baseName: string, ty: TypeRef, builtin = false): string =
var w, v: int # Wordsize and vector size
if ty.getTypeKind() == tkInteger:
w = int ty.getIntTypeWidth()
Expand All @@ -93,8 +93,67 @@ proc getInstrName(baseName: string, ty: TypeRef): string =
doAssert false, "Invalid input type: " & $ty

return baseName &
(if v != 1: "_v" & $v else: "_") &
"u" & $w
(if v != 1: ".v" & $v else: ".") &
(if builtin: "i" else: "u") & $w


proc def_llvm_add_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) =
let name = "llvm.uadd.with.overflow".getInstrName(wordTy, builtin = true)

let br {.inject.} = ctx.createBuilder()
defer: br.dispose()

var fn = m.getFunction(cstring name)
if fn.pointer.isNil():
let retTy = ctx.struct_t([wordTy, ctx.int1_t()])
let fnTy = function_t(retTy, [wordTy, wordTy])
discard m.addFunction(cstring name, fnTy)

proc llvm_add_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[carryOut, r: ValueRef] =
## (cOut, result) <- a+b+cIn
let ty = a.getTypeOf()
let intrin_name = "llvm.uadd.with.overflow".getInstrName(ty, builtin = true)

let fn = br.getCurrentModule().getFunction(cstring intrin_name)
doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n"

let ctx = br.getContext()

let retTy = ctx.struct_t([ty, ctx.int1_t()])
let fnTy = function_t(retTy, [ty, ty])
let addo = br.call2(fnTy, fn, [a, b], cstring name)
let lo = br.extractValue(addo, 0, cstring(name & ".lo"))
let cOut = br.extractValue(addo, 1, cstring(name & ".carry"))
return (cOut, lo)

proc def_llvm_sub_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) =
let name = "llvm.usub.with.overflow".getInstrName(wordTy, builtin = true)

let br {.inject.} = ctx.createBuilder()
defer: br.dispose()

var fn = m.getFunction(cstring name)
if fn.pointer.isNil():
let retTy = ctx.struct_t([wordTy, ctx.int1_t()])
let fnTy = function_t(retTy, [wordTy, wordTy])
discard m.addFunction(cstring name, fnTy)

proc llvm_sub_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[borrowOut, r: ValueRef] =
## (cOut, result) <- a+b+cIn
let ty = a.getTypeOf()
let intrin_name = "llvm.usub.with.overflow".getInstrName(ty, builtin = true)

let fn = br.getCurrentModule().getFunction(cstring intrin_name)
doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n"

let ctx = br.getContext()

let retTy = ctx.struct_t([ty, ctx.int1_t()])
let fnTy = function_t(retTy, [ty, ty])
let subo = br.call2(fnTy, fn, [a, b], cstring name)
let lo = br.extractValue(subo, 0, cstring(name & ".lo"))
let bOut = br.extractValue(subo, 1, cstring(name & ".borrow"))
return (bOut, lo)

template defSuperInstruction[N: static int](
module: ModuleRef, baseName: string,
Expand Down Expand Up @@ -139,11 +198,9 @@ proc def_addcarry*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) =
m.defSuperInstruction("addcarry", retType, inType):
let (a, b, carryIn) = llvmParams

let add = br.add(a, b, name = "a_plus_b")
let carry0 = br.icmp(kULT, add, b, name = "carry0")
let (carry0, add) = br.llvm_add_overflow(a, b, "a_plus_b")
let cIn = br.zext(carryIn, wordTy, name = "carryIn")
let adc = br.add(cIn, add, name = "a_plus_b_plus_cIn")
let carry1 = br.icmp(kULT, adc, add, name = "carry1")
let (carry1, adc) = br.llvm_add_overflow(cIn, add, "a_plus_b_plus_cIn")
let carryOut = br.`or`(carry0, carry1, name = "carryOut")

var ret = br.insertValue(poison(retType), adc, 1, "lo")
Expand All @@ -163,11 +220,10 @@ proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: Valu
let fnTy = function_t(retTy, [ty, ty, tyC])
let adc = br.call2(fnTy, fn, [a, b, carryIn], name = "adc")
adc.setInstrCallConv(Fast)
let lo = br.extractValue(adc, 1, name = "adcLo")
let cOut = br.extractValue(adc, 0, name = "adcC")
let lo = br.extractValue(adc, 1, name = "adc.lo")
let cOut = br.extractValue(adc, 0, name = "adc.carry")
return (cOut, lo)


proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) =
## Define (borrowOut, result) <- a-b-borrowIn

Expand All @@ -177,11 +233,9 @@ proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) =
m.defSuperInstruction("subborrow", retType, inType):
let (a, b, borrowIn) = llvmParams

let sub = br.sub(a, b, name = "a_minus_b")
let borrow0 = br.icmp(kULT, a, b, name = "borrow0")
let (borrow0, sub) = br.llvm_sub_overflow(a, b, "a_minus_b")
let bIn = br.zext(borrowIn, wordTy, name = "borrowIn")
let sbb = br.sub(sub, bIn, name = "sbb")
let borrow1 = br.icmp(kULT, sub, bIn, name = "borrow1")
let (borrow1, sbb) = br.llvm_sub_overflow(sub, bIn, "sbb")
let borrowOut = br.`or`(borrow0, borrow1, name = "borrowOut")

var ret = br.insertValue(poison(retType), sbb, 1, "lo")
Expand All @@ -201,8 +255,8 @@ proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: V
let fnTy = function_t(retTy, [ty, ty, tyC])
let sbb = br.call2(fnTy, fn, [a, b, borrowIn], name = "sbb")
sbb.setInstrCallConv(Fast)
let lo = br.extractValue(sbb, 1, name = "sbbLo")
let bOut = br.extractValue(sbb, 0, name = "sbbB")
let lo = br.extractValue(sbb, 1, name = "sbb.lo")
let bOut = br.extractValue(sbb, 0, name = "sbb.borrow")
return (bOut, lo)

proc mulExt*(bld: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] =
Expand Down
31 changes: 27 additions & 4 deletions research/codegen/x86_poc.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,32 @@ const Fields = [
"bls12_381_fr", 255,
"73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001"
),
(
"bls12_377_fp", 377,
"01ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001"
),
(
"bls12_377_fr", 253,
"12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001"
),
(
"bls24_315_fp", 315,
"4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001"
),
(
"bls12_315_fr", 253,
"196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001"
),
(
"bls24_317_fp", 317,
"1058CA226F60892CF28FC5A0B7F9D039169A61E684C73446D6F339E43424BF7E8D512E565DAB2AAB"
),
(
"bls12_317_fr", 255,
"443F917EA68DAFC2D0B097F28D83CD491CD1E79196BF0E7AF000000000000001"
),
]


proc t_field_add() =
let asy = Assembler_LLVM.new(bkX86_64_Linux, cstring("x86_poc"))
for F in Fields:
Expand Down Expand Up @@ -83,7 +106,7 @@ proc t_field_add() =
# - and contrary to what is claimed in https://llvm.org/docs/NewPassManager.html#id2
# the C API of PassBuilderRef is ghost town.
#
# So we reproduce the optimization passes from
# So we somewhat reproduce the optimization passes from
# https://reviews.llvm.org/D145835

let pbo = createPassBuilderOptions()
Expand All @@ -94,8 +117,8 @@ proc t_field_add() =
",function(aa-eval)" &
",always-inline,hotcoldsplit,inferattrs,instrprof,recompute-globalsaa" &
",cgscc(argpromotion,function-attrs)" &
# ",require<inline-advisor>,partial-inliner,called-value-propagation" &
# ",scc-oz-module-inliner,inline-wrapper,module-inline" & # Buggy optimization
",require<inline-advisor>,partial-inliner,called-value-propagation" &
",scc-oz-module-inliner,module-inline" & # Buggy optimization
",function(verify,loop-mssa(loop-reduce),mergeicmps,expand-memcmp,instsimplify)" &
",function(lower-constant-intrinsics,consthoist,partially-inline-libcalls,ee-instrument<post-inline>,scalarize-masked-mem-intrin,verify)" &
",memcpyopt,sroa,dse,aggressive-instcombine,gvn,ipsccp,deadargelim,adce" &
Expand Down

0 comments on commit 480ede5

Please sign in to comment.