Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CKKS square and power functions #100

Merged
merged 13 commits into from
Jul 15, 2020
8 changes: 8 additions & 0 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("neg", &CKKSVector::negate)
.def("neg_", &CKKSVector::negate_inplace)
.def("neg_inplace", &CKKSVector::negate_inplace)
.def("square", &CKKSVector::square)
.def("square_", &CKKSVector::square_inplace)
.def("square_inplace", &CKKSVector::square_inplace)
.def("pow", &CKKSVector::power)
.def("pow_", &CKKSVector::power_inplace)
.def("pow_inplace", &CKKSVector::power_inplace)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python API have have this inplace operations terminated by a _ (e.g. add_) only and we aren't providing function_name_inplace kind of methods.

.def("add", &CKKSVector::add)
.def("add_", &CKKSVector::add_inplace)
.def("add_plain", py::overload_cast<double>(&CKKSVector::add_plain))
Expand Down Expand Up @@ -130,6 +136,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("mm_", &CKKSVector::matmul_plain_inplace)
// python arithmetic
.def("__neg__", &CKKSVector::negate)
.def("__pow__", &CKKSVector::power)
.def("__ipow__", &CKKSVector::power_inplace)
.def("__add__", &CKKSVector::add)
.def("__add__", py::overload_cast<double>(&CKKSVector::add_plain))
.def("__add__",
Expand Down
60 changes: 58 additions & 2 deletions tenseal/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,63 @@ CKKSVector& CKKSVector::negate_inplace() {
return *this;
}

CKKSVector CKKSVector::square() {
CKKSVector new_vector = *this;
new_vector.square_inplace();

return new_vector;
}

CKKSVector& CKKSVector::square_inplace() {
this->context->evaluator->square_inplace(this->ciphertext);

if (this->context->auto_relin()) {
this->context->evaluator->relinearize_inplace(
this->ciphertext, *this->context->relin_keys());
}

if (this->context->auto_rescale()) {
this->context->evaluator->rescale_to_next_inplace(this->ciphertext);
this->ciphertext.scale() = this->init_scale;
}

return *this;
}
Comment on lines +94 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not calling mul_inplace which should take care of relin/rescale?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling mul_inplace requires coping the ciphertext (i.e. this->mul_inplace(*this)).


CKKSVector CKKSVector::power(unsigned int power) {
CKKSVector new_vector = *this;
new_vector.power_inplace(power);

return new_vector;
}

CKKSVector& CKKSVector::power_inplace(unsigned int power) {
// if the power is zero, return a new encrypted vector of ones
if (power == 0) {
vector<double> ones(this->size(), 1);
*this = CKKSVector(this->context, ones, this->init_scale);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can here just replace the ciphertext (ciphertext = this->encrypt(...)) instead of creating a whole new CKKSVector.

return *this;
}

if (power == 1) {
return *this;
}

if (power == 2) {
this->square_inplace();
return *this;
}

int closest_power_of_2 = 1 << static_cast<int>(floor(log2(power)));
power -= closest_power_of_2;
if (power == 0) {
this->power_inplace(closest_power_of_2 / 2).square_inplace();
} else {
this->power_inplace(power).mul_inplace(this->power(closest_power_of_2));
philomath213 marked this conversation as resolved.
Show resolved Hide resolved
}
return *this;
}

CKKSVector CKKSVector::add(CKKSVector to_add) {
CKKSVector new_vector = *this;
new_vector.add_inplace(to_add);
Expand Down Expand Up @@ -454,8 +511,7 @@ CKKSVector& CKKSVector::polyval_inplace(const vector<double>& coefficients) {
x_squares.reserve(max_square + 1);
x_squares.push_back(x); // x
for (int i = 1; i <= max_square; i++) {
// TODO: use square
x.mul_inplace(x);
x.square_inplace();
x_squares.push_back(x); // x^(2^i)
}

Expand Down
12 changes: 12 additions & 0 deletions tenseal/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ class CKKSVector {
CKKSVector negate();
CKKSVector& negate_inplace();

/*
Compute the square of the CKKSVector.
*/
CKKSVector square();
CKKSVector& square_inplace();

/*
Compute the power of the CKKSVector with minimal multiplication depth.
*/
CKKSVector power(unsigned int power);
CKKSVector& power_inplace(unsigned int power);

/*
Encrypted evaluation function operates on two encrypted vectors and
returns a new CKKSVector which is the result of either addition,
Expand Down
94 changes: 94 additions & 0 deletions tests/tensors/test_ckks_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,100 @@ def test_negate_inplace(context, plain_vec):
assert _almost_equal(decrypted_result, expected, 1), "Decryption of vector is incorrect"


@pytest.mark.parametrize(
"plain_vec, power",
[
([], 2),
([0], 3),
([0, 1, -1, 2, -2], 0),
([1, -1, 2, -2], 1),
([1, -1, 2, -2], 2),
([1, -1, 2, -2], 3),
([1, -2, 3, -4], 1),
([1, -2, 3, -4], 2),
([1, -2, 3, -4], 3),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raising to the power of 4 should be also possible with the default parameters, we need a test-case for this to make sure the circuit is optimum. Same for the other tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should work with power of 4 as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still missing a testcase of 4 here

],
)
def test_power(context, plain_vec, power):
ckks_vec = ts.ckks_vector(context, plain_vec)
expected = [np.power(v, power) for v in plain_vec]
new_vec = ckks_vec ** power
decrypted_result = new_vec.decrypt()
assert _almost_equal(decrypted_result, expected, 1), "Decryption of vector is incorrect"
assert _almost_equal(ckks_vec.decrypt(), plain_vec, 1), "Something went wrong in memory."


@pytest.mark.parametrize(
"plain_vec, power",
[
([], 2),
([0], 3),
([0, 1, -1, 2, -2], 0),
([1, -1, 2, -2], 1),
([1, -1, 2, -2], 2),
([1, -1, 2, -2], 3),
([1, -2, 3, -4], 4),
([1, -2, 3, -4], 1),
([1, -2, 3, -4], 2),
([1, -2, 3, -4], 3),
([1, -2, 3, -4], 4),
],
)
def test_power_inplace(context, plain_vec, power):
ckks_vec = ts.ckks_vector(context, plain_vec)
expected = [np.power(v, power) for v in plain_vec]
ckks_vec **= power
decrypted_result = ckks_vec.decrypt()
assert _almost_equal(decrypted_result, expected, 1), "Decryption of vector is incorrect"


@pytest.mark.parametrize(
"plain_vec",
[
[],
[0],
[1],
[2],
[1, -1, 2, -2],
[1, -1, 6, -2],
[2, 1, -2, -2],
[1, -2, 3, -4],
[3, -2, 5, -4],
[1, -4, 3, 5],
],
)
def test_square(context, plain_vec):
ckks_vec = ts.ckks_vector(context, plain_vec)
expected = [np.power(v, 2) for v in plain_vec]
new_vec = ckks_vec.square()
decrypted_result = new_vec.decrypt()
assert _almost_equal(decrypted_result, expected, 1), "Decryption of vector is incorrect"
assert _almost_equal(ckks_vec.decrypt(), plain_vec, 1), "Something went wrong in memory."


@pytest.mark.parametrize(
"plain_vec",
[
[],
[0],
[1],
[2],
[1, -1, 2, -2],
[1, -1, 6, -2],
[2, 1, -2, -2],
[1, -2, 3, -4],
[3, -2, 5, -4],
[1, -4, 3, 5],
],
)
def test_square_inplace(context, plain_vec):
ckks_vec = ts.ckks_vector(context, plain_vec)
expected = [np.power(v, 2) for v in plain_vec]
ckks_vec.square_()
decrypted_result = ckks_vec.decrypt()
assert _almost_equal(decrypted_result, expected, 1), "Decryption of vector is incorrect"


@pytest.mark.parametrize(
"vec1, vec2",
[
Expand Down