From b2533dc385471338a3eb4d4f808cfd08ee2cd333 Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Mon, 21 Dec 2020 16:55:53 +0200 Subject: [PATCH 1/7] add operation guard --- tenseal/cpp/tensors/bfvvector.cpp | 58 +++++++------ tenseal/cpp/tensors/bfvvector.h | 75 ++++++++-------- tenseal/cpp/tensors/ckkstensor.cpp | 47 +++++----- tenseal/cpp/tensors/ckkstensor.h | 68 ++++++++------- tenseal/cpp/tensors/ckksvector.cpp | 70 ++++++++------- tenseal/cpp/tensors/ckksvector.h | 78 +++++++++-------- tenseal/cpp/tensors/encrypted_tensor.h | 114 ++++++++++++++++++++----- tenseal/cpp/tensors/encrypted_vector.h | 72 ++++++++++------ 8 files changed, 353 insertions(+), 229 deletions(-) diff --git a/tenseal/cpp/tensors/bfvvector.cpp b/tenseal/cpp/tensors/bfvvector.cpp index 59c25bab..25b23cca 100644 --- a/tenseal/cpp/tensors/bfvvector.cpp +++ b/tenseal/cpp/tensors/bfvvector.cpp @@ -77,7 +77,7 @@ BFVVector::plain_t BFVVector::decrypt(const shared_ptr& sk) const { return vector(result.cbegin(), result.cbegin() + this->size()); } -shared_ptr BFVVector::power_inplace(unsigned int power) { +shared_ptr BFVVector::power_inplace_impl(unsigned int power) { // if the power is zero, return a new encrypted vector of ones if (power == 0) { vector ones(this->size(), 1); @@ -97,29 +97,29 @@ shared_ptr BFVVector::power_inplace(unsigned int power) { int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace(closest_power_of_2 / 2)->square_inplace(); + this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace(power)->mul_inplace(closest_pow2_vector); + this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); } return shared_from_this(); } -shared_ptr BFVVector::negate_inplace() { +shared_ptr BFVVector::negate_inplace_impl() { this->tenseal_context()->evaluator->negate_inplace(this->_ciphertext); return shared_from_this(); } -shared_ptr BFVVector::square_inplace() { +shared_ptr BFVVector::square_inplace_impl() { this->tenseal_context()->evaluator->square_inplace(this->_ciphertext); this->auto_relin(_ciphertext); return shared_from_this(); } -shared_ptr BFVVector::add_inplace( +shared_ptr BFVVector::add_inplace_impl( const shared_ptr& other) { auto to_add = other->copy(); if (!this->tenseal_context()->equals(to_add->tenseal_context())) { @@ -136,7 +136,7 @@ shared_ptr BFVVector::add_inplace( return shared_from_this(); } -shared_ptr BFVVector::sub_inplace( +shared_ptr BFVVector::sub_inplace_impl( const shared_ptr& other) { auto to_sub = other->copy(); if (!this->tenseal_context()->equals(to_sub->tenseal_context())) { @@ -153,7 +153,7 @@ shared_ptr BFVVector::sub_inplace( return shared_from_this(); } -shared_ptr BFVVector::mul_inplace( +shared_ptr BFVVector::mul_inplace_impl( const shared_ptr& other) { auto to_mul = other->copy(); if (!this->tenseal_context()->equals(to_mul->tenseal_context())) { @@ -171,34 +171,34 @@ shared_ptr BFVVector::mul_inplace( return shared_from_this(); } -shared_ptr BFVVector::dot_product_inplace( +shared_ptr BFVVector::dot_product_inplace_impl( const shared_ptr& to_mul) { - this->mul_inplace(to_mul); - this->sum_inplace(); + this->mul_inplace_impl(to_mul); + this->sum_inplace_impl(); return shared_from_this(); } -shared_ptr BFVVector::dot_product_plain_inplace( +shared_ptr BFVVector::dot_product_plain_inplace_impl( const BFVVector::plain_t& to_mul) { - this->mul_plain_inplace(to_mul); - this->sum_inplace(); + this->mul_plain_inplace_impl(to_mul); + this->sum_inplace_impl(); return shared_from_this(); } -shared_ptr BFVVector::sum_inplace(size_t /*axis=0*/) { +shared_ptr BFVVector::sum_inplace_impl(size_t /*axis=0*/) { sum_vector(this->tenseal_context(), this->_ciphertext, this->size()); this->_size = 1; return shared_from_this(); } -shared_ptr BFVVector::add_plain_inplace( +shared_ptr BFVVector::add_plain_inplace_impl( const plain_t::dtype& to_add) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::add_plain_inplace( +shared_ptr BFVVector::add_plain_inplace_impl( const BFVVector::plain_t& to_add) { if (this->size() != to_add.size()) { throw invalid_argument("can't add vectors of different sizes"); @@ -213,12 +213,12 @@ shared_ptr BFVVector::add_plain_inplace( return shared_from_this(); } -shared_ptr BFVVector::sub_plain_inplace( +shared_ptr BFVVector::sub_plain_inplace_impl( const plain_t::dtype& to_sub) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::sub_plain_inplace( +shared_ptr BFVVector::sub_plain_inplace_impl( const BFVVector::plain_t& to_sub) { if (this->size() != to_sub.size()) { throw invalid_argument("can't sub vectors of different sizes"); @@ -233,12 +233,12 @@ shared_ptr BFVVector::sub_plain_inplace( return shared_from_this(); } -shared_ptr BFVVector::mul_plain_inplace( +shared_ptr BFVVector::mul_plain_inplace_impl( const plain_t::dtype& to_sub) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::mul_plain_inplace( +shared_ptr BFVVector::mul_plain_inplace_impl( const BFVVector::plain_t& to_mul) { if (this->size() != to_mul.size()) { throw invalid_argument("can't multiply vectors of different sizes"); @@ -262,31 +262,31 @@ shared_ptr BFVVector::mul_plain_inplace( return shared_from_this(); } -shared_ptr BFVVector::matmul_plain_inplace( +shared_ptr BFVVector::matmul_plain_inplace_impl( const BFVVector::plain_t& matrix, size_t n_jobs) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::polyval_inplace( +shared_ptr BFVVector::polyval_inplace_impl( const vector& coefficients) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::conv2d_im2col_inplace( +shared_ptr BFVVector::conv2d_im2col_inplace_impl( const BFVVector::plain_t& kernel, const size_t windows_nb) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::enc_matmul_plain_inplace( +shared_ptr BFVVector::enc_matmul_plain_inplace_impl( const BFVVector::plain_t& plain_vec, const size_t rows_nb) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::replicate_first_slot_inplace(size_t n) { +shared_ptr BFVVector::replicate_first_slot_inplace_impl(size_t n) { // mask vector mask(this->_size, 0); mask[0] = 1; - this->mul_plain_inplace(mask); + this->mul_plain_inplace_impl(mask); // replicate Ciphertext tmp = this->_ciphertext; @@ -353,4 +353,8 @@ shared_ptr BFVVector::deepcopy() const { BFVVectorProto vec = this->save_proto(); return BFVVector::Create(ctx, vec); } + +bool BFVVector::_check_operation_sanity(){ + return true; +} } // namespace tenseal diff --git a/tenseal/cpp/tensors/bfvvector.h b/tenseal/cpp/tensors/bfvvector.h index f5f6db01..8dfdef75 100644 --- a/tenseal/cpp/tensors/bfvvector.h +++ b/tenseal/cpp/tensors/bfvvector.h @@ -28,36 +28,49 @@ class BFVVector return shared_ptr( new BFVVector(std::forward(args)...)); } - /** * Decrypts and returns the plaintext representation of the encrypted vector *of integers using the secret-key. **/ plain_t decrypt(const shared_ptr& sk) const override; + /** + * Load/Save the vector from/to a serialized protobuffer. + **/ + void load(const string& vec) override; + string save() const override; + /** + *Recreates a new BFVVector from the current one, without any + *pointer/reference to this one. + * **/ + encrypted_t copy() const override; + encrypted_t deepcopy() const override; + double scale() const override { throw logic_error("not implemented"); } + + protected: /** * Compute the power of the BFVVector with minimal multiplication depth. **/ - encrypted_t power_inplace(unsigned int power) override; + encrypted_t power_inplace_impl(unsigned int power) override; /** * Negates a BFVVector. **/ - encrypted_t negate_inplace() override; + encrypted_t negate_inplace_impl() override; /** * Compute the square of the BFVVector. **/ - encrypted_t square_inplace() override; + encrypted_t square_inplace_impl() override; /** * Encrypted evaluation function operates on two encrypted vectors and * returns a new BFVVector which is the result of either *addition, substraction or multiplication in an element-wise fashion. *in_place functions return a reference to the same object. **/ - encrypted_t add_inplace(const encrypted_t& to_add) override; - encrypted_t sub_inplace(const encrypted_t& to_sub) override; - encrypted_t mul_inplace(const encrypted_t& to_mul) override; - encrypted_t dot_product_inplace(const encrypted_t& to_mul) override; - encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override; - encrypted_t sum_inplace(size_t axis = 0) override; + encrypted_t add_inplace_impl(const encrypted_t& to_add) override; + encrypted_t sub_inplace_impl(const encrypted_t& to_sub) override; + encrypted_t mul_inplace_impl(const encrypted_t& to_mul) override; + encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) override; + encrypted_t dot_product_plain_inplace_impl(const plain_t& to_mul) override; + encrypted_t sum_inplace_impl(size_t axis = 0) override; /** * Plain evaluation function operates on an encrypted vector and plaintext @@ -65,56 +78,48 @@ class BFVVector * either addition, substraction or multiplication in an element-wise *fashion. in_place functions return a reference to the same object. **/ - encrypted_t add_plain_inplace(const plain_t::dtype& to_add) override; - encrypted_t add_plain_inplace(const plain_t& to_add) override; - encrypted_t sub_plain_inplace(const plain_t::dtype& to_sub) override; - encrypted_t sub_plain_inplace(const plain_t& to_sub) override; - encrypted_t mul_plain_inplace(const plain_t::dtype& to_mul) override; - encrypted_t mul_plain_inplace(const plain_t& to_mul) override; + encrypted_t add_plain_inplace_impl(const plain_t::dtype& to_add) override; + encrypted_t add_plain_inplace_impl(const plain_t& to_add) override; + encrypted_t sub_plain_inplace_impl(const plain_t::dtype& to_sub) override; + encrypted_t sub_plain_inplace_impl(const plain_t& to_sub) override; + encrypted_t mul_plain_inplace_impl(const plain_t::dtype& to_mul) override; + encrypted_t mul_plain_inplace_impl(const plain_t& to_mul) override; /** * Encrypted Vector multiplication with plain matrix. **/ - encrypted_t matmul_plain_inplace(const plain_t& matrix, - size_t n_jobs = 0) override; + encrypted_t matmul_plain_inplace_impl(const plain_t& matrix, + size_t n_jobs = 0) override; /** * Encrypted Matrix multiplication with plain vector. **/ - encrypted_t enc_matmul_plain_inplace(const plain_t& plain_vec, - size_t row_size) override; + encrypted_t enc_matmul_plain_inplace_impl(const plain_t& plain_vec, + size_t row_size) override; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * *x^i **/ - encrypted_t polyval_inplace(const vector& coefficients) override; + encrypted_t polyval_inplace_impl( + const vector& coefficients) override; /* * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). * The kernel vector should be padded with zeros to the next power of 2 */ - encrypted_t conv2d_im2col_inplace(const plain_t& kernel, - const size_t windows_nb) override; + encrypted_t conv2d_im2col_inplace_impl(const plain_t& kernel, + const size_t windows_nb) override; /** * Replicate the first slot of a ciphertext n times. Requires a *multiplication. **/ - encrypted_t replicate_first_slot_inplace(size_t n) override; - /** - * Load/Save the vector from/to a serialized protobuffer. - **/ - void load(const string& vec) override; - string save() const override; + encrypted_t replicate_first_slot_inplace_impl(size_t n) override; /** - *Recreates a new BFVVector from the current one, without any - *pointer/reference to this one. + * Check tensor sanity * **/ - encrypted_t copy() const override; - encrypted_t deepcopy() const override; - - double scale() const override { throw logic_error("not implemented"); } + bool _check_operation_sanity() override; private: BFVVector(const shared_ptr& ctx, const plain_t& vec); diff --git a/tenseal/cpp/tensors/ckkstensor.cpp b/tenseal/cpp/tensors/ckkstensor.cpp index 27551e3e..a19f9ae0 100644 --- a/tenseal/cpp/tensors/ckkstensor.cpp +++ b/tenseal/cpp/tensors/ckkstensor.cpp @@ -118,13 +118,13 @@ PlainTensor CKKSTensor::decrypt(const shared_ptr& sk) const { } } -shared_ptr CKKSTensor::negate_inplace() { +shared_ptr CKKSTensor::negate_inplace_impl() { for (auto& ct : _data) this->tenseal_context()->evaluator->negate_inplace(ct); return shared_from_this(); } -shared_ptr CKKSTensor::square_inplace() { +shared_ptr CKKSTensor::square_inplace_impl() { for (auto& ct : _data) { this->tenseal_context()->evaluator->square_inplace(ct); this->auto_relin(ct); @@ -133,7 +133,7 @@ shared_ptr CKKSTensor::square_inplace() { return shared_from_this(); } -shared_ptr CKKSTensor::power_inplace(unsigned int power) { +shared_ptr CKKSTensor::power_inplace_impl(unsigned int power) { if (power == 0) { auto ones = PlainTensor::repeat_value(1, this->shape()); *this = CKKSTensor(this->tenseal_context(), ones, this->_init_scale, @@ -146,17 +146,17 @@ shared_ptr CKKSTensor::power_inplace(unsigned int power) { } if (power == 2) { - this->square_inplace(); + this->square_inplace_impl(); return shared_from_this(); } int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace(closest_power_of_2 / 2)->square_inplace(); + this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace(power)->mul_inplace(closest_pow2_vector); + this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); } return shared_from_this(); @@ -349,61 +349,61 @@ shared_ptr CKKSTensor::op_plain_inplace(const double& operand, return shared_from_this(); } -shared_ptr CKKSTensor::add_inplace( +shared_ptr CKKSTensor::add_inplace_impl( const shared_ptr& to_add) { return this->op_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_inplace( +shared_ptr CKKSTensor::sub_inplace_impl( const shared_ptr& to_sub) { return this->op_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_inplace( +shared_ptr CKKSTensor::mul_inplace_impl( const shared_ptr& to_mul) { return this->op_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::dot_product_inplace( +shared_ptr CKKSTensor::dot_product_inplace_impl( const shared_ptr& to_mul) { // TODO return shared_from_this(); } -shared_ptr CKKSTensor::add_plain_inplace( +shared_ptr CKKSTensor::add_plain_inplace_impl( const PlainTensor& to_add) { return this->op_plain_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_plain_inplace( +shared_ptr CKKSTensor::sub_plain_inplace_impl( const PlainTensor& to_sub) { return this->op_plain_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_plain_inplace( +shared_ptr CKKSTensor::mul_plain_inplace_impl( const PlainTensor& to_mul) { return this->op_plain_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::dot_product_plain_inplace( +shared_ptr CKKSTensor::dot_product_plain_inplace_impl( const PlainTensor& to_mul) { // TODO return shared_from_this(); } -shared_ptr CKKSTensor::add_plain_inplace(const double& to_add) { +shared_ptr CKKSTensor::add_plain_inplace_impl(const double& to_add) { return this->op_plain_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_plain_inplace(const double& to_sub) { +shared_ptr CKKSTensor::sub_plain_inplace_impl(const double& to_sub) { return this->op_plain_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_plain_inplace(const double& to_mul) { +shared_ptr CKKSTensor::mul_plain_inplace_impl(const double& to_mul) { return this->op_plain_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::sum_inplace(size_t axis) { +shared_ptr CKKSTensor::sum_inplace_impl(size_t axis) { if (axis >= shape_with_batch().size()) throw invalid_argument("invalid axis"); @@ -454,7 +454,7 @@ shared_ptr CKKSTensor::sum_batch_inplace() { return shared_from_this(); } -shared_ptr CKKSTensor::polyval_inplace( +shared_ptr CKKSTensor::polyval_inplace_impl( const vector& coefficients) { if (coefficients.size() == 0) { throw invalid_argument( @@ -485,7 +485,7 @@ shared_ptr CKKSTensor::polyval_inplace( x_squares.reserve(max_square + 1); x_squares.push_back(x->copy()); // x for (int i = 1; i <= max_square; i++) { - x->square_inplace(); + x->square_inplace_impl(); x_squares.push_back(x->copy()); // x^(2^i) } @@ -499,7 +499,7 @@ shared_ptr CKKSTensor::polyval_inplace( for (int i = 1; i <= degree; i++) { if (coefficients[i] == 0.0) continue; x = compute_polynomial_term(i, coefficients[i], x_squares); - result->add_inplace(x); + result->add_inplace_impl(x); } this->_data = TensorStorage(result->data(), result->shape()); @@ -604,4 +604,9 @@ shared_ptr CKKSTensor::reshape_inplace( } double CKKSTensor::scale() const { return _init_scale; } + +bool CKKSTensor::_check_operation_sanity(){ + return true; +} + } // namespace tenseal diff --git a/tenseal/cpp/tensors/ckkstensor.h b/tenseal/cpp/tensors/ckkstensor.h index 24cda956..ecafb6cc 100644 --- a/tenseal/cpp/tensors/ckkstensor.h +++ b/tenseal/cpp/tensors/ckkstensor.h @@ -28,41 +28,11 @@ class CKKSTensor : public EncryptedTensor>, PlainTensor decrypt(const shared_ptr& sk) const override; - shared_ptr negate_inplace() override; - shared_ptr square_inplace() override; - shared_ptr power_inplace(unsigned int power) override; - - shared_ptr add_inplace( - const shared_ptr& to_add) override; - shared_ptr sub_inplace( - const shared_ptr& to_sub) override; - shared_ptr mul_inplace( - const shared_ptr& to_mul) override; - shared_ptr dot_product_inplace( - const shared_ptr& to_mul) override; - - shared_ptr add_plain_inplace(const double& to_add) override; - shared_ptr sub_plain_inplace(const double& to_sub) override; - shared_ptr mul_plain_inplace(const double& to_mul) override; - - shared_ptr add_plain_inplace( - const PlainTensor& to_add) override; - shared_ptr sub_plain_inplace( - const PlainTensor& to_sub) override; - shared_ptr mul_plain_inplace( - const PlainTensor& to_mul) override; - shared_ptr dot_product_plain_inplace( - const PlainTensor& to_mul) override; - - shared_ptr sum_inplace(size_t axis = 0) override; shared_ptr sum_batch() { return this->copy()->sum_batch_inplace(); } shared_ptr sum_batch_inplace(); - shared_ptr polyval_inplace( - const vector& coefficients) override; - void load(const string& vec) override; string save() const override; @@ -76,6 +46,44 @@ class CKKSTensor : public EncryptedTensor>, vector shape_with_batch() const; double scale() const override; + protected: + shared_ptr negate_inplace_impl() override; + shared_ptr square_inplace_impl() override; + shared_ptr power_inplace_impl(unsigned int power) override; + + shared_ptr add_inplace_impl( + const shared_ptr& to_add) override; + shared_ptr sub_inplace_impl( + const shared_ptr& to_sub) override; + shared_ptr mul_inplace_impl( + const shared_ptr& to_mul) override; + shared_ptr dot_product_inplace_impl( + const shared_ptr& to_mul) override; + + shared_ptr add_plain_inplace_impl( + const double& to_add) override; + shared_ptr sub_plain_inplace_impl( + const double& to_sub) override; + shared_ptr mul_plain_inplace_impl( + const double& to_mul) override; + + shared_ptr add_plain_inplace_impl( + const PlainTensor& to_add) override; + shared_ptr sub_plain_inplace_impl( + const PlainTensor& to_sub) override; + shared_ptr mul_plain_inplace_impl( + const PlainTensor& to_mul) override; + shared_ptr dot_product_plain_inplace_impl( + const PlainTensor& to_mul) override; + + shared_ptr sum_inplace_impl(size_t axis = 0) override; + shared_ptr polyval_inplace_impl( + const vector& coefficients) override; + /** + * Check tensor sanity + * **/ + bool _check_operation_sanity() override; + private: TensorStorage _data; double _init_scale; diff --git a/tenseal/cpp/tensors/ckksvector.cpp b/tenseal/cpp/tensors/ckksvector.cpp index 7cd05e2a..84333ef1 100644 --- a/tenseal/cpp/tensors/ckksvector.cpp +++ b/tenseal/cpp/tensors/ckksvector.cpp @@ -82,7 +82,7 @@ CKKSVector::plain_t CKKSVector::decrypt(const shared_ptr& sk) const { return vector(result.cbegin(), result.cbegin() + this->size()); } -shared_ptr CKKSVector::power_inplace(unsigned int power) { +shared_ptr CKKSVector::power_inplace_impl(unsigned int power) { // if the power is zero, return a new encrypted vector of ones if (power == 0) { vector ones(this->size(), 1); @@ -95,29 +95,29 @@ shared_ptr CKKSVector::power_inplace(unsigned int power) { } if (power == 2) { - this->square_inplace(); + this->square_inplace_impl(); return shared_from_this(); } int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace(closest_power_of_2 / 2)->square_inplace(); + this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace(power)->mul_inplace(closest_pow2_vector); + this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); } return shared_from_this(); } -shared_ptr CKKSVector::negate_inplace() { +shared_ptr CKKSVector::negate_inplace_impl() { this->tenseal_context()->evaluator->negate_inplace(this->_ciphertext); return shared_from_this(); } -shared_ptr CKKSVector::square_inplace() { +shared_ptr CKKSVector::square_inplace_impl() { this->tenseal_context()->evaluator->square_inplace(_ciphertext); this->auto_relin(_ciphertext); this->auto_rescale(_ciphertext); @@ -125,7 +125,7 @@ shared_ptr CKKSVector::square_inplace() { return shared_from_this(); } -shared_ptr CKKSVector::add_inplace( +shared_ptr CKKSVector::add_inplace_impl( const shared_ptr& other) { auto to_add = other; if (!this->tenseal_context()->equals(to_add->tenseal_context())) { @@ -143,7 +143,7 @@ shared_ptr CKKSVector::add_inplace( return shared_from_this(); } -shared_ptr CKKSVector::sub_inplace( +shared_ptr CKKSVector::sub_inplace_impl( const shared_ptr& other) { auto to_sub = other; if (!this->tenseal_context()->equals(to_sub->tenseal_context())) { @@ -161,7 +161,7 @@ shared_ptr CKKSVector::sub_inplace( return shared_from_this(); } -shared_ptr CKKSVector::mul_inplace( +shared_ptr CKKSVector::mul_inplace_impl( const shared_ptr& other) { auto to_mul = other; if (!this->tenseal_context()->equals(to_mul->tenseal_context())) { @@ -182,36 +182,36 @@ shared_ptr CKKSVector::mul_inplace( return shared_from_this(); } -shared_ptr CKKSVector::dot_product_inplace( +shared_ptr CKKSVector::dot_product_inplace_impl( const shared_ptr& to_mul) { - this->mul_inplace(to_mul); - this->sum_inplace(); + this->mul_inplace_impl(to_mul); + this->sum_inplace_impl(); return shared_from_this(); } -shared_ptr CKKSVector::dot_product_plain_inplace( +shared_ptr CKKSVector::dot_product_plain_inplace_impl( const plain_t& to_mul) { - this->mul_plain_inplace(to_mul); - this->sum_inplace(); + this->mul_plain_inplace_impl(to_mul); + this->sum_inplace_impl(); return shared_from_this(); } -shared_ptr CKKSVector::sum_inplace(size_t /*axis = 0*/) { +shared_ptr CKKSVector::sum_inplace_impl(size_t /*axis = 0*/) { sum_vector(this->tenseal_context(), this->_ciphertext, this->size()); this->_size = 1; return shared_from_this(); } -shared_ptr CKKSVector::add_plain_inplace(const plain_t& to_add) { +shared_ptr CKKSVector::add_plain_inplace_impl(const plain_t& to_add) { if (this->size() != to_add.size()) { throw invalid_argument("can't add vectors of different sizes"); } return this->_add_plain_inplace(to_add.data()); } -shared_ptr CKKSVector::add_plain_inplace(const double& to_add) { +shared_ptr CKKSVector::add_plain_inplace_impl(const double& to_add) { return this->_add_plain_inplace(to_add); } @@ -226,14 +226,14 @@ shared_ptr CKKSVector::_add_plain_inplace(const T& to_add) { return shared_from_this(); } -shared_ptr CKKSVector::sub_plain_inplace(const plain_t& to_sub) { +shared_ptr CKKSVector::sub_plain_inplace_impl(const plain_t& to_sub) { if (this->size() != to_sub.size()) { throw invalid_argument("can't sub vectors of different sizes"); } return this->_sub_plain_inplace(to_sub.data()); } -shared_ptr CKKSVector::sub_plain_inplace(const double& to_sub) { +shared_ptr CKKSVector::sub_plain_inplace_impl(const double& to_sub) { return this->_sub_plain_inplace(to_sub); } @@ -250,7 +250,7 @@ shared_ptr CKKSVector::_sub_plain_inplace(const T& to_sub) { return shared_from_this(); } -shared_ptr CKKSVector::mul_plain_inplace(const plain_t& to_mul) { +shared_ptr CKKSVector::mul_plain_inplace_impl(const plain_t& to_mul) { if (this->size() != to_mul.size()) { throw invalid_argument("can't multiply vectors of different sizes"); } @@ -258,7 +258,7 @@ shared_ptr CKKSVector::mul_plain_inplace(const plain_t& to_mul) { return this->_mul_plain_inplace(to_mul.data()); } -shared_ptr CKKSVector::mul_plain_inplace(const double& to_mul) { +shared_ptr CKKSVector::mul_plain_inplace_impl(const double& to_mul) { return this->_mul_plain_inplace(to_mul); } @@ -288,7 +288,7 @@ shared_ptr CKKSVector::_mul_plain_inplace(const T& to_mul) { return this->copy(); } -shared_ptr CKKSVector::matmul_plain_inplace( +shared_ptr CKKSVector::matmul_plain_inplace_impl( const CKKSVector::plain_t& matrix, size_t n_jobs) { this->_ciphertext = this->diagonal_ct_vector_matmul(matrix, n_jobs); @@ -298,7 +298,7 @@ shared_ptr CKKSVector::matmul_plain_inplace( return shared_from_this(); } -shared_ptr CKKSVector::polyval_inplace( +shared_ptr CKKSVector::polyval_inplace_impl( const vector& coefficients) { if (coefficients.size() == 0) { throw invalid_argument( @@ -335,7 +335,7 @@ shared_ptr CKKSVector::polyval_inplace( x_squares.reserve(max_square + 1); x_squares.push_back(x->copy()); // x for (int i = 1; i <= max_square; i++) { - x->square_inplace(); + x->square_inplace_impl(); x_squares.push_back(x->copy()); // x^(2^i) } @@ -343,14 +343,14 @@ shared_ptr CKKSVector::polyval_inplace( for (int i = 1; i <= degree; i++) { if (coefficients[i] == 0.0) continue; x = compute_polynomial_term(i, coefficients[i], x_squares); - result->add_inplace(x); + result->add_inplace_impl(x); } this->_ciphertext = result->ciphertext(); return shared_from_this(); } -shared_ptr CKKSVector::conv2d_im2col_inplace( +shared_ptr CKKSVector::conv2d_im2col_inplace_impl( const CKKSVector::plain_t& kernel, const size_t windows_nb) { if (windows_nb == 0) { throw invalid_argument("Windows number can't be zero"); @@ -362,11 +362,11 @@ shared_ptr CKKSVector::conv2d_im2col_inplace( // flat the kernel auto flatten_kernel = kernel.horizontal_scan(); - this->enc_matmul_plain_inplace(flatten_kernel, windows_nb); + this->enc_matmul_plain_inplace_impl(flatten_kernel, windows_nb); return shared_from_this(); } -shared_ptr CKKSVector::enc_matmul_plain_inplace( +shared_ptr CKKSVector::enc_matmul_plain_inplace_impl( const CKKSVector::plain_t& plain_vec, const size_t rows_nb) { if (plain_vec.empty()) { throw invalid_argument("Plain vector can't be empty"); @@ -400,7 +400,7 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace( replicate_vector(new_plain_vec, slot_count); this->_size = slot_count; - this->mul_plain_inplace(new_plain_vec); + this->mul_plain_inplace_impl(new_plain_vec); auto galois_keys = this->tenseal_context()->galois_keys(); @@ -412,7 +412,7 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace( 1 << (static_cast(ceil(log2(chunks_nb))) - 1)); tmp->rotate_vector_inplace(static_cast(rows_nb * chunks_nb), *galois_keys); - this->add_inplace(tmp); + this->add_inplace_impl(tmp); } this->_size = rows_nb; @@ -420,11 +420,11 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace( return shared_from_this(); } -shared_ptr CKKSVector::replicate_first_slot_inplace(size_t n) { +shared_ptr CKKSVector::replicate_first_slot_inplace_impl(size_t n) { // mask vector mask(this->_size, 0); mask[0] = 1; - this->mul_plain_inplace(mask); + this->mul_plain_inplace_impl(mask); // replicate Ciphertext tmp = this->_ciphertext; @@ -491,4 +491,8 @@ shared_ptr CKKSVector::deepcopy() const { return CKKSVector::Create(ctx, vec); } +bool CKKSVector::_check_operation_sanity(){ + return true; +} + } // namespace tenseal diff --git a/tenseal/cpp/tensors/ckksvector.h b/tenseal/cpp/tensors/ckksvector.h index 792de59d..7e2b5170 100644 --- a/tenseal/cpp/tensors/ckksvector.h +++ b/tenseal/cpp/tensors/ckksvector.h @@ -30,31 +30,46 @@ class CKKSVector *of real numbers using the secret-key. **/ plain_t decrypt(const shared_ptr& sk) const override; + /** + * Load/Save the vector from/to a serialized protobuffer. + **/ + void load(const string& vec) override; + string save() const override; + + /** + *Recreates a new CKKSVector from the current one, without any + *pointer/reference to this one. + **/ + encrypted_t copy() const override; + encrypted_t deepcopy() const override; + + double scale() const override { return _init_scale; } + protected: /** * Compute the power of the CKKSVector with minimal multiplication depth. **/ - encrypted_t power_inplace(unsigned int power) override; + encrypted_t power_inplace_impl(unsigned int power) override; /** * Negates a CKKSVector. **/ - encrypted_t negate_inplace() override; + encrypted_t negate_inplace_impl() override; /** * Compute the square of the CKKSVector. **/ - encrypted_t square_inplace() override; + encrypted_t square_inplace_impl() override; /** * Encrypted evaluation function operates on two encrypted vectors and * returns a new CKKSVector which is the result of either *addition, substraction or multiplication in an element-wise fashion. *in_place functions return a reference to the same object. **/ - encrypted_t add_inplace(const encrypted_t& to_add) override; - encrypted_t sub_inplace(const encrypted_t& to_sub) override; - encrypted_t mul_inplace(const encrypted_t& to_mul) override; - encrypted_t dot_product_inplace(const encrypted_t& to_mul) override; - encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override; - encrypted_t sum_inplace(size_t axis = 0) override; + encrypted_t add_inplace_impl(const encrypted_t& to_add) override; + encrypted_t sub_inplace_impl(const encrypted_t& to_sub) override; + encrypted_t mul_inplace_impl(const encrypted_t& to_mul) override; + encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) override; + encrypted_t dot_product_plain_inplace_impl(const plain_t& to_mul) override; + encrypted_t sum_inplace_impl(size_t axis = 0) override; /** * Plain evaluation function operates on an encrypted vector and plaintext @@ -62,58 +77,49 @@ class CKKSVector * either addition, substraction or multiplication in an element-wise *fashion. in_place functions return a reference to the same object. **/ - encrypted_t add_plain_inplace(const plain_t::dtype& to_add) override; - encrypted_t add_plain_inplace(const plain_t& to_add) override; - encrypted_t sub_plain_inplace(const plain_t::dtype& to_sub) override; - encrypted_t sub_plain_inplace(const plain_t& to_sub) override; - encrypted_t mul_plain_inplace(const plain_t::dtype& to_mul) override; - encrypted_t mul_plain_inplace(const plain_t& to_mul) override; + encrypted_t add_plain_inplace_impl(const plain_t::dtype& to_add) override; + encrypted_t add_plain_inplace_impl(const plain_t& to_add) override; + encrypted_t sub_plain_inplace_impl(const plain_t::dtype& to_sub) override; + encrypted_t sub_plain_inplace_impl(const plain_t& to_sub) override; + encrypted_t mul_plain_inplace_impl(const plain_t::dtype& to_mul) override; + encrypted_t mul_plain_inplace_impl(const plain_t& to_mul) override; /** * Encrypted Vector multiplication with plain matrix. **/ - encrypted_t matmul_plain_inplace(const plain_t& matrix, - size_t n_jobs = 0) override; + encrypted_t matmul_plain_inplace_impl(const plain_t& matrix, + size_t n_jobs = 0) override; /** * Encrypted Matrix multiplication with plain vector. **/ - encrypted_t enc_matmul_plain_inplace(const plain_t& plain_vec, - size_t row_size) override; + encrypted_t enc_matmul_plain_inplace_impl(const plain_t& plain_vec, + size_t row_size) override; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * *x^i **/ - encrypted_t polyval_inplace(const vector& coefficients) override; + encrypted_t polyval_inplace_impl( + const vector& coefficients) override; /* * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). * The kernel vector should be padded with zeros to the next power of 2 */ - encrypted_t conv2d_im2col_inplace(const plain_t& kernel, - const size_t windows_nb) override; + encrypted_t conv2d_im2col_inplace_impl(const plain_t& kernel, + const size_t windows_nb) override; /** * Replicate the first slot of a ciphertext n times. Requires a *multiplication. **/ - encrypted_t replicate_first_slot_inplace(size_t n) override; + encrypted_t replicate_first_slot_inplace_impl(size_t n) override; /** - * Load/Save the vector from/to a serialized protobuffer. - **/ - void load(const string& vec) override; - string save() const override; - - /** - *Recreates a new CKKSVector from the current one, without any - *pointer/reference to this one. - **/ - encrypted_t copy() const override; - encrypted_t deepcopy() const override; - - double scale() const override { return _init_scale; } + * Check tensor sanity + * **/ + bool _check_operation_sanity() override; private: double _init_scale; diff --git a/tenseal/cpp/tensors/encrypted_tensor.h b/tenseal/cpp/tensors/encrypted_tensor.h index c9fa083a..f62ef306 100644 --- a/tenseal/cpp/tensors/encrypted_tensor.h +++ b/tenseal/cpp/tensors/encrypted_tensor.h @@ -38,17 +38,22 @@ class EncryptedTensor { }; virtual PlainTensor decrypt( const shared_ptr& sk) const = 0; - /** * Negates a EncryptedTensor. **/ encrypted_t negate() const { return this->copy()->negate_inplace(); }; - virtual encrypted_t negate_inplace() = 0; + encrypted_t negate_inplace() { + this->_check_operation_sanity(); + return negate_inplace_impl(); + }; /** * Compute the square of the EncryptedTensor. **/ encrypted_t square() const { return this->copy()->square_inplace(); }; - virtual encrypted_t square_inplace() = 0; + encrypted_t square_inplace() { + this->_check_operation_sanity(); + return square_inplace_impl(); + }; /** * Compute the power of the EncryptedTensor with *minimal multiplication depth. @@ -56,7 +61,10 @@ class EncryptedTensor { encrypted_t power(unsigned int power) const { return this->copy()->power_inplace(power); }; - virtual encrypted_t power_inplace(unsigned int power) = 0; + encrypted_t power_inplace(unsigned int power) { + this->_check_operation_sanity(); + return power_inplace_impl(power); + }; /** * Encrypted evaluation function operates on two encrypted tensors and * returns a new EncryptedTensor @@ -67,19 +75,31 @@ class EncryptedTensor { encrypted_t add(const encrypted_t& to_add) const { return this->copy()->add_inplace(to_add); }; - virtual encrypted_t add_inplace(const encrypted_t& to_add) = 0; + encrypted_t add_inplace(const encrypted_t& to_add) { + this->_check_operation_sanity(); + return add_inplace_impl(to_add); + }; encrypted_t sub(encrypted_t to_sub) const { return this->copy()->sub_inplace(to_sub); }; - virtual encrypted_t sub_inplace(const encrypted_t& to_sub) = 0; + encrypted_t sub_inplace(const encrypted_t& to_sub) { + this->_check_operation_sanity(); + return sub_inplace_impl(to_sub); + }; encrypted_t mul(encrypted_t to_mul) const { return this->copy()->mul_inplace(to_mul); }; - virtual encrypted_t mul_inplace(const encrypted_t& to_mul) = 0; + encrypted_t mul_inplace(const encrypted_t& to_mul) { + this->_check_operation_sanity(); + return mul_inplace_impl(to_mul); + }; encrypted_t dot_product(encrypted_t to_mul) const { return this->copy()->dot_product_inplace(to_mul); }; - virtual encrypted_t dot_product_inplace(const encrypted_t& to_mul) = 0; + encrypted_t dot_product_inplace(const encrypted_t& to_mul) { + this->_check_operation_sanity(); + return dot_product_inplace_impl(to_mul); + }; /** * Plain evaluation function operates on an encrypted tensors and plaintext * tensors and returns a new EncryptedTensor @@ -93,9 +113,14 @@ class EncryptedTensor { encrypted_t add_plain(const PlainTensor& to_add) const { return this->copy()->add_plain_inplace(to_add); }; - virtual encrypted_t add_plain_inplace(const plain_data_t& to_add) = 0; - virtual encrypted_t add_plain_inplace( - const PlainTensor& to_add) = 0; + encrypted_t add_plain_inplace(const plain_data_t& to_add) { + this->_check_operation_sanity(); + return add_plain_inplace_impl(to_add); + }; + encrypted_t add_plain_inplace(const PlainTensor& to_add) { + this->_check_operation_sanity(); + return add_plain_inplace_impl(to_add); + }; encrypted_t sub_plain(const plain_data_t& to_sub) const { return this->copy()->sub_plain_inplace(to_sub); @@ -103,9 +128,14 @@ class EncryptedTensor { encrypted_t sub_plain(const PlainTensor& to_sub) const { return this->copy()->sub_plain_inplace(to_sub); }; - virtual encrypted_t sub_plain_inplace(const plain_data_t& to_sub) = 0; - virtual encrypted_t sub_plain_inplace( - const PlainTensor& to_sub) = 0; + encrypted_t sub_plain_inplace(const plain_data_t& to_sub) { + this->_check_operation_sanity(); + return sub_plain_inplace_impl(to_sub); + }; + encrypted_t sub_plain_inplace(const PlainTensor& to_sub) { + this->_check_operation_sanity(); + return sub_plain_inplace_impl(to_sub); + }; encrypted_t mul_plain(const plain_data_t& to_mul) const { return this->copy()->mul_plain_inplace(to_mul); @@ -113,20 +143,31 @@ class EncryptedTensor { encrypted_t mul_plain(const PlainTensor& to_mul) const { return this->copy()->mul_plain_inplace(to_mul); }; - virtual encrypted_t mul_plain_inplace(const plain_data_t& to_mul) = 0; - virtual encrypted_t mul_plain_inplace( - const PlainTensor& to_mul) = 0; + encrypted_t mul_plain_inplace(const plain_data_t& to_mul) { + this->_check_operation_sanity(); + return mul_plain_inplace_impl(to_mul); + }; + encrypted_t mul_plain_inplace(const PlainTensor& to_mul) { + this->_check_operation_sanity(); + return mul_plain_inplace_impl(to_mul); + }; encrypted_t dot_product_plain( const PlainTensor& to_mul) const { return this->copy()->dot_product_plain_inplace(to_mul); }; - virtual encrypted_t dot_product_plain_inplace( - const PlainTensor& to_mul) = 0; + encrypted_t dot_product_plain_inplace( + const PlainTensor& to_mul) { + this->_check_operation_sanity(); + return dot_product_plain_inplace_impl(to_mul); + }; encrypted_t sum(size_t axis = 0) const { return this->copy()->sum_inplace(axis); }; - virtual encrypted_t sum_inplace(size_t axis) = 0; + encrypted_t sum_inplace(size_t axis = 0) { + this->_check_operation_sanity(); + return sum_inplace_impl(axis); + }; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * @@ -135,7 +176,10 @@ class EncryptedTensor { encrypted_t polyval(const vector& coefficients) const { return this->copy()->polyval_inplace(coefficients); }; - virtual encrypted_t polyval_inplace(const vector& coefficients) = 0; + encrypted_t polyval_inplace(const vector& coefficients) { + this->_check_operation_sanity(); + return polyval_inplace_impl(coefficients); + }; /** * Load/Save the Tensor from/to a serialized protobuffer. **/ @@ -253,6 +297,34 @@ class EncryptedTensor { protected: shared_ptr _context; + /** + * Sanity checks for the tensor + * */ + virtual bool _check_operation_sanity() = 0; + /** + * Implementations for the operations + * **/ + virtual encrypted_t negate_inplace_impl() = 0; + virtual encrypted_t square_inplace_impl() = 0; + virtual encrypted_t power_inplace_impl(unsigned int power) = 0; + virtual encrypted_t add_inplace_impl(const encrypted_t& to_add) = 0; + virtual encrypted_t sub_inplace_impl(const encrypted_t& to_sub) = 0; + virtual encrypted_t mul_inplace_impl(const encrypted_t& to_mul) = 0; + virtual encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) = 0; + virtual encrypted_t add_plain_inplace_impl(const plain_data_t& to_add) = 0; + virtual encrypted_t add_plain_inplace_impl( + const PlainTensor& to_add) = 0; + virtual encrypted_t sub_plain_inplace_impl(const plain_data_t& to_sub) = 0; + virtual encrypted_t sub_plain_inplace_impl( + const PlainTensor& to_sub) = 0; + virtual encrypted_t mul_plain_inplace_impl(const plain_data_t& to_mul) = 0; + virtual encrypted_t mul_plain_inplace_impl( + const PlainTensor& to_mul) = 0; + virtual encrypted_t dot_product_plain_inplace_impl( + const PlainTensor& to_mul) = 0; + virtual encrypted_t sum_inplace_impl(size_t axis) = 0; + virtual encrypted_t polyval_inplace_impl( + const vector& coefficients) = 0; private: }; diff --git a/tenseal/cpp/tensors/encrypted_vector.h b/tenseal/cpp/tensors/encrypted_vector.h index 39be5e68..7ad42135 100644 --- a/tenseal/cpp/tensors/encrypted_vector.h +++ b/tenseal/cpp/tensors/encrypted_vector.h @@ -20,25 +20,25 @@ using namespace std; *EncryptedVector pure methods: * * vector EncryptedTensor::decrypt(const shared_ptr&) *const = 0; - * * encrypted_t negate_inplace(); - * * encrypted_t square_inplace(); - * * encrypted_t add_inplace(encrypted_t to_add); - * * encrypted_t sub_inplace(encrypted_t to_sub); - * * encrypted_t mul_inplace(encrypted_t to_mul); - * * encrypted_t dot_product_inplace(encrypted_t to_mul); - * * encrypted_t dot_product_plain_inplace( const vector& to_mul); - * * encrypted_t sum_inplace(); - * * encrypted_t EncryptedTensor::power_inplace(unsigned int power) = 0; - * * encrypted_t EncryptedTensor::add_plain_inplace(plain_t to_add) = 0; - * * encrypted_t EncryptedTensor::add_plain_inplace(const PlainTensor& - *to_add) = 0; - * * encrypted_t EncryptedTensor::sub_plain_inplace(plain_t to_sub) = 0; - * * encrypted_t EncryptedTensor::sub_plain_inplace(const PlainTensor& - *to_sub) = 0; - * * encrypted_t EncryptedTensor::mul_plain_inplace(plain_t to_mul) = 0; - * * encrypted_t EncryptedTensor::mul_plain_inplace(const PlainTensor& - *to_mul) = 0; - * * encrypted_t EncryptedTensor::polyval_inplace(const vector& + * * encrypted_t negate_inplace_impl(); + * * encrypted_t square_inplace_impl(); + * * encrypted_t add_inplace_impl(encrypted_t to_add); + * * encrypted_t sub_inplace_impl(encrypted_t to_sub); + * * encrypted_t mul_inplace_impl(encrypted_t to_mul); + * * encrypted_t dot_product_inplace_impl(encrypted_t to_mul); + * * encrypted_t dot_product_plain_inplace_impl( const vector& to_mul); + * * encrypted_t sum_inplace_impl(); + * * encrypted_t EncryptedTensor::power_inplace_impl(unsigned int power) = 0; + * * encrypted_t EncryptedTensor::add_plain_inplace_impl(plain_t to_add) = 0; + * * encrypted_t EncryptedTensor::add_plain_inplace_impl(const + *PlainTensor& to_add) = 0; + * * encrypted_t EncryptedTensor::sub_plain_inplace_impl(plain_t to_sub) = 0; + * * encrypted_t EncryptedTensor::sub_plain_inplace_impl(const + *PlainTensor& to_sub) = 0; + * * encrypted_t EncryptedTensor::mul_plain_inplace_impl(plain_t to_mul) = 0; + * * encrypted_t EncryptedTensor::mul_plain_inplace_impl(const + *PlainTensor& to_mul) = 0; + * * encrypted_t EncryptedTensor::polyval_inplace_impl(const vector& *coefficients) = 0; * * void EncryptedTensor::load(const string& vec) = 0; * * string EncryptedTensor::save() const = 0; @@ -67,7 +67,10 @@ class EncryptedVector : public EncryptedTensor { encrypted_t replicate_first_slot(size_t n) const { return this->copy()->replicate_first_slot_inplace(n); } - virtual encrypted_t replicate_first_slot_inplace(size_t n) = 0; + encrypted_t replicate_first_slot_inplace(size_t n) { + this->_check_operation_sanity(); + return replicate_first_slot_inplace_impl(n); + }; /** * Adjust two vectors to match sizes. * @return the right operand, in case it was copied and altered. @@ -95,8 +98,11 @@ class EncryptedVector : public EncryptedTensor { size_t n_jobs = 0) const { return this->copy()->matmul_plain_inplace(matrix, n_jobs); } - virtual encrypted_t matmul_plain_inplace(const PlainTensor& matrix, - size_t n_jobs = 0) = 0; + encrypted_t matmul_plain_inplace(const PlainTensor& matrix, + size_t n_jobs = 0) { + this->_check_operation_sanity(); + return matmul_plain_inplace_impl(matrix, n_jobs); + }; /** * Encrypted Matrix multiplication with plain vector. **/ @@ -104,8 +110,11 @@ class EncryptedVector : public EncryptedTensor { size_t row_size) const { return this->copy()->enc_matmul_plain_inplace(plain_vec, row_size); } - virtual encrypted_t enc_matmul_plain_inplace( - const PlainTensor& plain_vec, size_t row_size) = 0; + encrypted_t enc_matmul_plain_inplace(const PlainTensor& plain_vec, + size_t row_size) { + this->_check_operation_sanity(); + return enc_matmul_plain_inplace_impl(plain_vec, row_size); + }; /** * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). @@ -115,8 +124,11 @@ class EncryptedVector : public EncryptedTensor { const size_t windows_nb) const { return this->copy()->conv2d_im2col_inplace(kernel, windows_nb); } - virtual encrypted_t conv2d_im2col_inplace( - const PlainTensor& kernel, const size_t windows_nb) = 0; + encrypted_t conv2d_im2col_inplace(const PlainTensor& kernel, + const size_t windows_nb) { + this->_check_operation_sanity(); + return conv2d_im2col_inplace_impl(kernel, windows_nb); + }; /** * Rotate encrypted plaintext cyclically @@ -233,6 +245,14 @@ class EncryptedVector : public EncryptedTensor { protected: size_t _size; Ciphertext _ciphertext; + + virtual encrypted_t replicate_first_slot_inplace_impl(size_t n) = 0; + virtual encrypted_t matmul_plain_inplace_impl( + const PlainTensor& matrix, size_t n_jobs = 0) = 0; + virtual encrypted_t enc_matmul_plain_inplace_impl( + const PlainTensor& plain_vec, size_t row_size) = 0; + virtual encrypted_t conv2d_im2col_inplace_impl( + const PlainTensor& kernel, const size_t windows_nb) = 0; }; } // namespace tenseal From c4a111044b9da21e0e6c61373d9e27a8fb50b85e Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Mon, 21 Dec 2020 17:38:21 +0200 Subject: [PATCH 2/7] Revert "add operation guard" This reverts commit b2533dc385471338a3eb4d4f808cfd08ee2cd333. --- tenseal/cpp/tensors/bfvvector.cpp | 58 ++++++------- tenseal/cpp/tensors/bfvvector.h | 75 ++++++++-------- tenseal/cpp/tensors/ckkstensor.cpp | 47 +++++----- tenseal/cpp/tensors/ckkstensor.h | 68 +++++++-------- tenseal/cpp/tensors/ckksvector.cpp | 70 +++++++-------- tenseal/cpp/tensors/ckksvector.h | 78 ++++++++--------- tenseal/cpp/tensors/encrypted_tensor.h | 114 +++++-------------------- tenseal/cpp/tensors/encrypted_vector.h | 72 ++++++---------- 8 files changed, 229 insertions(+), 353 deletions(-) diff --git a/tenseal/cpp/tensors/bfvvector.cpp b/tenseal/cpp/tensors/bfvvector.cpp index 25b23cca..59c25bab 100644 --- a/tenseal/cpp/tensors/bfvvector.cpp +++ b/tenseal/cpp/tensors/bfvvector.cpp @@ -77,7 +77,7 @@ BFVVector::plain_t BFVVector::decrypt(const shared_ptr& sk) const { return vector(result.cbegin(), result.cbegin() + this->size()); } -shared_ptr BFVVector::power_inplace_impl(unsigned int power) { +shared_ptr BFVVector::power_inplace(unsigned int power) { // if the power is zero, return a new encrypted vector of ones if (power == 0) { vector ones(this->size(), 1); @@ -97,29 +97,29 @@ shared_ptr BFVVector::power_inplace_impl(unsigned int power) { int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); + this->power_inplace(closest_power_of_2 / 2)->square_inplace(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); + this->power_inplace(power)->mul_inplace(closest_pow2_vector); } return shared_from_this(); } -shared_ptr BFVVector::negate_inplace_impl() { +shared_ptr BFVVector::negate_inplace() { this->tenseal_context()->evaluator->negate_inplace(this->_ciphertext); return shared_from_this(); } -shared_ptr BFVVector::square_inplace_impl() { +shared_ptr BFVVector::square_inplace() { this->tenseal_context()->evaluator->square_inplace(this->_ciphertext); this->auto_relin(_ciphertext); return shared_from_this(); } -shared_ptr BFVVector::add_inplace_impl( +shared_ptr BFVVector::add_inplace( const shared_ptr& other) { auto to_add = other->copy(); if (!this->tenseal_context()->equals(to_add->tenseal_context())) { @@ -136,7 +136,7 @@ shared_ptr BFVVector::add_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::sub_inplace_impl( +shared_ptr BFVVector::sub_inplace( const shared_ptr& other) { auto to_sub = other->copy(); if (!this->tenseal_context()->equals(to_sub->tenseal_context())) { @@ -153,7 +153,7 @@ shared_ptr BFVVector::sub_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::mul_inplace_impl( +shared_ptr BFVVector::mul_inplace( const shared_ptr& other) { auto to_mul = other->copy(); if (!this->tenseal_context()->equals(to_mul->tenseal_context())) { @@ -171,34 +171,34 @@ shared_ptr BFVVector::mul_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::dot_product_inplace_impl( +shared_ptr BFVVector::dot_product_inplace( const shared_ptr& to_mul) { - this->mul_inplace_impl(to_mul); - this->sum_inplace_impl(); + this->mul_inplace(to_mul); + this->sum_inplace(); return shared_from_this(); } -shared_ptr BFVVector::dot_product_plain_inplace_impl( +shared_ptr BFVVector::dot_product_plain_inplace( const BFVVector::plain_t& to_mul) { - this->mul_plain_inplace_impl(to_mul); - this->sum_inplace_impl(); + this->mul_plain_inplace(to_mul); + this->sum_inplace(); return shared_from_this(); } -shared_ptr BFVVector::sum_inplace_impl(size_t /*axis=0*/) { +shared_ptr BFVVector::sum_inplace(size_t /*axis=0*/) { sum_vector(this->tenseal_context(), this->_ciphertext, this->size()); this->_size = 1; return shared_from_this(); } -shared_ptr BFVVector::add_plain_inplace_impl( +shared_ptr BFVVector::add_plain_inplace( const plain_t::dtype& to_add) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::add_plain_inplace_impl( +shared_ptr BFVVector::add_plain_inplace( const BFVVector::plain_t& to_add) { if (this->size() != to_add.size()) { throw invalid_argument("can't add vectors of different sizes"); @@ -213,12 +213,12 @@ shared_ptr BFVVector::add_plain_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::sub_plain_inplace_impl( +shared_ptr BFVVector::sub_plain_inplace( const plain_t::dtype& to_sub) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::sub_plain_inplace_impl( +shared_ptr BFVVector::sub_plain_inplace( const BFVVector::plain_t& to_sub) { if (this->size() != to_sub.size()) { throw invalid_argument("can't sub vectors of different sizes"); @@ -233,12 +233,12 @@ shared_ptr BFVVector::sub_plain_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::mul_plain_inplace_impl( +shared_ptr BFVVector::mul_plain_inplace( const plain_t::dtype& to_sub) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::mul_plain_inplace_impl( +shared_ptr BFVVector::mul_plain_inplace( const BFVVector::plain_t& to_mul) { if (this->size() != to_mul.size()) { throw invalid_argument("can't multiply vectors of different sizes"); @@ -262,31 +262,31 @@ shared_ptr BFVVector::mul_plain_inplace_impl( return shared_from_this(); } -shared_ptr BFVVector::matmul_plain_inplace_impl( +shared_ptr BFVVector::matmul_plain_inplace( const BFVVector::plain_t& matrix, size_t n_jobs) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::polyval_inplace_impl( +shared_ptr BFVVector::polyval_inplace( const vector& coefficients) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::conv2d_im2col_inplace_impl( +shared_ptr BFVVector::conv2d_im2col_inplace( const BFVVector::plain_t& kernel, const size_t windows_nb) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::enc_matmul_plain_inplace_impl( +shared_ptr BFVVector::enc_matmul_plain_inplace( const BFVVector::plain_t& plain_vec, const size_t rows_nb) { throw std::logic_error("not implemented"); } -shared_ptr BFVVector::replicate_first_slot_inplace_impl(size_t n) { +shared_ptr BFVVector::replicate_first_slot_inplace(size_t n) { // mask vector mask(this->_size, 0); mask[0] = 1; - this->mul_plain_inplace_impl(mask); + this->mul_plain_inplace(mask); // replicate Ciphertext tmp = this->_ciphertext; @@ -353,8 +353,4 @@ shared_ptr BFVVector::deepcopy() const { BFVVectorProto vec = this->save_proto(); return BFVVector::Create(ctx, vec); } - -bool BFVVector::_check_operation_sanity(){ - return true; -} } // namespace tenseal diff --git a/tenseal/cpp/tensors/bfvvector.h b/tenseal/cpp/tensors/bfvvector.h index 8dfdef75..f5f6db01 100644 --- a/tenseal/cpp/tensors/bfvvector.h +++ b/tenseal/cpp/tensors/bfvvector.h @@ -28,49 +28,36 @@ class BFVVector return shared_ptr( new BFVVector(std::forward(args)...)); } + /** * Decrypts and returns the plaintext representation of the encrypted vector *of integers using the secret-key. **/ plain_t decrypt(const shared_ptr& sk) const override; - /** - * Load/Save the vector from/to a serialized protobuffer. - **/ - void load(const string& vec) override; - string save() const override; - /** - *Recreates a new BFVVector from the current one, without any - *pointer/reference to this one. - * **/ - encrypted_t copy() const override; - encrypted_t deepcopy() const override; - double scale() const override { throw logic_error("not implemented"); } - - protected: /** * Compute the power of the BFVVector with minimal multiplication depth. **/ - encrypted_t power_inplace_impl(unsigned int power) override; + encrypted_t power_inplace(unsigned int power) override; /** * Negates a BFVVector. **/ - encrypted_t negate_inplace_impl() override; + encrypted_t negate_inplace() override; /** * Compute the square of the BFVVector. **/ - encrypted_t square_inplace_impl() override; + encrypted_t square_inplace() override; /** * Encrypted evaluation function operates on two encrypted vectors and * returns a new BFVVector which is the result of either *addition, substraction or multiplication in an element-wise fashion. *in_place functions return a reference to the same object. **/ - encrypted_t add_inplace_impl(const encrypted_t& to_add) override; - encrypted_t sub_inplace_impl(const encrypted_t& to_sub) override; - encrypted_t mul_inplace_impl(const encrypted_t& to_mul) override; - encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) override; - encrypted_t dot_product_plain_inplace_impl(const plain_t& to_mul) override; - encrypted_t sum_inplace_impl(size_t axis = 0) override; + encrypted_t add_inplace(const encrypted_t& to_add) override; + encrypted_t sub_inplace(const encrypted_t& to_sub) override; + encrypted_t mul_inplace(const encrypted_t& to_mul) override; + encrypted_t dot_product_inplace(const encrypted_t& to_mul) override; + encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override; + encrypted_t sum_inplace(size_t axis = 0) override; /** * Plain evaluation function operates on an encrypted vector and plaintext @@ -78,48 +65,56 @@ class BFVVector * either addition, substraction or multiplication in an element-wise *fashion. in_place functions return a reference to the same object. **/ - encrypted_t add_plain_inplace_impl(const plain_t::dtype& to_add) override; - encrypted_t add_plain_inplace_impl(const plain_t& to_add) override; - encrypted_t sub_plain_inplace_impl(const plain_t::dtype& to_sub) override; - encrypted_t sub_plain_inplace_impl(const plain_t& to_sub) override; - encrypted_t mul_plain_inplace_impl(const plain_t::dtype& to_mul) override; - encrypted_t mul_plain_inplace_impl(const plain_t& to_mul) override; + encrypted_t add_plain_inplace(const plain_t::dtype& to_add) override; + encrypted_t add_plain_inplace(const plain_t& to_add) override; + encrypted_t sub_plain_inplace(const plain_t::dtype& to_sub) override; + encrypted_t sub_plain_inplace(const plain_t& to_sub) override; + encrypted_t mul_plain_inplace(const plain_t::dtype& to_mul) override; + encrypted_t mul_plain_inplace(const plain_t& to_mul) override; /** * Encrypted Vector multiplication with plain matrix. **/ - encrypted_t matmul_plain_inplace_impl(const plain_t& matrix, - size_t n_jobs = 0) override; + encrypted_t matmul_plain_inplace(const plain_t& matrix, + size_t n_jobs = 0) override; /** * Encrypted Matrix multiplication with plain vector. **/ - encrypted_t enc_matmul_plain_inplace_impl(const plain_t& plain_vec, - size_t row_size) override; + encrypted_t enc_matmul_plain_inplace(const plain_t& plain_vec, + size_t row_size) override; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * *x^i **/ - encrypted_t polyval_inplace_impl( - const vector& coefficients) override; + encrypted_t polyval_inplace(const vector& coefficients) override; /* * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). * The kernel vector should be padded with zeros to the next power of 2 */ - encrypted_t conv2d_im2col_inplace_impl(const plain_t& kernel, - const size_t windows_nb) override; + encrypted_t conv2d_im2col_inplace(const plain_t& kernel, + const size_t windows_nb) override; /** * Replicate the first slot of a ciphertext n times. Requires a *multiplication. **/ - encrypted_t replicate_first_slot_inplace_impl(size_t n) override; + encrypted_t replicate_first_slot_inplace(size_t n) override; + /** + * Load/Save the vector from/to a serialized protobuffer. + **/ + void load(const string& vec) override; + string save() const override; /** - * Check tensor sanity + *Recreates a new BFVVector from the current one, without any + *pointer/reference to this one. * **/ - bool _check_operation_sanity() override; + encrypted_t copy() const override; + encrypted_t deepcopy() const override; + + double scale() const override { throw logic_error("not implemented"); } private: BFVVector(const shared_ptr& ctx, const plain_t& vec); diff --git a/tenseal/cpp/tensors/ckkstensor.cpp b/tenseal/cpp/tensors/ckkstensor.cpp index a19f9ae0..27551e3e 100644 --- a/tenseal/cpp/tensors/ckkstensor.cpp +++ b/tenseal/cpp/tensors/ckkstensor.cpp @@ -118,13 +118,13 @@ PlainTensor CKKSTensor::decrypt(const shared_ptr& sk) const { } } -shared_ptr CKKSTensor::negate_inplace_impl() { +shared_ptr CKKSTensor::negate_inplace() { for (auto& ct : _data) this->tenseal_context()->evaluator->negate_inplace(ct); return shared_from_this(); } -shared_ptr CKKSTensor::square_inplace_impl() { +shared_ptr CKKSTensor::square_inplace() { for (auto& ct : _data) { this->tenseal_context()->evaluator->square_inplace(ct); this->auto_relin(ct); @@ -133,7 +133,7 @@ shared_ptr CKKSTensor::square_inplace_impl() { return shared_from_this(); } -shared_ptr CKKSTensor::power_inplace_impl(unsigned int power) { +shared_ptr CKKSTensor::power_inplace(unsigned int power) { if (power == 0) { auto ones = PlainTensor::repeat_value(1, this->shape()); *this = CKKSTensor(this->tenseal_context(), ones, this->_init_scale, @@ -146,17 +146,17 @@ shared_ptr CKKSTensor::power_inplace_impl(unsigned int power) { } if (power == 2) { - this->square_inplace_impl(); + this->square_inplace(); return shared_from_this(); } int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); + this->power_inplace(closest_power_of_2 / 2)->square_inplace(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); + this->power_inplace(power)->mul_inplace(closest_pow2_vector); } return shared_from_this(); @@ -349,61 +349,61 @@ shared_ptr CKKSTensor::op_plain_inplace(const double& operand, return shared_from_this(); } -shared_ptr CKKSTensor::add_inplace_impl( +shared_ptr CKKSTensor::add_inplace( const shared_ptr& to_add) { return this->op_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_inplace_impl( +shared_ptr CKKSTensor::sub_inplace( const shared_ptr& to_sub) { return this->op_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_inplace_impl( +shared_ptr CKKSTensor::mul_inplace( const shared_ptr& to_mul) { return this->op_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::dot_product_inplace_impl( +shared_ptr CKKSTensor::dot_product_inplace( const shared_ptr& to_mul) { // TODO return shared_from_this(); } -shared_ptr CKKSTensor::add_plain_inplace_impl( +shared_ptr CKKSTensor::add_plain_inplace( const PlainTensor& to_add) { return this->op_plain_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_plain_inplace_impl( +shared_ptr CKKSTensor::sub_plain_inplace( const PlainTensor& to_sub) { return this->op_plain_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_plain_inplace_impl( +shared_ptr CKKSTensor::mul_plain_inplace( const PlainTensor& to_mul) { return this->op_plain_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::dot_product_plain_inplace_impl( +shared_ptr CKKSTensor::dot_product_plain_inplace( const PlainTensor& to_mul) { // TODO return shared_from_this(); } -shared_ptr CKKSTensor::add_plain_inplace_impl(const double& to_add) { +shared_ptr CKKSTensor::add_plain_inplace(const double& to_add) { return this->op_plain_inplace(to_add, OP::ADD); } -shared_ptr CKKSTensor::sub_plain_inplace_impl(const double& to_sub) { +shared_ptr CKKSTensor::sub_plain_inplace(const double& to_sub) { return this->op_plain_inplace(to_sub, OP::SUB); } -shared_ptr CKKSTensor::mul_plain_inplace_impl(const double& to_mul) { +shared_ptr CKKSTensor::mul_plain_inplace(const double& to_mul) { return this->op_plain_inplace(to_mul, OP::MUL); } -shared_ptr CKKSTensor::sum_inplace_impl(size_t axis) { +shared_ptr CKKSTensor::sum_inplace(size_t axis) { if (axis >= shape_with_batch().size()) throw invalid_argument("invalid axis"); @@ -454,7 +454,7 @@ shared_ptr CKKSTensor::sum_batch_inplace() { return shared_from_this(); } -shared_ptr CKKSTensor::polyval_inplace_impl( +shared_ptr CKKSTensor::polyval_inplace( const vector& coefficients) { if (coefficients.size() == 0) { throw invalid_argument( @@ -485,7 +485,7 @@ shared_ptr CKKSTensor::polyval_inplace_impl( x_squares.reserve(max_square + 1); x_squares.push_back(x->copy()); // x for (int i = 1; i <= max_square; i++) { - x->square_inplace_impl(); + x->square_inplace(); x_squares.push_back(x->copy()); // x^(2^i) } @@ -499,7 +499,7 @@ shared_ptr CKKSTensor::polyval_inplace_impl( for (int i = 1; i <= degree; i++) { if (coefficients[i] == 0.0) continue; x = compute_polynomial_term(i, coefficients[i], x_squares); - result->add_inplace_impl(x); + result->add_inplace(x); } this->_data = TensorStorage(result->data(), result->shape()); @@ -604,9 +604,4 @@ shared_ptr CKKSTensor::reshape_inplace( } double CKKSTensor::scale() const { return _init_scale; } - -bool CKKSTensor::_check_operation_sanity(){ - return true; -} - } // namespace tenseal diff --git a/tenseal/cpp/tensors/ckkstensor.h b/tenseal/cpp/tensors/ckkstensor.h index ecafb6cc..24cda956 100644 --- a/tenseal/cpp/tensors/ckkstensor.h +++ b/tenseal/cpp/tensors/ckkstensor.h @@ -28,11 +28,41 @@ class CKKSTensor : public EncryptedTensor>, PlainTensor decrypt(const shared_ptr& sk) const override; + shared_ptr negate_inplace() override; + shared_ptr square_inplace() override; + shared_ptr power_inplace(unsigned int power) override; + + shared_ptr add_inplace( + const shared_ptr& to_add) override; + shared_ptr sub_inplace( + const shared_ptr& to_sub) override; + shared_ptr mul_inplace( + const shared_ptr& to_mul) override; + shared_ptr dot_product_inplace( + const shared_ptr& to_mul) override; + + shared_ptr add_plain_inplace(const double& to_add) override; + shared_ptr sub_plain_inplace(const double& to_sub) override; + shared_ptr mul_plain_inplace(const double& to_mul) override; + + shared_ptr add_plain_inplace( + const PlainTensor& to_add) override; + shared_ptr sub_plain_inplace( + const PlainTensor& to_sub) override; + shared_ptr mul_plain_inplace( + const PlainTensor& to_mul) override; + shared_ptr dot_product_plain_inplace( + const PlainTensor& to_mul) override; + + shared_ptr sum_inplace(size_t axis = 0) override; shared_ptr sum_batch() { return this->copy()->sum_batch_inplace(); } shared_ptr sum_batch_inplace(); + shared_ptr polyval_inplace( + const vector& coefficients) override; + void load(const string& vec) override; string save() const override; @@ -46,44 +76,6 @@ class CKKSTensor : public EncryptedTensor>, vector shape_with_batch() const; double scale() const override; - protected: - shared_ptr negate_inplace_impl() override; - shared_ptr square_inplace_impl() override; - shared_ptr power_inplace_impl(unsigned int power) override; - - shared_ptr add_inplace_impl( - const shared_ptr& to_add) override; - shared_ptr sub_inplace_impl( - const shared_ptr& to_sub) override; - shared_ptr mul_inplace_impl( - const shared_ptr& to_mul) override; - shared_ptr dot_product_inplace_impl( - const shared_ptr& to_mul) override; - - shared_ptr add_plain_inplace_impl( - const double& to_add) override; - shared_ptr sub_plain_inplace_impl( - const double& to_sub) override; - shared_ptr mul_plain_inplace_impl( - const double& to_mul) override; - - shared_ptr add_plain_inplace_impl( - const PlainTensor& to_add) override; - shared_ptr sub_plain_inplace_impl( - const PlainTensor& to_sub) override; - shared_ptr mul_plain_inplace_impl( - const PlainTensor& to_mul) override; - shared_ptr dot_product_plain_inplace_impl( - const PlainTensor& to_mul) override; - - shared_ptr sum_inplace_impl(size_t axis = 0) override; - shared_ptr polyval_inplace_impl( - const vector& coefficients) override; - /** - * Check tensor sanity - * **/ - bool _check_operation_sanity() override; - private: TensorStorage _data; double _init_scale; diff --git a/tenseal/cpp/tensors/ckksvector.cpp b/tenseal/cpp/tensors/ckksvector.cpp index 84333ef1..7cd05e2a 100644 --- a/tenseal/cpp/tensors/ckksvector.cpp +++ b/tenseal/cpp/tensors/ckksvector.cpp @@ -82,7 +82,7 @@ CKKSVector::plain_t CKKSVector::decrypt(const shared_ptr& sk) const { return vector(result.cbegin(), result.cbegin() + this->size()); } -shared_ptr CKKSVector::power_inplace_impl(unsigned int power) { +shared_ptr CKKSVector::power_inplace(unsigned int power) { // if the power is zero, return a new encrypted vector of ones if (power == 0) { vector ones(this->size(), 1); @@ -95,29 +95,29 @@ shared_ptr CKKSVector::power_inplace_impl(unsigned int power) { } if (power == 2) { - this->square_inplace_impl(); + this->square_inplace(); return shared_from_this(); } int closest_power_of_2 = 1 << static_cast(floor(log2(power))); power -= closest_power_of_2; if (power == 0) { - this->power_inplace_impl(closest_power_of_2 / 2)->square_inplace_impl(); + this->power_inplace(closest_power_of_2 / 2)->square_inplace(); } else { auto closest_pow2_vector = this->power(closest_power_of_2); - this->power_inplace_impl(power)->mul_inplace_impl(closest_pow2_vector); + this->power_inplace(power)->mul_inplace(closest_pow2_vector); } return shared_from_this(); } -shared_ptr CKKSVector::negate_inplace_impl() { +shared_ptr CKKSVector::negate_inplace() { this->tenseal_context()->evaluator->negate_inplace(this->_ciphertext); return shared_from_this(); } -shared_ptr CKKSVector::square_inplace_impl() { +shared_ptr CKKSVector::square_inplace() { this->tenseal_context()->evaluator->square_inplace(_ciphertext); this->auto_relin(_ciphertext); this->auto_rescale(_ciphertext); @@ -125,7 +125,7 @@ shared_ptr CKKSVector::square_inplace_impl() { return shared_from_this(); } -shared_ptr CKKSVector::add_inplace_impl( +shared_ptr CKKSVector::add_inplace( const shared_ptr& other) { auto to_add = other; if (!this->tenseal_context()->equals(to_add->tenseal_context())) { @@ -143,7 +143,7 @@ shared_ptr CKKSVector::add_inplace_impl( return shared_from_this(); } -shared_ptr CKKSVector::sub_inplace_impl( +shared_ptr CKKSVector::sub_inplace( const shared_ptr& other) { auto to_sub = other; if (!this->tenseal_context()->equals(to_sub->tenseal_context())) { @@ -161,7 +161,7 @@ shared_ptr CKKSVector::sub_inplace_impl( return shared_from_this(); } -shared_ptr CKKSVector::mul_inplace_impl( +shared_ptr CKKSVector::mul_inplace( const shared_ptr& other) { auto to_mul = other; if (!this->tenseal_context()->equals(to_mul->tenseal_context())) { @@ -182,36 +182,36 @@ shared_ptr CKKSVector::mul_inplace_impl( return shared_from_this(); } -shared_ptr CKKSVector::dot_product_inplace_impl( +shared_ptr CKKSVector::dot_product_inplace( const shared_ptr& to_mul) { - this->mul_inplace_impl(to_mul); - this->sum_inplace_impl(); + this->mul_inplace(to_mul); + this->sum_inplace(); return shared_from_this(); } -shared_ptr CKKSVector::dot_product_plain_inplace_impl( +shared_ptr CKKSVector::dot_product_plain_inplace( const plain_t& to_mul) { - this->mul_plain_inplace_impl(to_mul); - this->sum_inplace_impl(); + this->mul_plain_inplace(to_mul); + this->sum_inplace(); return shared_from_this(); } -shared_ptr CKKSVector::sum_inplace_impl(size_t /*axis = 0*/) { +shared_ptr CKKSVector::sum_inplace(size_t /*axis = 0*/) { sum_vector(this->tenseal_context(), this->_ciphertext, this->size()); this->_size = 1; return shared_from_this(); } -shared_ptr CKKSVector::add_plain_inplace_impl(const plain_t& to_add) { +shared_ptr CKKSVector::add_plain_inplace(const plain_t& to_add) { if (this->size() != to_add.size()) { throw invalid_argument("can't add vectors of different sizes"); } return this->_add_plain_inplace(to_add.data()); } -shared_ptr CKKSVector::add_plain_inplace_impl(const double& to_add) { +shared_ptr CKKSVector::add_plain_inplace(const double& to_add) { return this->_add_plain_inplace(to_add); } @@ -226,14 +226,14 @@ shared_ptr CKKSVector::_add_plain_inplace(const T& to_add) { return shared_from_this(); } -shared_ptr CKKSVector::sub_plain_inplace_impl(const plain_t& to_sub) { +shared_ptr CKKSVector::sub_plain_inplace(const plain_t& to_sub) { if (this->size() != to_sub.size()) { throw invalid_argument("can't sub vectors of different sizes"); } return this->_sub_plain_inplace(to_sub.data()); } -shared_ptr CKKSVector::sub_plain_inplace_impl(const double& to_sub) { +shared_ptr CKKSVector::sub_plain_inplace(const double& to_sub) { return this->_sub_plain_inplace(to_sub); } @@ -250,7 +250,7 @@ shared_ptr CKKSVector::_sub_plain_inplace(const T& to_sub) { return shared_from_this(); } -shared_ptr CKKSVector::mul_plain_inplace_impl(const plain_t& to_mul) { +shared_ptr CKKSVector::mul_plain_inplace(const plain_t& to_mul) { if (this->size() != to_mul.size()) { throw invalid_argument("can't multiply vectors of different sizes"); } @@ -258,7 +258,7 @@ shared_ptr CKKSVector::mul_plain_inplace_impl(const plain_t& to_mul) return this->_mul_plain_inplace(to_mul.data()); } -shared_ptr CKKSVector::mul_plain_inplace_impl(const double& to_mul) { +shared_ptr CKKSVector::mul_plain_inplace(const double& to_mul) { return this->_mul_plain_inplace(to_mul); } @@ -288,7 +288,7 @@ shared_ptr CKKSVector::_mul_plain_inplace(const T& to_mul) { return this->copy(); } -shared_ptr CKKSVector::matmul_plain_inplace_impl( +shared_ptr CKKSVector::matmul_plain_inplace( const CKKSVector::plain_t& matrix, size_t n_jobs) { this->_ciphertext = this->diagonal_ct_vector_matmul(matrix, n_jobs); @@ -298,7 +298,7 @@ shared_ptr CKKSVector::matmul_plain_inplace_impl( return shared_from_this(); } -shared_ptr CKKSVector::polyval_inplace_impl( +shared_ptr CKKSVector::polyval_inplace( const vector& coefficients) { if (coefficients.size() == 0) { throw invalid_argument( @@ -335,7 +335,7 @@ shared_ptr CKKSVector::polyval_inplace_impl( x_squares.reserve(max_square + 1); x_squares.push_back(x->copy()); // x for (int i = 1; i <= max_square; i++) { - x->square_inplace_impl(); + x->square_inplace(); x_squares.push_back(x->copy()); // x^(2^i) } @@ -343,14 +343,14 @@ shared_ptr CKKSVector::polyval_inplace_impl( for (int i = 1; i <= degree; i++) { if (coefficients[i] == 0.0) continue; x = compute_polynomial_term(i, coefficients[i], x_squares); - result->add_inplace_impl(x); + result->add_inplace(x); } this->_ciphertext = result->ciphertext(); return shared_from_this(); } -shared_ptr CKKSVector::conv2d_im2col_inplace_impl( +shared_ptr CKKSVector::conv2d_im2col_inplace( const CKKSVector::plain_t& kernel, const size_t windows_nb) { if (windows_nb == 0) { throw invalid_argument("Windows number can't be zero"); @@ -362,11 +362,11 @@ shared_ptr CKKSVector::conv2d_im2col_inplace_impl( // flat the kernel auto flatten_kernel = kernel.horizontal_scan(); - this->enc_matmul_plain_inplace_impl(flatten_kernel, windows_nb); + this->enc_matmul_plain_inplace(flatten_kernel, windows_nb); return shared_from_this(); } -shared_ptr CKKSVector::enc_matmul_plain_inplace_impl( +shared_ptr CKKSVector::enc_matmul_plain_inplace( const CKKSVector::plain_t& plain_vec, const size_t rows_nb) { if (plain_vec.empty()) { throw invalid_argument("Plain vector can't be empty"); @@ -400,7 +400,7 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace_impl( replicate_vector(new_plain_vec, slot_count); this->_size = slot_count; - this->mul_plain_inplace_impl(new_plain_vec); + this->mul_plain_inplace(new_plain_vec); auto galois_keys = this->tenseal_context()->galois_keys(); @@ -412,7 +412,7 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace_impl( 1 << (static_cast(ceil(log2(chunks_nb))) - 1)); tmp->rotate_vector_inplace(static_cast(rows_nb * chunks_nb), *galois_keys); - this->add_inplace_impl(tmp); + this->add_inplace(tmp); } this->_size = rows_nb; @@ -420,11 +420,11 @@ shared_ptr CKKSVector::enc_matmul_plain_inplace_impl( return shared_from_this(); } -shared_ptr CKKSVector::replicate_first_slot_inplace_impl(size_t n) { +shared_ptr CKKSVector::replicate_first_slot_inplace(size_t n) { // mask vector mask(this->_size, 0); mask[0] = 1; - this->mul_plain_inplace_impl(mask); + this->mul_plain_inplace(mask); // replicate Ciphertext tmp = this->_ciphertext; @@ -491,8 +491,4 @@ shared_ptr CKKSVector::deepcopy() const { return CKKSVector::Create(ctx, vec); } -bool CKKSVector::_check_operation_sanity(){ - return true; -} - } // namespace tenseal diff --git a/tenseal/cpp/tensors/ckksvector.h b/tenseal/cpp/tensors/ckksvector.h index 7e2b5170..792de59d 100644 --- a/tenseal/cpp/tensors/ckksvector.h +++ b/tenseal/cpp/tensors/ckksvector.h @@ -30,46 +30,31 @@ class CKKSVector *of real numbers using the secret-key. **/ plain_t decrypt(const shared_ptr& sk) const override; - /** - * Load/Save the vector from/to a serialized protobuffer. - **/ - void load(const string& vec) override; - string save() const override; - - /** - *Recreates a new CKKSVector from the current one, without any - *pointer/reference to this one. - **/ - encrypted_t copy() const override; - encrypted_t deepcopy() const override; - - double scale() const override { return _init_scale; } - protected: /** * Compute the power of the CKKSVector with minimal multiplication depth. **/ - encrypted_t power_inplace_impl(unsigned int power) override; + encrypted_t power_inplace(unsigned int power) override; /** * Negates a CKKSVector. **/ - encrypted_t negate_inplace_impl() override; + encrypted_t negate_inplace() override; /** * Compute the square of the CKKSVector. **/ - encrypted_t square_inplace_impl() override; + encrypted_t square_inplace() override; /** * Encrypted evaluation function operates on two encrypted vectors and * returns a new CKKSVector which is the result of either *addition, substraction or multiplication in an element-wise fashion. *in_place functions return a reference to the same object. **/ - encrypted_t add_inplace_impl(const encrypted_t& to_add) override; - encrypted_t sub_inplace_impl(const encrypted_t& to_sub) override; - encrypted_t mul_inplace_impl(const encrypted_t& to_mul) override; - encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) override; - encrypted_t dot_product_plain_inplace_impl(const plain_t& to_mul) override; - encrypted_t sum_inplace_impl(size_t axis = 0) override; + encrypted_t add_inplace(const encrypted_t& to_add) override; + encrypted_t sub_inplace(const encrypted_t& to_sub) override; + encrypted_t mul_inplace(const encrypted_t& to_mul) override; + encrypted_t dot_product_inplace(const encrypted_t& to_mul) override; + encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override; + encrypted_t sum_inplace(size_t axis = 0) override; /** * Plain evaluation function operates on an encrypted vector and plaintext @@ -77,49 +62,58 @@ class CKKSVector * either addition, substraction or multiplication in an element-wise *fashion. in_place functions return a reference to the same object. **/ - encrypted_t add_plain_inplace_impl(const plain_t::dtype& to_add) override; - encrypted_t add_plain_inplace_impl(const plain_t& to_add) override; - encrypted_t sub_plain_inplace_impl(const plain_t::dtype& to_sub) override; - encrypted_t sub_plain_inplace_impl(const plain_t& to_sub) override; - encrypted_t mul_plain_inplace_impl(const plain_t::dtype& to_mul) override; - encrypted_t mul_plain_inplace_impl(const plain_t& to_mul) override; + encrypted_t add_plain_inplace(const plain_t::dtype& to_add) override; + encrypted_t add_plain_inplace(const plain_t& to_add) override; + encrypted_t sub_plain_inplace(const plain_t::dtype& to_sub) override; + encrypted_t sub_plain_inplace(const plain_t& to_sub) override; + encrypted_t mul_plain_inplace(const plain_t::dtype& to_mul) override; + encrypted_t mul_plain_inplace(const plain_t& to_mul) override; /** * Encrypted Vector multiplication with plain matrix. **/ - encrypted_t matmul_plain_inplace_impl(const plain_t& matrix, - size_t n_jobs = 0) override; + encrypted_t matmul_plain_inplace(const plain_t& matrix, + size_t n_jobs = 0) override; /** * Encrypted Matrix multiplication with plain vector. **/ - encrypted_t enc_matmul_plain_inplace_impl(const plain_t& plain_vec, - size_t row_size) override; + encrypted_t enc_matmul_plain_inplace(const plain_t& plain_vec, + size_t row_size) override; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * *x^i **/ - encrypted_t polyval_inplace_impl( - const vector& coefficients) override; + encrypted_t polyval_inplace(const vector& coefficients) override; /* * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). * The kernel vector should be padded with zeros to the next power of 2 */ - encrypted_t conv2d_im2col_inplace_impl(const plain_t& kernel, - const size_t windows_nb) override; + encrypted_t conv2d_im2col_inplace(const plain_t& kernel, + const size_t windows_nb) override; /** * Replicate the first slot of a ciphertext n times. Requires a *multiplication. **/ - encrypted_t replicate_first_slot_inplace_impl(size_t n) override; + encrypted_t replicate_first_slot_inplace(size_t n) override; /** - * Check tensor sanity - * **/ - bool _check_operation_sanity() override; + * Load/Save the vector from/to a serialized protobuffer. + **/ + void load(const string& vec) override; + string save() const override; + + /** + *Recreates a new CKKSVector from the current one, without any + *pointer/reference to this one. + **/ + encrypted_t copy() const override; + encrypted_t deepcopy() const override; + + double scale() const override { return _init_scale; } private: double _init_scale; diff --git a/tenseal/cpp/tensors/encrypted_tensor.h b/tenseal/cpp/tensors/encrypted_tensor.h index f62ef306..c9fa083a 100644 --- a/tenseal/cpp/tensors/encrypted_tensor.h +++ b/tenseal/cpp/tensors/encrypted_tensor.h @@ -38,22 +38,17 @@ class EncryptedTensor { }; virtual PlainTensor decrypt( const shared_ptr& sk) const = 0; + /** * Negates a EncryptedTensor. **/ encrypted_t negate() const { return this->copy()->negate_inplace(); }; - encrypted_t negate_inplace() { - this->_check_operation_sanity(); - return negate_inplace_impl(); - }; + virtual encrypted_t negate_inplace() = 0; /** * Compute the square of the EncryptedTensor. **/ encrypted_t square() const { return this->copy()->square_inplace(); }; - encrypted_t square_inplace() { - this->_check_operation_sanity(); - return square_inplace_impl(); - }; + virtual encrypted_t square_inplace() = 0; /** * Compute the power of the EncryptedTensor with *minimal multiplication depth. @@ -61,10 +56,7 @@ class EncryptedTensor { encrypted_t power(unsigned int power) const { return this->copy()->power_inplace(power); }; - encrypted_t power_inplace(unsigned int power) { - this->_check_operation_sanity(); - return power_inplace_impl(power); - }; + virtual encrypted_t power_inplace(unsigned int power) = 0; /** * Encrypted evaluation function operates on two encrypted tensors and * returns a new EncryptedTensor @@ -75,31 +67,19 @@ class EncryptedTensor { encrypted_t add(const encrypted_t& to_add) const { return this->copy()->add_inplace(to_add); }; - encrypted_t add_inplace(const encrypted_t& to_add) { - this->_check_operation_sanity(); - return add_inplace_impl(to_add); - }; + virtual encrypted_t add_inplace(const encrypted_t& to_add) = 0; encrypted_t sub(encrypted_t to_sub) const { return this->copy()->sub_inplace(to_sub); }; - encrypted_t sub_inplace(const encrypted_t& to_sub) { - this->_check_operation_sanity(); - return sub_inplace_impl(to_sub); - }; + virtual encrypted_t sub_inplace(const encrypted_t& to_sub) = 0; encrypted_t mul(encrypted_t to_mul) const { return this->copy()->mul_inplace(to_mul); }; - encrypted_t mul_inplace(const encrypted_t& to_mul) { - this->_check_operation_sanity(); - return mul_inplace_impl(to_mul); - }; + virtual encrypted_t mul_inplace(const encrypted_t& to_mul) = 0; encrypted_t dot_product(encrypted_t to_mul) const { return this->copy()->dot_product_inplace(to_mul); }; - encrypted_t dot_product_inplace(const encrypted_t& to_mul) { - this->_check_operation_sanity(); - return dot_product_inplace_impl(to_mul); - }; + virtual encrypted_t dot_product_inplace(const encrypted_t& to_mul) = 0; /** * Plain evaluation function operates on an encrypted tensors and plaintext * tensors and returns a new EncryptedTensor @@ -113,14 +93,9 @@ class EncryptedTensor { encrypted_t add_plain(const PlainTensor& to_add) const { return this->copy()->add_plain_inplace(to_add); }; - encrypted_t add_plain_inplace(const plain_data_t& to_add) { - this->_check_operation_sanity(); - return add_plain_inplace_impl(to_add); - }; - encrypted_t add_plain_inplace(const PlainTensor& to_add) { - this->_check_operation_sanity(); - return add_plain_inplace_impl(to_add); - }; + virtual encrypted_t add_plain_inplace(const plain_data_t& to_add) = 0; + virtual encrypted_t add_plain_inplace( + const PlainTensor& to_add) = 0; encrypted_t sub_plain(const plain_data_t& to_sub) const { return this->copy()->sub_plain_inplace(to_sub); @@ -128,14 +103,9 @@ class EncryptedTensor { encrypted_t sub_plain(const PlainTensor& to_sub) const { return this->copy()->sub_plain_inplace(to_sub); }; - encrypted_t sub_plain_inplace(const plain_data_t& to_sub) { - this->_check_operation_sanity(); - return sub_plain_inplace_impl(to_sub); - }; - encrypted_t sub_plain_inplace(const PlainTensor& to_sub) { - this->_check_operation_sanity(); - return sub_plain_inplace_impl(to_sub); - }; + virtual encrypted_t sub_plain_inplace(const plain_data_t& to_sub) = 0; + virtual encrypted_t sub_plain_inplace( + const PlainTensor& to_sub) = 0; encrypted_t mul_plain(const plain_data_t& to_mul) const { return this->copy()->mul_plain_inplace(to_mul); @@ -143,31 +113,20 @@ class EncryptedTensor { encrypted_t mul_plain(const PlainTensor& to_mul) const { return this->copy()->mul_plain_inplace(to_mul); }; - encrypted_t mul_plain_inplace(const plain_data_t& to_mul) { - this->_check_operation_sanity(); - return mul_plain_inplace_impl(to_mul); - }; - encrypted_t mul_plain_inplace(const PlainTensor& to_mul) { - this->_check_operation_sanity(); - return mul_plain_inplace_impl(to_mul); - }; + virtual encrypted_t mul_plain_inplace(const plain_data_t& to_mul) = 0; + virtual encrypted_t mul_plain_inplace( + const PlainTensor& to_mul) = 0; encrypted_t dot_product_plain( const PlainTensor& to_mul) const { return this->copy()->dot_product_plain_inplace(to_mul); }; - encrypted_t dot_product_plain_inplace( - const PlainTensor& to_mul) { - this->_check_operation_sanity(); - return dot_product_plain_inplace_impl(to_mul); - }; + virtual encrypted_t dot_product_plain_inplace( + const PlainTensor& to_mul) = 0; encrypted_t sum(size_t axis = 0) const { return this->copy()->sum_inplace(axis); }; - encrypted_t sum_inplace(size_t axis = 0) { - this->_check_operation_sanity(); - return sum_inplace_impl(axis); - }; + virtual encrypted_t sum_inplace(size_t axis) = 0; /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * @@ -176,10 +135,7 @@ class EncryptedTensor { encrypted_t polyval(const vector& coefficients) const { return this->copy()->polyval_inplace(coefficients); }; - encrypted_t polyval_inplace(const vector& coefficients) { - this->_check_operation_sanity(); - return polyval_inplace_impl(coefficients); - }; + virtual encrypted_t polyval_inplace(const vector& coefficients) = 0; /** * Load/Save the Tensor from/to a serialized protobuffer. **/ @@ -297,34 +253,6 @@ class EncryptedTensor { protected: shared_ptr _context; - /** - * Sanity checks for the tensor - * */ - virtual bool _check_operation_sanity() = 0; - /** - * Implementations for the operations - * **/ - virtual encrypted_t negate_inplace_impl() = 0; - virtual encrypted_t square_inplace_impl() = 0; - virtual encrypted_t power_inplace_impl(unsigned int power) = 0; - virtual encrypted_t add_inplace_impl(const encrypted_t& to_add) = 0; - virtual encrypted_t sub_inplace_impl(const encrypted_t& to_sub) = 0; - virtual encrypted_t mul_inplace_impl(const encrypted_t& to_mul) = 0; - virtual encrypted_t dot_product_inplace_impl(const encrypted_t& to_mul) = 0; - virtual encrypted_t add_plain_inplace_impl(const plain_data_t& to_add) = 0; - virtual encrypted_t add_plain_inplace_impl( - const PlainTensor& to_add) = 0; - virtual encrypted_t sub_plain_inplace_impl(const plain_data_t& to_sub) = 0; - virtual encrypted_t sub_plain_inplace_impl( - const PlainTensor& to_sub) = 0; - virtual encrypted_t mul_plain_inplace_impl(const plain_data_t& to_mul) = 0; - virtual encrypted_t mul_plain_inplace_impl( - const PlainTensor& to_mul) = 0; - virtual encrypted_t dot_product_plain_inplace_impl( - const PlainTensor& to_mul) = 0; - virtual encrypted_t sum_inplace_impl(size_t axis) = 0; - virtual encrypted_t polyval_inplace_impl( - const vector& coefficients) = 0; private: }; diff --git a/tenseal/cpp/tensors/encrypted_vector.h b/tenseal/cpp/tensors/encrypted_vector.h index 7ad42135..39be5e68 100644 --- a/tenseal/cpp/tensors/encrypted_vector.h +++ b/tenseal/cpp/tensors/encrypted_vector.h @@ -20,25 +20,25 @@ using namespace std; *EncryptedVector pure methods: * * vector EncryptedTensor::decrypt(const shared_ptr&) *const = 0; - * * encrypted_t negate_inplace_impl(); - * * encrypted_t square_inplace_impl(); - * * encrypted_t add_inplace_impl(encrypted_t to_add); - * * encrypted_t sub_inplace_impl(encrypted_t to_sub); - * * encrypted_t mul_inplace_impl(encrypted_t to_mul); - * * encrypted_t dot_product_inplace_impl(encrypted_t to_mul); - * * encrypted_t dot_product_plain_inplace_impl( const vector& to_mul); - * * encrypted_t sum_inplace_impl(); - * * encrypted_t EncryptedTensor::power_inplace_impl(unsigned int power) = 0; - * * encrypted_t EncryptedTensor::add_plain_inplace_impl(plain_t to_add) = 0; - * * encrypted_t EncryptedTensor::add_plain_inplace_impl(const - *PlainTensor& to_add) = 0; - * * encrypted_t EncryptedTensor::sub_plain_inplace_impl(plain_t to_sub) = 0; - * * encrypted_t EncryptedTensor::sub_plain_inplace_impl(const - *PlainTensor& to_sub) = 0; - * * encrypted_t EncryptedTensor::mul_plain_inplace_impl(plain_t to_mul) = 0; - * * encrypted_t EncryptedTensor::mul_plain_inplace_impl(const - *PlainTensor& to_mul) = 0; - * * encrypted_t EncryptedTensor::polyval_inplace_impl(const vector& + * * encrypted_t negate_inplace(); + * * encrypted_t square_inplace(); + * * encrypted_t add_inplace(encrypted_t to_add); + * * encrypted_t sub_inplace(encrypted_t to_sub); + * * encrypted_t mul_inplace(encrypted_t to_mul); + * * encrypted_t dot_product_inplace(encrypted_t to_mul); + * * encrypted_t dot_product_plain_inplace( const vector& to_mul); + * * encrypted_t sum_inplace(); + * * encrypted_t EncryptedTensor::power_inplace(unsigned int power) = 0; + * * encrypted_t EncryptedTensor::add_plain_inplace(plain_t to_add) = 0; + * * encrypted_t EncryptedTensor::add_plain_inplace(const PlainTensor& + *to_add) = 0; + * * encrypted_t EncryptedTensor::sub_plain_inplace(plain_t to_sub) = 0; + * * encrypted_t EncryptedTensor::sub_plain_inplace(const PlainTensor& + *to_sub) = 0; + * * encrypted_t EncryptedTensor::mul_plain_inplace(plain_t to_mul) = 0; + * * encrypted_t EncryptedTensor::mul_plain_inplace(const PlainTensor& + *to_mul) = 0; + * * encrypted_t EncryptedTensor::polyval_inplace(const vector& *coefficients) = 0; * * void EncryptedTensor::load(const string& vec) = 0; * * string EncryptedTensor::save() const = 0; @@ -67,10 +67,7 @@ class EncryptedVector : public EncryptedTensor { encrypted_t replicate_first_slot(size_t n) const { return this->copy()->replicate_first_slot_inplace(n); } - encrypted_t replicate_first_slot_inplace(size_t n) { - this->_check_operation_sanity(); - return replicate_first_slot_inplace_impl(n); - }; + virtual encrypted_t replicate_first_slot_inplace(size_t n) = 0; /** * Adjust two vectors to match sizes. * @return the right operand, in case it was copied and altered. @@ -98,11 +95,8 @@ class EncryptedVector : public EncryptedTensor { size_t n_jobs = 0) const { return this->copy()->matmul_plain_inplace(matrix, n_jobs); } - encrypted_t matmul_plain_inplace(const PlainTensor& matrix, - size_t n_jobs = 0) { - this->_check_operation_sanity(); - return matmul_plain_inplace_impl(matrix, n_jobs); - }; + virtual encrypted_t matmul_plain_inplace(const PlainTensor& matrix, + size_t n_jobs = 0) = 0; /** * Encrypted Matrix multiplication with plain vector. **/ @@ -110,11 +104,8 @@ class EncryptedVector : public EncryptedTensor { size_t row_size) const { return this->copy()->enc_matmul_plain_inplace(plain_vec, row_size); } - encrypted_t enc_matmul_plain_inplace(const PlainTensor& plain_vec, - size_t row_size) { - this->_check_operation_sanity(); - return enc_matmul_plain_inplace_impl(plain_vec, row_size); - }; + virtual encrypted_t enc_matmul_plain_inplace( + const PlainTensor& plain_vec, size_t row_size) = 0; /** * Image Block to Columns. * The input matrix should be encoded in a vertical scan (column-major). @@ -124,11 +115,8 @@ class EncryptedVector : public EncryptedTensor { const size_t windows_nb) const { return this->copy()->conv2d_im2col_inplace(kernel, windows_nb); } - encrypted_t conv2d_im2col_inplace(const PlainTensor& kernel, - const size_t windows_nb) { - this->_check_operation_sanity(); - return conv2d_im2col_inplace_impl(kernel, windows_nb); - }; + virtual encrypted_t conv2d_im2col_inplace( + const PlainTensor& kernel, const size_t windows_nb) = 0; /** * Rotate encrypted plaintext cyclically @@ -245,14 +233,6 @@ class EncryptedVector : public EncryptedTensor { protected: size_t _size; Ciphertext _ciphertext; - - virtual encrypted_t replicate_first_slot_inplace_impl(size_t n) = 0; - virtual encrypted_t matmul_plain_inplace_impl( - const PlainTensor& matrix, size_t n_jobs = 0) = 0; - virtual encrypted_t enc_matmul_plain_inplace_impl( - const PlainTensor& plain_vec, size_t row_size) = 0; - virtual encrypted_t conv2d_im2col_inplace_impl( - const PlainTensor& kernel, const size_t windows_nb) = 0; }; } // namespace tenseal From 5f9b08afa5935c037621199a7b7a1c24068d966c Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Mon, 21 Dec 2020 17:59:26 +0200 Subject: [PATCH 3/7] add lazy loading --- tenseal/cpp/tensors/bfvvector.cpp | 7 +++++-- tenseal/cpp/tensors/bfvvector.h | 1 + tenseal/cpp/tensors/ckkstensor.cpp | 7 +++++++ tenseal/cpp/tensors/ckkstensor.h | 1 + tenseal/cpp/tensors/ckksvector.cpp | 7 +++++++ tenseal/cpp/tensors/ckksvector.h | 1 + tenseal/cpp/tensors/encrypted_tensor.h | 5 +++++ tests/cpp/tensors/bfvvector_test.cpp | 19 +++++++++++++++++++ tests/cpp/tensors/ckkstensor_test.cpp | 24 ++++++++++++++++++++++++ tests/cpp/tensors/ckksvector_test.cpp | 22 ++++++++++++++++++++++ 10 files changed, 92 insertions(+), 2 deletions(-) diff --git a/tenseal/cpp/tensors/bfvvector.cpp b/tenseal/cpp/tensors/bfvvector.cpp index 59c25bab..77e717c3 100644 --- a/tenseal/cpp/tensors/bfvvector.cpp +++ b/tenseal/cpp/tensors/bfvvector.cpp @@ -24,6 +24,8 @@ BFVVector::BFVVector(const shared_ptr& ctx, const string& vec) { this->load(vec); } +BFVVector::BFVVector(const string& vec) { this->load(vec); } + BFVVector::BFVVector(const shared_ptr& ctx, const BFVVectorProto& vec) { this->prepare_context(ctx); @@ -321,8 +323,9 @@ BFVVectorProto BFVVector::save_proto() const { } void BFVVector::load(const std::string& vec) { - if (this->tenseal_context() == nullptr) { - throw invalid_argument("context missing for deserialization"); + if (this->_context == nullptr) { + _lazy_buffer = vec; + return; } BFVVectorProto buffer; if (!buffer.ParseFromArray(vec.c_str(), static_cast(vec.size()))) { diff --git a/tenseal/cpp/tensors/bfvvector.h b/tenseal/cpp/tensors/bfvvector.h index f5f6db01..2c507678 100644 --- a/tenseal/cpp/tensors/bfvvector.h +++ b/tenseal/cpp/tensors/bfvvector.h @@ -120,6 +120,7 @@ class BFVVector BFVVector(const shared_ptr& ctx, const plain_t& vec); BFVVector(const shared_ptr&); BFVVector(const shared_ptr& ctx, const string& vec); + BFVVector(const string& vec); BFVVector(const TenSEALContextProto& ctx, const BFVVectorProto& vec); BFVVector(const shared_ptr& ctx, const BFVVectorProto& vec); diff --git a/tenseal/cpp/tensors/ckkstensor.cpp b/tenseal/cpp/tensors/ckkstensor.cpp index 27551e3e..b5096fa8 100644 --- a/tenseal/cpp/tensors/ckkstensor.cpp +++ b/tenseal/cpp/tensors/ckkstensor.cpp @@ -41,6 +41,8 @@ CKKSTensor::CKKSTensor(const shared_ptr& ctx, this->load(tensor); } +CKKSTensor::CKKSTensor(const string& tensor) { this->load(tensor); } + CKKSTensor::CKKSTensor(const TenSEALContextProto& ctx, const CKKSTensorProto& tensor) { this->load_context_proto(ctx); @@ -550,6 +552,11 @@ CKKSTensorProto CKKSTensor::save_proto() const { } void CKKSTensor::load(const std::string& tensor_str) { + if (this->_context == nullptr) { + _lazy_buffer = tensor_str; + return; + } + CKKSTensorProto buffer; if (!buffer.ParseFromArray(tensor_str.c_str(), static_cast(tensor_str.size()))) { diff --git a/tenseal/cpp/tensors/ckkstensor.h b/tenseal/cpp/tensors/ckkstensor.h index 24cda956..5bc124ea 100644 --- a/tenseal/cpp/tensors/ckkstensor.h +++ b/tenseal/cpp/tensors/ckkstensor.h @@ -86,6 +86,7 @@ class CKKSTensor : public EncryptedTensor>, std::optional scale = {}, bool batch = true); CKKSTensor(const TenSEALContextProto& ctx, const CKKSTensorProto& tensor); CKKSTensor(const shared_ptr& ctx, const string& vec); + CKKSTensor(const string& vec); CKKSTensor(const shared_ptr& ctx, const CKKSTensorProto& tensor); CKKSTensor(const shared_ptr& vec); diff --git a/tenseal/cpp/tensors/ckksvector.cpp b/tenseal/cpp/tensors/ckksvector.cpp index 7cd05e2a..27fd2001 100644 --- a/tenseal/cpp/tensors/ckksvector.cpp +++ b/tenseal/cpp/tensors/ckksvector.cpp @@ -26,6 +26,8 @@ CKKSVector::CKKSVector(const shared_ptr& ctx, this->load(vec); } +CKKSVector::CKKSVector(const string& vec) { this->load(vec); } + CKKSVector::CKKSVector(const TenSEALContextProto& ctx, const CKKSVectorProto& vec) { this->load_context_proto(ctx); @@ -461,6 +463,11 @@ CKKSVectorProto CKKSVector::save_proto() const { } void CKKSVector::load(const std::string& vec) { + if (this->_context == nullptr) { + _lazy_buffer = vec; + return; + } + CKKSVectorProto buffer; if (!buffer.ParseFromArray(vec.c_str(), static_cast(vec.size()))) { throw invalid_argument("failed to parse CKKS stream"); diff --git a/tenseal/cpp/tensors/ckksvector.h b/tenseal/cpp/tensors/ckksvector.h index 792de59d..fb0be5d6 100644 --- a/tenseal/cpp/tensors/ckksvector.h +++ b/tenseal/cpp/tensors/ckksvector.h @@ -130,6 +130,7 @@ class CKKSVector CKKSVector(const shared_ptr& ctx, const plain_t& vec, optional scale = {}); CKKSVector(const shared_ptr& ctx, const string& vec); + CKKSVector(const string& vec); CKKSVector(const TenSEALContextProto& ctx, const CKKSVectorProto& vec); CKKSVector(const shared_ptr& ctx, const CKKSVectorProto& vec); diff --git a/tenseal/cpp/tensors/encrypted_tensor.h b/tenseal/cpp/tensors/encrypted_tensor.h index c9fa083a..cfc41b05 100644 --- a/tenseal/cpp/tensors/encrypted_tensor.h +++ b/tenseal/cpp/tensors/encrypted_tensor.h @@ -161,6 +161,10 @@ class EncryptedTensor { **/ void link_tenseal_context(shared_ptr ctx) { this->_context = ctx; + if (_lazy_buffer) { + this->load(*_lazy_buffer); + _lazy_buffer = {}; + } }; void load_context_proto(const TenSEALContextProto& ctx) { this->link_tenseal_context(TenSEALContext::Create(ctx)); @@ -253,6 +257,7 @@ class EncryptedTensor { protected: shared_ptr _context; + optional _lazy_buffer; private: }; diff --git a/tests/cpp/tensors/bfvvector_test.cpp b/tests/cpp/tensors/bfvvector_test.cpp index 3787266f..c9f9aa7f 100644 --- a/tests/cpp/tensors/bfvvector_test.cpp +++ b/tests/cpp/tensors/bfvvector_test.cpp @@ -102,6 +102,25 @@ TEST_P(BFVVectorTest, TestEmptyPlaintext) { std::exception); } +TEST_F(BFVVectorTest, TestBFVLazyLoading) { + auto ctx = TenSEALContext::Create(scheme_type::bfv, 8192, 1032193, {}); + ASSERT_TRUE(ctx != nullptr); + + auto l = BFVVector::Create(ctx, vector({1, 2, 3})); + auto r = BFVVector::Create(ctx, vector({2, 3, 4})); + + auto buffer = l->save(); + auto newl = BFVVector::Create(buffer); + + EXPECT_THROW(newl->add(r), std::exception); + + newl->link_tenseal_context(ctx); + auto res = newl->add(r); + + auto decr = res->decrypt(); + EXPECT_THAT(decr.data(), ElementsAreArray({3, 5, 7})); +} + INSTANTIATE_TEST_CASE_P(TestBFVVector, BFVVectorTest, ::testing::Values(false, true)); diff --git a/tests/cpp/tensors/ckkstensor_test.cpp b/tests/cpp/tensors/ckkstensor_test.cpp index 82126887..53a56eed 100644 --- a/tests/cpp/tensors/ckkstensor_test.cpp +++ b/tests/cpp/tensors/ckkstensor_test.cpp @@ -280,5 +280,29 @@ TEST_P(CKKSTensorTest, TestEmptyPlaintext) { INSTANTIATE_TEST_CASE_P(TestCKKSTensor, CKKSTensorTest, ::testing::Values(false, true)); +TEST_F(CKKSTensorTest, TestCKKSLazyContext) { + auto ctx = + TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60}); + ASSERT_TRUE(ctx != nullptr); + + ctx->global_scale(std::pow(2, 40)); + + auto l = CKKSTensor::Create( + ctx, PlainTensor(std::vector({1, 2, 3, 4}), {2, 2})); + auto r = CKKSTensor::Create( + ctx, PlainTensor(std::vector({5, 6, 7, 8}), {2, 2})); + + auto buffer = l->save(); + auto newl = CKKSTensor::Create(buffer); + + EXPECT_THROW(newl->add(r), std::exception); + + newl->link_tenseal_context(ctx); + auto res = newl->add(r); + + auto decr = res->decrypt(); + ASSERT_TRUE(are_close(decr.data(), {6, 8, 10, 12})); +} + } // namespace } // namespace tenseal diff --git a/tests/cpp/tensors/ckksvector_test.cpp b/tests/cpp/tensors/ckksvector_test.cpp index 81ba2864..2d89ee96 100644 --- a/tests/cpp/tensors/ckksvector_test.cpp +++ b/tests/cpp/tensors/ckksvector_test.cpp @@ -249,5 +249,27 @@ TEST_P(CKKSVectorTest, TestEmptyPlaintext) { INSTANTIATE_TEST_CASE_P(TestCKKSVector, CKKSVectorTest, ::testing::Values(false, true)); +TEST_F(CKKSVectorTest, TestCKKSLazyContext) { + auto ctx = + TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60}); + ASSERT_TRUE(ctx != nullptr); + + ctx->global_scale(std::pow(2, 40)); + + auto l = CKKSVector::Create(ctx, std::vector({1, 2, 3})); + auto r = CKKSVector::Create(ctx, std::vector({3, 4, 4})); + + auto buffer = l->save(); + auto newl = CKKSVector::Create(buffer); + + EXPECT_THROW(newl->add(r), std::exception); + + newl->link_tenseal_context(ctx); + auto res = newl->add(r); + + auto decr = res->decrypt(); + ASSERT_TRUE(are_close(decr.data(), {4, 6, 7})); +} + } // namespace } // namespace tenseal From 345346c15b07c97ede611f538fabe163fec6ccd5 Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Mon, 21 Dec 2020 18:47:48 +0200 Subject: [PATCH 4/7] python bindings for lazy loading --- tenseal/__init__.py | 18 ++++++ tenseal/binding.cpp | 18 ++++-- tenseal/tensors/abstract_tensor.py | 19 ++++++ .../tenseal/tensors/test_serialization.py | 64 +++++++++++++++++++ 4 files changed, 113 insertions(+), 6 deletions(-) diff --git a/tenseal/__init__.py b/tenseal/__init__.py index f5eba85b..7dd85018 100644 --- a/tenseal/__init__.py +++ b/tenseal/__init__.py @@ -87,6 +87,11 @@ def bfv_vector_from(context: Context, data: bytes) -> BFVVector: return BFVVector.load(context, data) +def lazy_bfv_vector_from(data: bytes) -> BFVVector: + """Load a BFVVector from a protocol buffer.""" + return BFVVector.lazy_load(data) + + def ckks_vector(*args, **kwargs) -> CKKSVector: """Constructor function for tenseal.CKKSVector""" return CKKSVector(*args, **kwargs) @@ -98,6 +103,11 @@ def ckks_vector_from(context: Context, data: bytes) -> CKKSVector: return CKKSVector.load(context, data) +def lazy_ckks_vector_from(data: bytes) -> CKKSVector: + """Load a CKKSVector from a protocol buffer.""" + return CKKSVector.lazy_load(data) + + def ckks_tensor(*args, **kwargs) -> CKKSTensor: """Constructor function for tenseal.CKKSTensor""" return CKKSTensor(*args, **kwargs) @@ -109,13 +119,21 @@ def ckks_tensor_from(context: Context, data: bytes) -> CKKSTensor: return CKKSTensor.load(context, data) +def lazy_ckks_tensor_from(data: bytes) -> CKKSTensor: + """Load a CKKSTensor from a protocol buffer""" + return CKKSTensor.lazy_load(data) + + __all__ = [ "bfv_vector", "bfv_vector_from", + "lazy_bfv_vector_from", "ckks_vector", "ckks_vector_from", + "lazy_ckks_vector_from", "ckks_tensor", "ckks_tensor_from", + "lazy_ckks_tensor_from", "context", "context_from", "im2col_encoding", diff --git a/tenseal/binding.cpp b/tenseal/binding.cpp index cbed086a..b545a55a 100644 --- a/tenseal/binding.cpp +++ b/tenseal/binding.cpp @@ -75,6 +75,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { [](const shared_ptr &ctx, const std::string &data) { return BFVVector::Create(ctx, data); })) + .def(py::init( + [](const std::string &data) { return BFVVector::Create(data); })) .def("size", py::overload_cast<>(&BFVVector::size, py::const_)) .def("decrypt", [](shared_ptr obj) { return obj->decrypt().data(); }) @@ -144,8 +146,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { [](shared_ptr obj, const vector &other) { return obj->mul_plain_inplace(other); }) - .def("context", - [](shared_ptr obj) { return obj->tenseal_context(); }) + .def("context", &BFVVector::tenseal_context) + .def("link_context", &BFVVector::link_tenseal_context) .def("serialize", [](shared_ptr &obj) { return py::bytes(obj->save()); }) .def("copy", &BFVVector::deepcopy) @@ -214,6 +216,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { [](const shared_ptr &ctx, const std::string &data) { return CKKSVector::Create(ctx, data); })) + .def(py::init( + [](const std::string &data) { return CKKSVector::Create(data); })) .def("size", py::overload_cast<>(&CKKSVector::size, py::const_)) .def("decrypt", [](shared_ptr obj) { return obj->decrypt().data(); }) @@ -415,8 +419,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { return obj->matmul_plain_inplace(matrix, n_jobs); }, py::arg("matrix"), py::arg("n_jobs") = 0) - .def("context", - [](shared_ptr obj) { return obj->tenseal_context(); }) + .def("context", &CKKSVector::tenseal_context) + .def("link_context", &CKKSVector::link_tenseal_context) .def("serialize", [](shared_ptr obj) { return py::bytes(obj->save()); }) .def("copy", &CKKSVector::deepcopy) @@ -448,6 +452,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { [](const shared_ptr &ctx, const std::string &data) { return CKKSTensor::Create(ctx, data); })) + .def(py::init( + [](const std::string &data) { return CKKSTensor::Create(data); })) .def("decrypt", [](shared_ptr obj) { return obj->decrypt(); }) .def("decrypt", @@ -560,8 +566,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { .def("__neg__", &CKKSTensor::negate) .def("__pow__", &CKKSTensor::power) .def("__ipow__", &CKKSTensor::power_inplace) - .def("context", - [](shared_ptr obj) { return obj->tenseal_context(); }) + .def("context", &CKKSTensor::tenseal_context) + .def("link_context", &CKKSTensor::link_tenseal_context) .def("serialize", [](shared_ptr &obj) { return py::bytes(obj->save()); }) .def("copy", &CKKSTensor::deepcopy) diff --git a/tenseal/tensors/abstract_tensor.py b/tenseal/tensors/abstract_tensor.py index 3d8c34b5..c0d571d5 100644 --- a/tenseal/tensors/abstract_tensor.py +++ b/tenseal/tensors/abstract_tensor.py @@ -28,6 +28,10 @@ def context(self) -> "ts.Context": """Get the context linked to this tensor""" return ts.Context._wrap(self.data.context()) + def link_context(self, ctx: "ts.Context"): + """Set the context linked to this tensor""" + return self.data.link_context(ctx) + @property def shape(self) -> List[int]: return self.data.shape() @@ -50,6 +54,21 @@ def load(cls, context: "ts.Context", data: bytes) -> "AbstractTensor": "Invalid input types context: {} and vector: {}".format(type(context), type(data)) ) + @classmethod + def lazy_load(cls, data: bytes) -> "AbstractTensor": + """ + Constructor method for the tensor object from a serialized protobuffer, without a context. + Args: + data: the serialized protobuffer. + Returns: + Tensor object. + """ + if isinstance(data, bytes): + native_type = getattr(ts._ts_cpp, cls.__name__) + return cls._wrap(native_type(data)) + + raise TypeError("Invalid input types vector: {}".format(type(data))) + def serialize(self) -> bytes: """Serialize the tensor into a stream of bytes""" return self.data.serialize() diff --git a/tests/python/tenseal/tensors/test_serialization.py b/tests/python/tenseal/tensors/test_serialization.py index 121c2a37..8d1a3ce7 100644 --- a/tests/python/tenseal/tensors/test_serialization.py +++ b/tests/python/tenseal/tensors/test_serialization.py @@ -277,6 +277,29 @@ def test_mul_without_global_scale(vec1, vec2, precision, duplicate): assert _almost_equal(first_vec.decrypt(), vec1, precision), "Something went wrong in memory." +def test_ckksvector_lazy_load(precision): + vec1 = [1, 2, 3, 4] + vec2 = [1, 2, 3, 4] + + context = ckks_context() + first_vec = ts.ckks_vector(context, vec1) + second_vec = ts.ckks_vector(context, vec2) + + buff = first_vec.serialize() + newvec = ts.lazy_ckks_vector_from(buff) + newvec.link_context(context) + + result = newvec + second_vec + # Decryption + decrypted_result = result.decrypt() + assert _almost_equal( + decrypted_result, [2, 4, 6, 8], precision + ), "Addition of vectors is incorrect." + assert _almost_equal( + newvec.decrypt(), [1, 2, 3, 4], precision + ), "Something went wrong in memory." + + @pytest.mark.parametrize( "vec1, vec2", [([1], [1]), ([-1], [1]), ([1, 2, 3, 4], [4, 3, 2, 1]),], ) @@ -546,6 +569,25 @@ def test_mul_plain_inplace(vec1, vec2, duplicate): assert decrypted_result == expected, "Multiplication of vectors is incorrect." +def test_bfvvector_lazy_load(): + vec1 = [1, 2, 3, 4] + vec2 = [1, 2, 3, 4] + + context = bfv_context() + first_vec = ts.bfv_vector(context, vec1) + second_vec = ts.bfv_vector(context, vec2) + + buff = first_vec.serialize() + newvec = ts.lazy_bfv_vector_from(buff) + newvec.link_context(context) + + result = newvec + second_vec + # Decryption + decrypted_result = result.decrypt() + assert decrypted_result == [2, 4, 6, 8], "Addition of vectors is incorrect." + assert newvec.decrypt() == [1, 2, 3, 4], "Something went wrong in memory." + + @pytest.mark.parametrize( "plain_vec", [[0], [-1], [1], [21, 81, 90], [-73, -81, -90], [-11, 82, -43, 52]] ) @@ -560,3 +602,25 @@ def test_ckks_tensor_sanity(plain_vec, precision, duplicate): decrypted = ckks_tensor.decrypt().tolist() assert _almost_equal(decrypted, plain_vec, precision), "Decryption of tensor is incorrect" + + +def test_ckks_tensor_lazy_load(): + vec1 = [1, 2, 3, 4] + vec2 = [1, 2, 3, 4] + + context = ckks_context() + first_vec = ts.ckks_tensor(context, ts.plain_tensor(vec1)) + second_vec = ts.ckks_tensor(context, ts.plain_tensor(vec2)) + + buff = first_vec.serialize() + newvec = ts.lazy_ckks_tensor_from(buff) + newvec.link_context(context) + + result = newvec + second_vec + # Decryption + decrypted_result = result.decrypt().tolist() + + assert _almost_equal( + decrypted_result, [2, 4, 6, 8], precision + ), "Decryption of tensor is incorrect" + assert _almost_equal(newvec.decrypt().tolist(), [1, 2, 3, 4], precision), "invalid new tensor" From da41b0dccbe60f1cd41938d4735572e810bd9006 Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Mon, 21 Dec 2020 21:03:43 +0200 Subject: [PATCH 5/7] update tests --- tenseal/tensors/abstract_tensor.py | 2 +- tests/python/tenseal/tensors/test_serialization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tenseal/tensors/abstract_tensor.py b/tenseal/tensors/abstract_tensor.py index c0d571d5..cac1d0a0 100644 --- a/tenseal/tensors/abstract_tensor.py +++ b/tenseal/tensors/abstract_tensor.py @@ -30,7 +30,7 @@ def context(self) -> "ts.Context": def link_context(self, ctx: "ts.Context"): """Set the context linked to this tensor""" - return self.data.link_context(ctx) + return self.data.link_context(ctx.data) @property def shape(self) -> List[int]: diff --git a/tests/python/tenseal/tensors/test_serialization.py b/tests/python/tenseal/tensors/test_serialization.py index 8d1a3ce7..9277022e 100644 --- a/tests/python/tenseal/tensors/test_serialization.py +++ b/tests/python/tenseal/tensors/test_serialization.py @@ -604,7 +604,7 @@ def test_ckks_tensor_sanity(plain_vec, precision, duplicate): assert _almost_equal(decrypted, plain_vec, precision), "Decryption of tensor is incorrect" -def test_ckks_tensor_lazy_load(): +def test_ckks_tensor_lazy_load(precision): vec1 = [1, 2, 3, 4] vec2 = [1, 2, 3, 4] From f9b7b9805b135fbc15aa6763f67f1612cb48db9e Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Tue, 22 Dec 2020 15:11:20 +0200 Subject: [PATCH 6/7] make context private --- tenseal/cpp/tensors/bfvvector.cpp | 2 +- tenseal/cpp/tensors/ckkstensor.cpp | 2 +- tenseal/cpp/tensors/ckksvector.cpp | 2 +- tenseal/cpp/tensors/encrypted_tensor.h | 9 +++++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tenseal/cpp/tensors/bfvvector.cpp b/tenseal/cpp/tensors/bfvvector.cpp index 77e717c3..d20700cd 100644 --- a/tenseal/cpp/tensors/bfvvector.cpp +++ b/tenseal/cpp/tensors/bfvvector.cpp @@ -323,7 +323,7 @@ BFVVectorProto BFVVector::save_proto() const { } void BFVVector::load(const std::string& vec) { - if (this->_context == nullptr) { + if (!this->has_context()) { _lazy_buffer = vec; return; } diff --git a/tenseal/cpp/tensors/ckkstensor.cpp b/tenseal/cpp/tensors/ckkstensor.cpp index b5096fa8..2fd524f1 100644 --- a/tenseal/cpp/tensors/ckkstensor.cpp +++ b/tenseal/cpp/tensors/ckkstensor.cpp @@ -552,7 +552,7 @@ CKKSTensorProto CKKSTensor::save_proto() const { } void CKKSTensor::load(const std::string& tensor_str) { - if (this->_context == nullptr) { + if (!this->has_context()) { _lazy_buffer = tensor_str; return; } diff --git a/tenseal/cpp/tensors/ckksvector.cpp b/tenseal/cpp/tensors/ckksvector.cpp index 27fd2001..d945b8db 100644 --- a/tenseal/cpp/tensors/ckksvector.cpp +++ b/tenseal/cpp/tensors/ckksvector.cpp @@ -463,7 +463,7 @@ CKKSVectorProto CKKSVector::save_proto() const { } void CKKSVector::load(const std::string& vec) { - if (this->_context == nullptr) { + if (!this->has_context()) { _lazy_buffer = vec; return; } diff --git a/tenseal/cpp/tensors/encrypted_tensor.h b/tenseal/cpp/tensors/encrypted_tensor.h index cfc41b05..f7a9c6fd 100644 --- a/tenseal/cpp/tensors/encrypted_tensor.h +++ b/tenseal/cpp/tensors/encrypted_tensor.h @@ -155,7 +155,12 @@ class EncryptedTensor { if (_context == nullptr) throw invalid_argument("missing context"); return _context; }; - + /** + * Check if the context is linked + * **/ + bool has_context() const { + return _context != nullptr; + }; /** * Link to a TenSEAL context. **/ @@ -256,10 +261,10 @@ class EncryptedTensor { virtual ~EncryptedTensor(){}; protected: - shared_ptr _context; optional _lazy_buffer; private: + shared_ptr _context; }; } // namespace tenseal From 5cf834f5ac529d538897698b248a0dadf82a4cd0 Mon Sep 17 00:00:00 2001 From: Cebere Bogdan Date: Tue, 22 Dec 2020 16:01:17 +0200 Subject: [PATCH 7/7] lint --- tenseal/cpp/tensors/encrypted_tensor.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tenseal/cpp/tensors/encrypted_tensor.h b/tenseal/cpp/tensors/encrypted_tensor.h index f7a9c6fd..d94cad56 100644 --- a/tenseal/cpp/tensors/encrypted_tensor.h +++ b/tenseal/cpp/tensors/encrypted_tensor.h @@ -158,9 +158,7 @@ class EncryptedTensor { /** * Check if the context is linked * **/ - bool has_context() const { - return _context != nullptr; - }; + bool has_context() const { return _context != nullptr; }; /** * Link to a TenSEAL context. **/