diff --git a/tenseal/__init__.py b/tenseal/__init__.py index ca553898..a13951f1 100644 --- a/tenseal/__init__.py +++ b/tenseal/__init__.py @@ -18,6 +18,7 @@ # utils im2col_encoding = _ts_cpp.im2col_encoding +enc_matmul_encoding = _ts_cpp.enc_matmul_encoding def context( diff --git a/tenseal/binding.cpp b/tenseal/binding.cpp index 218f2757..9619e115 100644 --- a/tenseal/binding.cpp +++ b/tenseal/binding.cpp @@ -96,6 +96,27 @@ PYBIND11_MODULE(_tenseal_cpp, m) { return make_pair(ckks_vector, windows_nb); }); + m.def("enc_matmul_encoding", [](shared_ptr ctx, + const vector> &input) { + vector final_vector; + vector> padded_matrix; + padded_matrix.reserve(input.size()); + // calculate the next power of 2 + size_t plain_vec_size = + 1 << (static_cast(ceil(log2(input[0].size())))); + + for (size_t i = 0; i < input.size(); i++) { + // pad the row to the next power of 2 + vector row(plain_vec_size, 0); + copy(input[i].begin(), input[i].end(), row.begin()); + padded_matrix.push_back(row); + } + + vertical_scan(padded_matrix, final_vector); + CKKSVector ckks_vector = CKKSVector(ctx, final_vector); + return ckks_vector; + }); + py::class_(m, "CKKSVector") // specifying scale .def(py::init &, vector, double>()) @@ -159,6 +180,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) { py::arg("n_jobs") = 0) .def("conv2d_im2col", &CKKSVector::conv2d_im2col) .def("conv2d_im2col_inplace", &CKKSVector::conv2d_im2col_inplace) + .def("enc_matmul_plain", &CKKSVector::enc_matmul_plain) + .def("enc_matmul_plain_inplace", &CKKSVector::enc_matmul_plain) // python arithmetic .def("__neg__", &CKKSVector::negate) .def("__pow__", &CKKSVector::power) diff --git a/tenseal/cpp/tensors/ckksvector.cpp b/tenseal/cpp/tensors/ckksvector.cpp index 4830f242..c05246a5 100644 --- a/tenseal/cpp/tensors/ckksvector.cpp +++ b/tenseal/cpp/tensors/ckksvector.cpp @@ -587,34 +587,52 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace( vector flatten_kernel; horizontal_scan(kernel, flatten_kernel); + this->enc_matmul_plain_inplace(flatten_kernel, windows_nb); + return *this; +} + +CKKSVector CKKSVector::enc_matmul_plain(const vector& plain_vec, + const size_t rows_nb) { + CKKSVector new_vec = *this; + new_vec.enc_matmul_plain_inplace(plain_vec, rows_nb); + return new_vec; +} + +CKKSVector& CKKSVector::enc_matmul_plain_inplace( + const vector& plain_vec, const size_t rows_nb) { + if (plain_vec.empty()) { + throw invalid_argument("Plain vector can't be empty"); + } + // calculate the next power of 2 - size_t kernel_size = kernel.size() * kernel[0].size(); - kernel_size = 1 << (static_cast(ceil(log2(kernel_size)))); + size_t plain_vec_size = + 1 << (static_cast(ceil(log2(plain_vec.size())))); - // pad the kernel with zeros to the next power of 2 - flatten_kernel.resize(kernel_size, 0); + // pad the vector with zeros to the next power of 2 + vector padded_plain_vec(plain_vec); + padded_plain_vec.resize(plain_vec_size, 0); - size_t chunks_nb = flatten_kernel.size(); + size_t chunks_nb = padded_plain_vec.size(); - if (this->_size / windows_nb != chunks_nb) { + if (this->_size / rows_nb != chunks_nb) { throw invalid_argument("Matrix shape doesn't match with vector size"); } - vector plain_vec; - plain_vec.reserve(this->_size); + vector new_plain_vec; + new_plain_vec.reserve(this->_size); for (size_t i = 0; i < chunks_nb; i++) { - vector tmp(windows_nb, flatten_kernel[i]); - plain_vec.insert(plain_vec.end(), tmp.begin(), tmp.end()); + vector tmp(rows_nb, padded_plain_vec[i]); + new_plain_vec.insert(new_plain_vec.end(), tmp.begin(), tmp.end()); } // replicate the vector in order to be able to do multiple matrix // multiplications size_t slot_count = this->context->slot_count(); - replicate_vector(plain_vec, slot_count); + replicate_vector(new_plain_vec, slot_count); this->_size = slot_count; - this->mul_plain_inplace(plain_vec); + this->mul_plain_inplace(new_plain_vec); auto galois_keys = this->context->galois_keys(); @@ -625,12 +643,12 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace( chunks_nb = static_cast( 1 << (static_cast(ceil(log2(chunks_nb))) - 1)); this->context->evaluator->rotate_vector_inplace( - tmp.ciphertext, static_cast(windows_nb * chunks_nb), + tmp.ciphertext, static_cast(rows_nb * chunks_nb), *galois_keys); this->add_inplace(tmp); } - this->_size = windows_nb; + this->_size = rows_nb; return *this; } diff --git a/tenseal/cpp/tensors/ckksvector.h b/tenseal/cpp/tensors/ckksvector.h index 2e635108..cc3c8139 100644 --- a/tenseal/cpp/tensors/ckksvector.h +++ b/tenseal/cpp/tensors/ckksvector.h @@ -108,13 +108,21 @@ class CKKSVector { CKKSVector& sum_inplace(); /** - * Matrix multiplication operations. + * Encrypted Vector multiplication with plain matrix. **/ CKKSVector matmul_plain(const vector>& matrix, size_t n_jobs = 0); CKKSVector& matmul_plain_inplace(const vector>& matrix, size_t n_jobs = 0); + /** + * Encrypted Matrix multiplication with plain vector. + **/ + CKKSVector enc_matmul_plain(const vector& plain_vec, + size_t row_size); + CKKSVector& enc_matmul_plain_inplace(const vector& plain_vec, + size_t row_size); + /** * Polynomial evaluation with `this` as variable. * p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] * diff --git a/tests/python/tenseal/tensors/test_ckks_vector.py b/tests/python/tenseal/tensors/test_ckks_vector.py index d40160f7..6402aae4 100644 --- a/tests/python/tenseal/tensors/test_ckks_vector.py +++ b/tests/python/tenseal/tensors/test_ckks_vector.py @@ -1008,6 +1008,29 @@ def test_vec_plain_matrix_mul_depth2(context, vec, matrix1, matrix2, precision): ), "Matrix multiplication is incorrect." +@pytest.mark.parametrize( + "matrix_shape, vector_size", + [((1, 1), 1), ((2, 1), 1), ((3, 2), 2), ((4, 4), 4), ((9, 7), 7), ((16, 12), 12),], +) +def test_enc_matmul_plain(context, matrix_shape, vector_size, precision): + def generate_input(matrix_shape, vector_size): + # generated random values + matrix = np.random.randn(*matrix_shape) + vector = np.random.randn(vector_size) + + return matrix, vector + + matrix, vector = generate_input(matrix_shape, vector_size) + expected = matrix @ vector + + context.generate_galois_keys() + ckks_vector = ts.enc_matmul_encoding(context, matrix.tolist()) + result = ckks_vector.enc_matmul_plain(vector.tolist(), matrix_shape[0]) + assert _almost_equal( + result.decrypt(), expected, precision + ), "Matrix multiplication is incorrect." + + @pytest.mark.parametrize( "data, polynom", [