From 480ede539a96a1979174dfc093f41786c48042bb Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Wed, 14 Aug 2024 11:17:57 +0200 Subject: [PATCH] llvm: use builtin llvm.uadd.with.overflow.iXXX to try to generate optimal code (and fail for i320 and i384 https://github.com/llvm/llvm-project/issues/103717) --- constantine/math_compiler/impl_fields_sat.nim | 55 ++++-------- constantine/math_compiler/ir.nim | 8 ++ .../platforms/llvm/super_instructions.nim | 86 +++++++++++++++---- research/codegen/x86_poc.nim | 31 ++++++- 4 files changed, 124 insertions(+), 56 deletions(-) diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim index a7e66545..cd52f96f 100644 --- a/constantine/math_compiler/impl_fields_sat.nim +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -78,7 +78,7 @@ import const SectionName = "ctt.fields" -proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array, carry: ValueRef) = +proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M, carry: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -87,28 +87,22 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arr ## ## To be used when the final substraction can ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) - let t = asy.makeArray(fd.fieldTy) # Mask: contains 0xFFFF or 0x0000 let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry) # Now substract the modulus, and test a < M # (underflow) with the last borrow - var b: ValueRef - (b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1) - for i in 1 ..< fd.numWords: - (b, t[i]) = asy.br.subborrow(a[i], M[i], b) + let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) # If it underflows here, it means that it was - # smaller than the modulus and we don't need `scratch` - (b, _) = asy.br.subborrow(mask, fd.zero, b) + # smaller than the modulus and we don't need `a-M` + let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow) - for i in 0 ..< fd.numWords: - t[i] = asy.br.select(b, a[i], t[i]) + let t = asy.br.select(ctl, a, a_minus_M) + asy.store(rr, t) - asy.store(r, t) - -proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Array) = +proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -117,20 +111,15 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Arra ## ## To be used when the modulus does not use the full bitwidth of the storing words ## (say using 255 bits for the modulus out of 256 available in words) - let t = asy.makeArray(fd.fieldTy) # Now substract the modulus, and test a < M # (underflow) with the last borrow - var b: ValueRef - (b, t[0]) = asy.br.subborrow(a[0], M[0], fd.zero_i1) - for i in 1 ..< fd.numWords: - (b, t[i]) = asy.br.subborrow(a[i], M[i], b) - - # If it underflows here a was smaller than the modulus, which is what we want - for i in 0 ..< fd.numWords: - t[i] = asy.br.select(b, a[i], t[i]) + let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) - asy.store(r, t) + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + let t = asy.br.select(borrow, a, a_minus_M) + asy.store(rr, t) proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = ## Generate an optimized modular addition kernel @@ -138,7 +127,7 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = let red = if fd.spareBits >= 1: "noo" else: "mayo" - let name = "_modadd_" & red & "_u" & $fd.w & "x" & $fd.numWords + let name = "_modadd_" & red & ".u" & $fd.w & "x" & $fd.numWords asy.llvmInternalFnDef( name, SectionName, asy.void_t, toTypes([r, a, b, M]), @@ -149,21 +138,15 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = let (rr, aa, bb, MM) = llvmParams # Pointers are opaque in LLVM now - let r = asy.asArray(rr, fd.fieldTy) - let a = asy.asArray(aa, fd.fieldTy) - let b = asy.asArray(bb, fd.fieldTy) - let M = asy.asArray(MM, fd.fieldTy) - - let apb = asy.makeArray(fd.fieldTy) - var c: ValueRef - (c, apb[0]) = asy.br.addcarry(a[0], b[0], fd.zero_i1) - for i in 1 ..< fd.numWords: - (c, apb[i]) = asy.br.addcarry(a[i], b[i], c) + let a = asy.load2(fd.intBufTy, aa, "a") + let b = asy.load2(fd.intBufTy, bb, "b") + let M = asy.load2(fd.intBufTy, MM, "M") + let (carry, apb) = asy.br.llvm_add_overflow(a, b) if fd.spareBits >= 1: - asy.finalSubNoOverflow(fd, r, apb, M) + asy.finalSubNoOverflow(fd, rr, apb, M) else: - asy.finalSubMayOverflow(fd, r, apb, M, c) + asy.finalSubMayOverflow(fd, rr, apb, M, carry) asy.br.retVoid() diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 51acb1fa..ab80e036 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -196,6 +196,11 @@ proc configureField*(ctx: ContextRef, result.spareBits = uint8(next_multiple_wordsize - modBits) proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) = + asy.ctx.def_llvm_add_overflow(asy.module, fd.wordTy) + asy.ctx.def_llvm_add_overflow(asy.module, fd.intBufTy) + asy.ctx.def_llvm_sub_overflow(asy.module, fd.wordTy) + asy.ctx.def_llvm_sub_overflow(asy.module, fd.intBufTy) + asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy) asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy) @@ -524,3 +529,6 @@ proc callFn*( template load2*(asy: Assembler_LLVM, ty: TypeRef, `ptr`: ValueRef, name: cstring = ""): ValueRef = asy.br.load2(ty, `ptr`, name) + +template store*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") = + asy.br.store(src, dst) diff --git a/constantine/platforms/llvm/super_instructions.nim b/constantine/platforms/llvm/super_instructions.nim index db678e95..08646288 100644 --- a/constantine/platforms/llvm/super_instructions.nim +++ b/constantine/platforms/llvm/super_instructions.nim @@ -81,7 +81,7 @@ proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefi const SectionName = "ctt.superinstructions" -proc getInstrName(baseName: string, ty: TypeRef): string = +proc getInstrName(baseName: string, ty: TypeRef, builtin = false): string = var w, v: int # Wordsize and vector size if ty.getTypeKind() == tkInteger: w = int ty.getIntTypeWidth() @@ -93,8 +93,67 @@ proc getInstrName(baseName: string, ty: TypeRef): string = doAssert false, "Invalid input type: " & $ty return baseName & - (if v != 1: "_v" & $v else: "_") & - "u" & $w + (if v != 1: ".v" & $v else: ".") & + (if builtin: "i" else: "u") & $w + + +proc def_llvm_add_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = + let name = "llvm.uadd.with.overflow".getInstrName(wordTy, builtin = true) + + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = m.getFunction(cstring name) + if fn.pointer.isNil(): + let retTy = ctx.struct_t([wordTy, ctx.int1_t()]) + let fnTy = function_t(retTy, [wordTy, wordTy]) + discard m.addFunction(cstring name, fnTy) + +proc llvm_add_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[carryOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let intrin_name = "llvm.uadd.with.overflow".getInstrName(ty, builtin = true) + + let fn = br.getCurrentModule().getFunction(cstring intrin_name) + doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n" + + let ctx = br.getContext() + + let retTy = ctx.struct_t([ty, ctx.int1_t()]) + let fnTy = function_t(retTy, [ty, ty]) + let addo = br.call2(fnTy, fn, [a, b], cstring name) + let lo = br.extractValue(addo, 0, cstring(name & ".lo")) + let cOut = br.extractValue(addo, 1, cstring(name & ".carry")) + return (cOut, lo) + +proc def_llvm_sub_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = + let name = "llvm.usub.with.overflow".getInstrName(wordTy, builtin = true) + + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = m.getFunction(cstring name) + if fn.pointer.isNil(): + let retTy = ctx.struct_t([wordTy, ctx.int1_t()]) + let fnTy = function_t(retTy, [wordTy, wordTy]) + discard m.addFunction(cstring name, fnTy) + +proc llvm_sub_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[borrowOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let intrin_name = "llvm.usub.with.overflow".getInstrName(ty, builtin = true) + + let fn = br.getCurrentModule().getFunction(cstring intrin_name) + doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n" + + let ctx = br.getContext() + + let retTy = ctx.struct_t([ty, ctx.int1_t()]) + let fnTy = function_t(retTy, [ty, ty]) + let subo = br.call2(fnTy, fn, [a, b], cstring name) + let lo = br.extractValue(subo, 0, cstring(name & ".lo")) + let bOut = br.extractValue(subo, 1, cstring(name & ".borrow")) + return (bOut, lo) template defSuperInstruction[N: static int]( module: ModuleRef, baseName: string, @@ -139,11 +198,9 @@ proc def_addcarry*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = m.defSuperInstruction("addcarry", retType, inType): let (a, b, carryIn) = llvmParams - let add = br.add(a, b, name = "a_plus_b") - let carry0 = br.icmp(kULT, add, b, name = "carry0") + let (carry0, add) = br.llvm_add_overflow(a, b, "a_plus_b") let cIn = br.zext(carryIn, wordTy, name = "carryIn") - let adc = br.add(cIn, add, name = "a_plus_b_plus_cIn") - let carry1 = br.icmp(kULT, adc, add, name = "carry1") + let (carry1, adc) = br.llvm_add_overflow(cIn, add, "a_plus_b_plus_cIn") let carryOut = br.`or`(carry0, carry1, name = "carryOut") var ret = br.insertValue(poison(retType), adc, 1, "lo") @@ -163,11 +220,10 @@ proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: Valu let fnTy = function_t(retTy, [ty, ty, tyC]) let adc = br.call2(fnTy, fn, [a, b, carryIn], name = "adc") adc.setInstrCallConv(Fast) - let lo = br.extractValue(adc, 1, name = "adcLo") - let cOut = br.extractValue(adc, 0, name = "adcC") + let lo = br.extractValue(adc, 1, name = "adc.lo") + let cOut = br.extractValue(adc, 0, name = "adc.carry") return (cOut, lo) - proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = ## Define (borrowOut, result) <- a-b-borrowIn @@ -177,11 +233,9 @@ proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = m.defSuperInstruction("subborrow", retType, inType): let (a, b, borrowIn) = llvmParams - let sub = br.sub(a, b, name = "a_minus_b") - let borrow0 = br.icmp(kULT, a, b, name = "borrow0") + let (borrow0, sub) = br.llvm_sub_overflow(a, b, "a_minus_b") let bIn = br.zext(borrowIn, wordTy, name = "borrowIn") - let sbb = br.sub(sub, bIn, name = "sbb") - let borrow1 = br.icmp(kULT, sub, bIn, name = "borrow1") + let (borrow1, sbb) = br.llvm_sub_overflow(sub, bIn, "sbb") let borrowOut = br.`or`(borrow0, borrow1, name = "borrowOut") var ret = br.insertValue(poison(retType), sbb, 1, "lo") @@ -201,8 +255,8 @@ proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: V let fnTy = function_t(retTy, [ty, ty, tyC]) let sbb = br.call2(fnTy, fn, [a, b, borrowIn], name = "sbb") sbb.setInstrCallConv(Fast) - let lo = br.extractValue(sbb, 1, name = "sbbLo") - let bOut = br.extractValue(sbb, 0, name = "sbbB") + let lo = br.extractValue(sbb, 1, name = "sbb.lo") + let bOut = br.extractValue(sbb, 0, name = "sbb.borrow") return (bOut, lo) proc mulExt*(bld: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] = diff --git a/research/codegen/x86_poc.nim b/research/codegen/x86_poc.nim index a460a822..4ec5b67f 100644 --- a/research/codegen/x86_poc.nim +++ b/research/codegen/x86_poc.nim @@ -36,9 +36,32 @@ const Fields = [ "bls12_381_fr", 255, "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001" ), + ( + "bls12_377_fp", 377, + "01ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001" + ), + ( + "bls12_377_fr", 253, + "12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001" + ), + ( + "bls24_315_fp", 315, + "4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001" + ), + ( + "bls12_315_fr", 253, + "196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001" + ), + ( + "bls24_317_fp", 317, + "1058CA226F60892CF28FC5A0B7F9D039169A61E684C73446D6F339E43424BF7E8D512E565DAB2AAB" + ), + ( + "bls12_317_fr", 255, + "443F917EA68DAFC2D0B097F28D83CD491CD1E79196BF0E7AF000000000000001" + ), ] - proc t_field_add() = let asy = Assembler_LLVM.new(bkX86_64_Linux, cstring("x86_poc")) for F in Fields: @@ -83,7 +106,7 @@ proc t_field_add() = # - and contrary to what is claimed in https://llvm.org/docs/NewPassManager.html#id2 # the C API of PassBuilderRef is ghost town. # - # So we reproduce the optimization passes from + # So we somewhat reproduce the optimization passes from # https://reviews.llvm.org/D145835 let pbo = createPassBuilderOptions() @@ -94,8 +117,8 @@ proc t_field_add() = ",function(aa-eval)" & ",always-inline,hotcoldsplit,inferattrs,instrprof,recompute-globalsaa" & ",cgscc(argpromotion,function-attrs)" & - # ",require,partial-inliner,called-value-propagation" & - # ",scc-oz-module-inliner,inline-wrapper,module-inline" & # Buggy optimization + ",require,partial-inliner,called-value-propagation" & + ",scc-oz-module-inliner,module-inline" & # Buggy optimization ",function(verify,loop-mssa(loop-reduce),mergeicmps,expand-memcmp,instsimplify)" & ",function(lower-constant-intrinsics,consthoist,partially-inline-libcalls,ee-instrument,scalarize-masked-mem-intrin,verify)" & ",memcpyopt,sroa,dse,aggressive-instcombine,gvn,ipsccp,deadargelim,adce" &