Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inner sum in BGV #513

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Changelog
All notable changes to this library are documented in this file.

## [6.x.x] - 16.12.2024
- Refactoring of the InnerSum methods:
- `rlwe.Evaluator.InnerSum` has been replaced by `rlwe.Evaluator.PartialTracesSum`, which applies the automorphisms that correspond to rotations at the scheme level (and sum the results).
- Introduction of the `bgv.Evaluator.InnerSum` and `ckks.Evaluator.InnerSum` methods, which have the same behaviour as the old `InnerSum` method for parameters `n` and `batchSize` s.t. `0 < n*batchSize <= ctIn.Slots()` divides the number of slots. Parameters not satisfying these conditions are rejected.
- Introduction of the `bgv.Evaluator.RotateAndAdd` and `ckks.Evaluator.RotateAndAdd` methods, which have the same behaviour as the old `InnerSum` method for all parameters.

## [6.1.0] - 04.10.2024
- Update of `PrecisionStats` in `ckks/precision.go`:
- The precision is now computed as the min/max/average/... of the log of the error (instead of the log of the min/max/average/... of the error).
Expand Down
34 changes: 11 additions & 23 deletions core/rlwe/inner_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,14 @@ func GaloisElementsForTrace(params ParameterProvider, logN int) (galEls []uint64
return
}

// InnerSum applies an optimized inner sum on the Ciphertext (log2(n) + HW(n) rotations with double hoisting).
// The operation assumes that `ctIn` encrypts Slots/`batchSize` sub-vectors of size `batchSize` and will add them together (in parallel) in groups of `n`.
// It outputs in opOut a [Ciphertext] for which the "leftmost" sub-vector of each group is equal to the sum of the group.
//
// The inner sum is computed in a tree fashion. Example for batchSize=2 & n=4 (garbage slots are marked by 'x'):
//
// 1. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
//
// 2. [{a, b}, {c, d}, {e, f}, {g, h}, {a, b}, {c, d}, {e, f}, {g, h}]
// +
// [{c, d}, {e, f}, {g, h}, {x, x}, {c, d}, {e, f}, {g, h}, {x, x}] (rotate batchSize * 2^{0})
// =
// [{a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}, {a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}]
//
// 3. [{a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}, {a+c, b+d}, {x, x}, {e+g, f+h}, {x, x}] (rotate batchSize * 2^{1})
// +
// [{e+g, f+h}, {x, x}, {x, x}, {x, x}, {e+g, f+h}, {x, x}, {x, x}, {x, x}] =
// =
// [{a+c+e+g, b+d+f+h}, {x, x}, {x, x}, {x, x}, {a+c+e+g, b+d+f+h}, {x, x}, {x, x}, {x, x}]
func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Ciphertext) (err error) {
// PartialTracesSum applies a set of automorphisms on the input ciphertext and sum the results.
// The automorphisms are of the form phi(i*offset, X), 0 <= i < n, where phi(k, X): X -> X^{5^k}
// i.e. opOut = \sum_{i = 0}^{n-1} phi(i*offset, ctIn).
// At the scheme level, this function is used to perform inner sums or efficiently replicate slots.
func (eval Evaluator) PartialTracesSum(ctIn *Ciphertext, offset, n int, opOut *Ciphertext) (err error) {
if n == 0 || offset == 0 {
return fmt.Errorf("partialtrace: invalid parameter (n = 0 or batchSize = 0)")
}

params := eval.GetRLWEParameters()

Expand Down Expand Up @@ -236,7 +224,7 @@ func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Cipher
if j&1 == 1 {

k := n - (n & ((2 << i) - 1))
k *= batchSize
k *= offset

// If the rotation is not zero
if k != 0 {
Expand Down Expand Up @@ -281,7 +269,7 @@ func (eval Evaluator) InnerSum(ctIn *Ciphertext, batchSize, n int, opOut *Cipher

if !state {

rot := params.GaloisElement((1 << i) * batchSize)
rot := params.GaloisElement((1 << i) * offset)

// ctInNTT = ctInNTT + Rotate(ctInNTT, 2^i)
if err = eval.AutomorphismHoisted(levelQ, ctInNTT, eval.BuffDecompQP, rot, cQ); err != nil {
Expand Down Expand Up @@ -486,7 +474,7 @@ func GaloisElementsForInnerSum(params ParameterProvider, batch, n int) (galEls [
// two consecutive sub-vectors to replicate.
// This method is faster than Replicate when the number of rotations is large and it uses log2(n) + HW(n) instead of n.
func (eval Evaluator) Replicate(ctIn *Ciphertext, batchSize, n int, opOut *Ciphertext) (err error) {
return eval.InnerSum(ctIn, -batchSize, n, opOut)
return eval.PartialTracesSum(ctIn, -batchSize, n, opOut)
}

// GaloisElementsForReplicate returns the list of Galois elements necessary to perform the
Expand Down
5 changes: 3 additions & 2 deletions core/rlwe/rlwe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ func testSlotOperations(tc *TestContext, level, bpw2 int, t *testing.T) {
enc := tc.enc
dec := tc.dec

t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/InnerSum"), func(t *testing.T) {
t.Run(testString(params, level, params.MaxLevelP(), bpw2, "Evaluator/PartialTrace"), func(t *testing.T) {

if params.MaxLevelP() == -1 {
t.Skip("test requires #P > 0")
Expand All @@ -1099,14 +1099,15 @@ func testSlotOperations(tc *TestContext, level, bpw2 int, t *testing.T) {
ringQ := tc.params.RingQ().AtLevel(level)

pt := genPlaintext(params, level, 1<<30)
pt.LogDimensions = ring.Dimensions{Rows: 1, Cols: params.logN - 1}
ptInnerSum := *pt.Value.CopyNew()
ct, err := enc.EncryptNew(pt)
require.NoError(t, err)

// Galois Keys
evk := NewMemEvaluationKeySet(nil, kgen.GenGaloisKeysNew(GaloisElementsForInnerSum(params, batch, n), sk)...)

require.NoError(t, eval.WithKey(evk).InnerSum(ct, batch, n, ct))
require.NoError(t, eval.WithKey(evk).PartialTracesSum(ct, batch, n, ct))

dec.Decrypt(ct, pt)

Expand Down
42 changes: 35 additions & 7 deletions examples/singleparty/tutorials/ckks/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,13 +590,13 @@ func main() {

// The `circuits/lintrans` package provides a multiple handy linear transformations.
// We will start with the inner sum.
// Thus method allows to aggregate `n` sub-vectors of size `batch`.
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 3
// it will return the vector [x0+x2+x4, x1+x3+x5, x2+x4+x6, x3+x5+x7, x4+x6+x0, x5+x7+x1, x6+x0+x2, x7+x1+x3]
// Observe that the inner sum wraps around the vector, this behavior must be taken into account.
// This method allows to aggregate `n` sub-vectors of size `batch` and it stores the result in the leftmost sub-vector of each "group".
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 4
// it will return the vector [x0+x2+x4+x6, x1+x3+x5+x7, X, X, X, X, X, X], where X marks garbage slots.
// Note that n*batch must divide the length of the vector (i.e. the number of slots).

batch := 37
n := 127
batch := 32
n := 128

// The innersum operations is carried out with log2(n) + HW(n) automorphisms and we need to
// generate the corresponding Galois keys and provide them to the `Evaluator`.
Expand All @@ -619,7 +619,35 @@ func main() {
// apply the innersum and then only apply the rescaling.
fmt.Printf("Innersum %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())

// The replicate operation is exactly the same as the innersum operation, but in reverse
// Sometimes we wish to compute an inner sum on the first values of the vector only.
// In this case, n*batch does not necessarily divide the length of the vector and the RotateAndAdd function must be used instead.
// This method allows to repeatedly shift the vector by batch values and add (i.e. \sum_{i=0}^{n-1} v << (i*batch), where v is the input vector).
// For example given a vector [x0, x1, x2, x3, x4, x5, x6, x7], batch = 2 and n = 3
// it will return the vector [x0+x2+x4, x1+x3+x5, x2+x4+x6, x3+x5+x7, x4+x6+x0, x5+x7+x1, x6+x0+x2, x7+x1+x3].
// Observe that the inner sum wraps around the vector, this behavior must be taken into account.

batch = 37
n = 127
eval = eval.WithKey(rlwe.NewMemEvaluationKeySet(rlk, kgen.GenGaloisKeysNew(params.GaloisElementsForInnerSum(batch, n), sk)...))

// Plaintext circuit
copy(want, values1)
for i := 1; i < n; i++ {
for j, vi := range utils.RotateSlice(values1, i*batch) {
want[j] += vi
}
}

if err := eval.RotateAndAdd(ct1, batch, n, res); err != nil {
panic(err)
}

// Note that this method can obviously be used to average values.
// For a good noise management, it is recommended to first multiply the values by 1/n, then
// apply the inner sum and then only apply the rescaling.
fmt.Printf("RotateAndAdd %s", ckks.GetPrecisionStats(params, ecd, dec, want, res, 0, false).String())

// The replicate operation is exactly the same as the rotate and add operation, but in reverse
eval = eval.WithKey(rlwe.NewMemEvaluationKeySet(rlk, kgen.GenGaloisKeysNew(params.GaloisElementsForReplicate(batch, n), sk)...))

// Plaintext circuit
Expand Down
90 changes: 90 additions & 0 deletions schemes/bgv/bgv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/tuneinsight/lattigo/v6/core/rlwe"
"github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/utils"
)

var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise")
Expand Down Expand Up @@ -665,6 +666,95 @@ func testEvaluatorBvg(tc *TestContext, t *testing.T) {
}
})
}

// Naive implementation of the inner sum for reference
innersum := func(values []uint64, n, batchSize int, rotateAndAdd bool) {
aggregate := false
if n*batchSize == len(values) && !rotateAndAdd {
aggregate = true
n = n / 2
}
halfN := len(values) >> 1
tmp1 := make([]uint64, halfN)
tmp2 := make([]uint64, halfN)
copy(tmp1, values[:halfN])
copy(tmp2, values[halfN:])
for i := 1; i < n; i++ {
rot1 := utils.RotateSlice(tmp1, i*batchSize)
rot2 := utils.RotateSlice(tmp2, i*batchSize)
for j := range rot1 {
values[j] = (values[j] + rot1[j]) % tc.Params.PlaintextModulus()
values[j+halfN] = (values[j+halfN] + rot2[j]) % tc.Params.PlaintextModulus()
}
}
if aggregate {
for i := range tmp1 {
values[i] = (values[i] + values[i+halfN]) % tc.Params.PlaintextModulus()
}
}
}

for _, i := range []int{0, 1, 2} {
// n*batchSize = N, N/2, N/8
for _, offset := range []int{0, 1, 3} {
for _, lvl := range testLevel {
t.Run(name("Evaluator/InnerSum/", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}
n := tc.Params.MaxSlots() >> (i + offset)
batchSize := 1 << i

galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))

want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))

innersum(want, n, batchSize, false)

receiver := NewCiphertext(tc.Params, 1, lvl)

require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver))

have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))

for i := 0; i < len(want); i += n * batchSize {
require.Equal(t, want[i:i+batchSize], have[i:i+batchSize])
}
})
}
}

// Test RotateAndAdd with n*batchSize dividing and not dividing #slots
for _, n := range []int{tc.Params.MaxSlots() >> 3, 7} {
for _, batchSize := range []int{8, 3} {
for _, lvl := range testLevel {
t.Run(name("Evaluator/RotateAndAdd/", tc, lvl), func(t *testing.T) {
if lvl == 0 {
t.Skip("Skipping: Level = 0")
}

galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n)
evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...))

want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3))

innersum(want, n, batchSize, true)

receiver := NewCiphertext(tc.Params, 1, lvl)

require.NoError(t, evl.RotateAndAdd(ciphertext0, batchSize, n, receiver))

have := make([]uint64, len(want))
require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have))

require.Equal(t, want, have)
})
}
}
}
}
}

func testEvaluatorBfv(tc *TestContext, t *testing.T) {
Expand Down
80 changes: 80 additions & 0 deletions schemes/bgv/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,86 @@ func (eval Evaluator) RotateHoistedLazyNew(level int, rotations []int, op0 *rlwe
return
}

// InnerSum divides each row of the underlying plaintext in sub-vectors of size batchSize and add n of these together.
// If n*batchSize = ctIn.Slots(), the inner sum is computed as if the plaintext was a 1-D vector of dimension ctIn.Slots()
// (we recall that a BGV/BFV plaintext is represented as a 2 x ctIn.Slots()/2 matrix).
//
// WARNING: 0 < n*batchSize <= ctIn.Slots() must divide the number of slots ctIn.Slots(). For other parameters, consider using [Evaluator.RotateAndAdd].
//
// Example for batchSize=2, n=4 and 32 slots (garbage slots are marked as X):
//
// Input:
//
// [[{a1, b1}, {c1, d1}, {e1, f1}, {g1, h1}, {i1, j1}, {k1, l1}, {m1, n1}, {o1, p1}]
//
// [{a2, b2}, {c2, d2}, {e2, f2}, {g2, h2}, {i2, j2}, {k2, l2}, {m2, n2}, {o2, p2}]]
//
// Output:
//
// [[{a1+c1+e1+g1, b1+d1+f1+h1}, {X, X}, {X, X}, {X, X}, {i1+k1+m1+o1, j1+l1+n1+p1}, {X, X}, {X, X}, {X, X}]
//
// [{a2+c2+e2+g2, b2+d2+f2+h2}, {X, X}, {X, X}, {X, X}, {i2+k2+m2+o2, j2+l2+n2+p2}, {X, X}, {X, X}, {X, X}]]
func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) {
lehugueni marked this conversation as resolved.
Show resolved Hide resolved
N := ctIn.Slots()
l := n * batchSize
lehugueni marked this conversation as resolved.
Show resolved Hide resolved

if n <= 0 || batchSize <= 0 {
return fmt.Errorf("innersum: invalid parameter (n <= 0 or batchSize <= 0)")
}
if l > N {
return fmt.Errorf("innersum: invalid parameters (n*batchSize=%d > #slots=%d)", l, N)
}
if l&(l-1) != 0 {
return fmt.Errorf("innersum: invalid parameters (n*batchSize=%d does not divide #slots=%d)", l, N)
}

if l == N {
if n == 1 {
opOut.Copy(ctIn)
return
}

if err = eval.Evaluator.PartialTracesSum(ctIn, batchSize, n/2, opOut); err != nil {
return
}

ctTmp := &rlwe.Ciphertext{Element: rlwe.Element[ring.Poly]{Value: []ring.Poly{eval.BuffQP[2].Q, eval.BuffQP[3].Q}}}
ctTmp.MetaData = opOut.MetaData
if err = eval.RotateRows(opOut, ctTmp); err != nil {
return
}

if err = eval.Add(opOut, ctTmp, opOut); err != nil {
return
}

return
}

err = eval.Evaluator.PartialTracesSum(ctIn, batchSize, n, opOut)
return
}

// RotateAndAdd computes the sum of pt_i, 0 <= i < n, where pt_i is the underlying plaintext rotated ([Evaluator.RotateRows]) by batchSize*i slots.
//
// Example: for batchSize=3, n=2, ctIn.Slots()=16:
//
// Input (recall that a BGV/BFV plaintext is represented as a 2 x ctIn.Slots()/2 matrix):
//
// [[a, b, c, d, e, f, g, h]
// [i, j, k, l, m, n, o, p]]
//
// Output:
//
// [[a, b, c, d, e, f, g, h] + [[d, e, f, g, h, a, b, c] = [[a+d, b+e, c+f, d+g, e+h, f+a, g+b, h+c]
// [i, j, k, l, m, n, o, p]] [l, m, n, o, p, i, j, k]] [i+l, j+m, k+n, l+o, m+p, n+i, o+j, p+k]]
//
// Calling RotateAndAdd(ctIn, 1, n, opOut) can be used to compute the inner sum of the first n slots of a plaintext.
func (eval Evaluator) RotateAndAdd(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) {
err = eval.Evaluator.PartialTracesSum(ctIn, batchSize, n, opOut)
return
}

// MatchScalesAndLevel updates the both input ciphertexts to ensures that their scale matches.
// To do so it computes t0 * a = opOut * b such that:
// - ct0.Scale * a = opOut.Scale: make the scales match.
Expand Down
2 changes: 1 addition & 1 deletion schemes/bgv/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (p Parameters) GaloisElementForRowRotation() uint64 {
// InnerSum operation with parameters batch and n.
func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) {
galEls = rlwe.GaloisElementsForInnerSum(p, batch, n)
if n > p.N()>>1 {
if n*batch > p.MaxSlots()>>1 {
galEls = append(galEls, p.GaloisElementForRowRotation())
}
return
Expand Down
Loading
Loading