Skip to content

Commit

Permalink
stash prep for Barret Reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Oct 10, 2023
1 parent 4dd0a02 commit 30ccbb7
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import
./limbs_mod2k,
./limbs_multiprec,
./limbs_extmul,
./limbs_divmod
./limbs_divmod,
./limbs_divmod_vartime

# No exceptions allowed
{.push raises: [], checks: off.}
Expand Down Expand Up @@ -61,6 +62,7 @@ func powOddMod_vartime*(

if eBits == 1:
r.view().reduce(a.view(), aBits, M.view(), mBits)
# discard r.reduce_vartime(a, M)
return

let L = wordsRequired(mBits)
Expand All @@ -77,8 +79,9 @@ func powOddMod_vartime*(
# For now, we call explicit reduction as it can handle all sizes.
# TODO: explicit reduction uses constant-time division which is **very** expensive
if a.len != M.len:
let t = allocStackArray(SecretWord, L)
var t = allocStackArray(SecretWord, L)
t.LimbsViewMut.reduce(a.view(), aBits, M.view(), mBits)
# discard t.toOpenArray(0, L-1).reduce_vartime(a, M)
rMont.LimbsViewMut.getMont(LimbsViewConst t, M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
else:
rMont.LimbsViewMut.getMont(a.view(), M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
Expand Down
22 changes: 11 additions & 11 deletions constantine/math_arbitrary_precision/arithmetic/limbs_divmod.nim
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,
a1 = (a[^1] shl (WordBitWidth-R)) or (a[^2] shr R)
m0 = (M[^1] shl (WordBitWidth-R)) or (M[^2] shr R)

# m0 has its high bit set. (a0, a1)/p0 fits in a limb.
# m0 has its high bit set. (a0, a1)/m0 fits in a limb.
# Get a quotient q, at most we will be 2 iterations off
# from the true quotient
var q, r: SecretWord
Expand All @@ -78,29 +78,29 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,

# Now substract a*2^64 - q*p
var carry = Zero
var over_p = CtTrue # Track if quotient greater than the modulus
var overM = CtTrue # Track if quotient greater than the modulus

for i in 0 ..< MLen:
var qp_lo: SecretWord
var qm_lo: SecretWord

block: # q*p
# q * p + carry (doubleword) carry from previous limb
muladd1(carry, qp_lo, q, M[i], carry)
block: # q*m
# q * m + carry (doubleword) carry from previous limb
muladd1(carry, qm_lo, q, M[i], carry)

block: # a*2^64 - q*p
var borrow: Borrow
subB(borrow, a[i], a[i], qp_lo, Borrow(0))
subB(borrow, a[i], a[i], qm_lo, Borrow(0))
carry += SecretWord(borrow) # Adjust if borrow

over_p = mux(a[i] == M[i], over_p, a[i] > M[i])
overM = mux(a[i] == M[i], overM, a[i] > M[i])

# Fix quotient, the true quotient is either q-1, q or q+1
#
# if carry < q or carry == q and over_p we must do "a -= p"
# if carry > hi (negative result) we must do "a += p"
# if carry < q or carry == q and over_p we must do "a -= m"
# if carry > hi (negative result) we must do "a += m"

result.neg = carry > hi
result.tooBig = not(result.neg) and (over_p or (carry < hi))
result.tooBig = not(result.neg) and (overM or (carry < hi))

func shlAddMod(a: LimbsViewMut, aLen: int,
c: SecretWord, M: LimbsViewConst, mBits: int) =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
../../platforms/abstractions,
../../platforms/intrinsics/extended_precision_vartime,
./limbs_views,
./limbs_fixedprec

# No exceptions allowed
{.push raises: [].}

# ############################################################
#
# Division and Modular Reduction
# (variable-time)
#
# ############################################################

func shlAddMod_multiprec_vartime(
a: var openArray[SecretWord], c: SecretWord,
M: openArray[SecretWord], mBits: int): SecretWord {.meter.} =
## Fused modular left-shift + add
## Computes: a <- a shl 2ʷ + c (mod M)
## Returns: (a shl 2ʷ + c) / M
##
## with w the base word width, usually 32 on 32-bit platforms and 64 on 64-bit platforms
##
## The modulus `M` most-significant bit at `mBits` MUST be set.
##
## Specialized for M being a multi-precision integer.
# Assuming 64-bit words
let hi = a[^1] # Save the high word to detect carries
let R = mBits and (WordBitWidth - 1) # R = mBits mod 64

var a0, a1, m0: SecretWord
if R == 0: # If the number of mBits is a multiple of 64
a0 = a[^1] #
copyWords(a.view(), 1, # we can just shift words
a.view(), 0, a.len-1) #
a[0] = c # and replace the first one by c
a1 = a[^1]
m0 = M[^1]
else: # Else: need to deal with partial word shifts at the edge.
let clz = WordBitWidth-R
a0 = (a[^1] shl clz) or (a[^2] shr R)
copyWords(a.view(), 1,
a.view(), 0, a.len-1)
a[0] = c
a1 = (a[^1] shl clz) or (a[^2] shr R)
m0 = (M[^1] shl clz) or (M[^2] shr R)

# m0 has its high bit set. (a0, a1)/m0 fits in a limb.
# Get a quotient q, at most we will be 2 iterations off
# from the true quotient
var q: SecretWord # Estimate quotient
if bool(a0 == m0): # if a_hi == divisor
q = MaxWord # quotient = MaxWord (0b1111...1111)
elif bool(a0.isZero()) and
bool(a1 < m0): # elif q == 0, true quotient = 0
q = Zero
return q
else:
var r: SecretWord
div2n1n(q, r, a0, a1, m0) # else instead of being of by 0, 1 or 2
q -= One # we return q-1 to be off by -1, 0 or 1

# Now substract a*2^64 - q*m
var carry = Zero
var overM = true # Track if quotient greater than the modulus

for i in 0 ..< M.len:
var qm_lo: SecretWord
block: # q*m
# q * m + carry (doubleword) carry from previous limb
muladd1(carry, qm_lo, q, M[i], carry)

block: # a*2^64 - q*m
var borrow: Borrow
subB(borrow, a[i], a[i], qm_lo, Borrow(0))
carry += SecretWord(borrow) # Adjust if borrow

if bool(a[i] != M[i]):
overM = bool(a[i] > M[i])

# Fix quotient, the true quotient is either q-1, q or q+1
#
# if carry < q or carry == q and overM we must do "a -= M"
# if carry > hi (negative result) we must do "a += M"
if bool(carry > hi):
var c = Carry(0)
for i in 0 ..< a.len:
addC(c, a[i], a[i], M[i], c)
q -= One
elif overM or bool(carry < hi):
var b = Borrow(0)
for i in 0 ..< a.len:
subB(b, a[i], a[i], M[i], b)
q += One

return q

func shlAddMod_vartime(a: var openArray[SecretWord], c: SecretWord,
M: openArray[SecretWord], mBits: int): SecretWord {.meter.} =
## Fused modular left-shift + add
## Computes: a <- a shl 2ʷ + c (mod M)
## Returns: (a shl 2ʷ + c) / M
##
## with w the base word width, usually 32 on 32-bit platforms and 64 on 64-bit platforms
##
## The modulus `M` most-significant bit at `mBits` MUST be set.
if mBits <= WordBitWidth:
# If M fits in a single limb

# We normalize M with clz so that the MSB is set
# And normalize (a * 2^64 + c) by R as well to maintain the result
# This ensures that (a0, a1)/m0 fits in a limb.
let R = mBits and (WordBitWidth - 1)

# (hi, lo) = a * 2^64 + c
if R == 0:
# We can delegate this R == 0 case to the
# shlAddMod_multiprec, with the same result.
# But isn't it faster to handle it here?
var q, r: SecretWord
div2n1n_vartime(q, r, a[0], c, M[0])
a[0] = r
return q
else:
let clz = WordBitWidth-R
let hi = (a[0] shl clz) or (c shr R)
let lo = c shl clz
let m0 = M[0] shl clz

var q, r: SecretWord
div2n1n(q, r, hi, lo, m0)
a[0] = r shr clz
return q
else:
return shlAddMod_multiprec_vartime(a, c, M, mBits)

func divRem_vartime*(
q, r: var openArray[SecretWord],
a, b: openArray[SecretWord]): bool {.meter.} =
# q = a div b
# r = a mod b
#
# Requirements:
# b != 0
# q.len > a.len - b.len
# r.len >= b.len

let aBits = getBits_LE_vartime(a)
let bBits = getBits_LE_vartime(b)
let aLen = wordsRequired(aBits)
let bLen = wordsRequired(bBits)
let rLen = bLen

let aOffset = a.len - b.len

# Note: don't confuse a.len and aLen (actually used words)

if unlikely(bool(r.len < bLen)):
# remainder buffer cannot store up to modulus size
return false

if unlikely(bool(q.len < aOffset+1)):
# quotient buffer cannot store up to quotient size
return false

if unlikely(bBits == 0):
# Divide by zero
return false

if aBits < bBits:
# if a uses less bits than b,
# a < b, so q = 0 and r = a
copyWords(r.view(), 0, a.view(), 0, aLen)
for i in aLen ..< r.len:
r[i] = Zero
for i in 0 ..< q.len:
q[i] = Zero
else:
# The length of a is at least the divisor
# We can copy bLen-1 words
# and modular shift-left-add the rest

copyWords(r.view(), 0, a.view(), aOffset+1, b.len-1)
r[rLen-1] = Zero
# Now shift-left the copied words while adding the new word mod b

for i in countdown(aOffset, 0):
q[i] = shlAddMod_multiprec_vartime(
r.toOpenArray(0, rLen-1),
a[i],
b.toOpenArray(0, bLen-1),
bBits)

# Clean up extra words
for i in aOffset+1 ..< q.len:
q[i] = Zero
for i in rLen ..< r.len:
r[i] = Zero

return true

func reduce_vartime*(r: var openArray[SecretWord],
a, b: openArray[SecretWord]): bool {.noInline, meter.} =
let aBits = getBits_LE_vartime(a)
let bBits = getBits_LE_vartime(b)
let aLen = wordsRequired(aBits)
let bLen = wordsRequired(bBits)

let aOffset = a.len - b.len
var qBuf = allocHeapArray(SecretWord, aOffset+1)
template q: untyped = qBuf.toOpenArray(0, aOffset)
result = divRem_vartime(q, r, a, b)
freeHeap(qBuf)

# ############################################################
#
# Barrett Reduction
#
# ############################################################

# - https://en.wikipedia.org/wiki/Barrett_reduction
# - Handbook of Applied Cryptography
# Alfred J. Menezes, Paul C. van Oorschot and Scott A. Vanstone
# https://cacr.uwaterloo.ca/hac/about/chap14.pdf
# - Modern Computer Arithmetic
# Richard P. Brent and Paul Zimmermann
# https://members.loria.fr/PZimmermann/mca/mca-cup-0.5.9.pdf
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import
./limbs_views,
./limbs_mod,
./limbs_fixedprec,
./limbs_divmod
./limbs_divmod,
./limbs_divmod_vartime

# No exceptions allowed
{.push raises: [], checks: off.}
Expand Down Expand Up @@ -72,13 +73,13 @@ func oneMont_vartime*(r: var openArray[SecretWord], M: openArray[SecretWord]) {.

# r.r_powmod_vartime(M, 1)

let mBits = getBits_LE_vartime(M)

let t = allocStackArray(SecretWord, M.len + 1)
zeroMem(t, M.len*sizeof(SecretWord))
t[M.len] = One

let mBits = getBits_LE_vartime(M)
r.view().reduce(LimbsViewMut t, M.len*WordBitWidth+1, M.view(), mBits)
# discard r.reduce_vartime(t.toOpenArray(0, M.len), M)

func r2_vartime*(r: var openArray[SecretWord], M: openArray[SecretWord]) {.meter.} =
## Returns the Montgomery domain magic constant for the input modulus:
Expand All @@ -90,14 +91,13 @@ func r2_vartime*(r: var openArray[SecretWord], M: openArray[SecretWord]) {.meter

# r.r_powmod_vartime(M, 2)

let mBits = getBits_LE_vartime(M)

let t = allocStackArray(SecretWord, 2*M.len + 1)
zeroMem(t, 2*M.len*sizeof(SecretWord))
t[2*M.len] = One

let mBits = getBits_LE_vartime(M)
r.view().reduce(LimbsViewMut t, 2*M.len*WordBitWidth+1, M.view(), mBits)

# discard r.reduce_vartime(t.toOpenArray(0, 2*M.len), M)

# Montgomery multiplication
# ------------------------------------------
Expand Down
Loading

0 comments on commit 30ccbb7

Please sign in to comment.