Skip to content

Commit

Permalink
Encrypted matrix multiplication with plain vector (OpenMined#137)
Browse files Browse the repository at this point in the history
* enc_matmul_encoding python binding

* python tests for enc_matmul_plain

* remove duplicated code from conv2d_im2col function
  • Loading branch information
philomath213 authored Aug 15, 2020
1 parent 8eab52b commit 6425b7e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 15 deletions.
1 change: 1 addition & 0 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# utils
im2col_encoding = _ts_cpp.im2col_encoding
enc_matmul_encoding = _ts_cpp.enc_matmul_encoding


def context(
Expand Down
23 changes: 23 additions & 0 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,27 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
return make_pair(ckks_vector, windows_nb);
});

m.def("enc_matmul_encoding", [](shared_ptr<TenSEALContext> ctx,
const vector<vector<double>> &input) {
vector<double> final_vector;
vector<vector<double>> padded_matrix;
padded_matrix.reserve(input.size());
// calculate the next power of 2
size_t plain_vec_size =
1 << (static_cast<size_t>(ceil(log2(input[0].size()))));

for (size_t i = 0; i < input.size(); i++) {
// pad the row to the next power of 2
vector<double> row(plain_vec_size, 0);
copy(input[i].begin(), input[i].end(), row.begin());
padded_matrix.push_back(row);
}

vertical_scan(padded_matrix, final_vector);
CKKSVector ckks_vector = CKKSVector(ctx, final_vector);
return ckks_vector;
});

py::class_<CKKSVector>(m, "CKKSVector")
// specifying scale
.def(py::init<shared_ptr<TenSEALContext> &, vector<double>, double>())
Expand Down Expand Up @@ -159,6 +180,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
py::arg("n_jobs") = 0)
.def("conv2d_im2col", &CKKSVector::conv2d_im2col)
.def("conv2d_im2col_inplace", &CKKSVector::conv2d_im2col_inplace)
.def("enc_matmul_plain", &CKKSVector::enc_matmul_plain)
.def("enc_matmul_plain_inplace", &CKKSVector::enc_matmul_plain)
// python arithmetic
.def("__neg__", &CKKSVector::negate)
.def("__pow__", &CKKSVector::power)
Expand Down
46 changes: 32 additions & 14 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,34 +587,52 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace(
vector<double> flatten_kernel;
horizontal_scan(kernel, flatten_kernel);

this->enc_matmul_plain_inplace(flatten_kernel, windows_nb);
return *this;
}

CKKSVector CKKSVector::enc_matmul_plain(const vector<double>& plain_vec,
const size_t rows_nb) {
CKKSVector new_vec = *this;
new_vec.enc_matmul_plain_inplace(plain_vec, rows_nb);
return new_vec;
}

CKKSVector& CKKSVector::enc_matmul_plain_inplace(
const vector<double>& plain_vec, const size_t rows_nb) {
if (plain_vec.empty()) {
throw invalid_argument("Plain vector can't be empty");
}

// calculate the next power of 2
size_t kernel_size = kernel.size() * kernel[0].size();
kernel_size = 1 << (static_cast<size_t>(ceil(log2(kernel_size))));
size_t plain_vec_size =
1 << (static_cast<size_t>(ceil(log2(plain_vec.size()))));

// pad the kernel with zeros to the next power of 2
flatten_kernel.resize(kernel_size, 0);
// pad the vector with zeros to the next power of 2
vector<double> padded_plain_vec(plain_vec);
padded_plain_vec.resize(plain_vec_size, 0);

size_t chunks_nb = flatten_kernel.size();
size_t chunks_nb = padded_plain_vec.size();

if (this->_size / windows_nb != chunks_nb) {
if (this->_size / rows_nb != chunks_nb) {
throw invalid_argument("Matrix shape doesn't match with vector size");
}

vector<double> plain_vec;
plain_vec.reserve(this->_size);
vector<double> new_plain_vec;
new_plain_vec.reserve(this->_size);

for (size_t i = 0; i < chunks_nb; i++) {
vector<double> tmp(windows_nb, flatten_kernel[i]);
plain_vec.insert(plain_vec.end(), tmp.begin(), tmp.end());
vector<double> tmp(rows_nb, padded_plain_vec[i]);
new_plain_vec.insert(new_plain_vec.end(), tmp.begin(), tmp.end());
}

// replicate the vector in order to be able to do multiple matrix
// multiplications
size_t slot_count = this->context->slot_count<CKKSEncoder>();
replicate_vector(plain_vec, slot_count);
replicate_vector(new_plain_vec, slot_count);
this->_size = slot_count;

this->mul_plain_inplace(plain_vec);
this->mul_plain_inplace(new_plain_vec);

auto galois_keys = this->context->galois_keys();

Expand All @@ -625,12 +643,12 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace(
chunks_nb = static_cast<int>(
1 << (static_cast<size_t>(ceil(log2(chunks_nb))) - 1));
this->context->evaluator->rotate_vector_inplace(
tmp.ciphertext, static_cast<int>(windows_nb * chunks_nb),
tmp.ciphertext, static_cast<int>(rows_nb * chunks_nb),
*galois_keys);
this->add_inplace(tmp);
}

this->_size = windows_nb;
this->_size = rows_nb;

return *this;
}
Expand Down
10 changes: 9 additions & 1 deletion tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,21 @@ class CKKSVector {
CKKSVector& sum_inplace();

/**
* Matrix multiplication operations.
* Encrypted Vector multiplication with plain matrix.
**/
CKKSVector matmul_plain(const vector<vector<double>>& matrix,
size_t n_jobs = 0);
CKKSVector& matmul_plain_inplace(const vector<vector<double>>& matrix,
size_t n_jobs = 0);

/**
* Encrypted Matrix multiplication with plain vector.
**/
CKKSVector enc_matmul_plain(const vector<double>& plain_vec,
size_t row_size);
CKKSVector& enc_matmul_plain_inplace(const vector<double>& plain_vec,
size_t row_size);

/**
* Polynomial evaluation with `this` as variable.
* p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] *
Expand Down
23 changes: 23 additions & 0 deletions tests/python/tenseal/tensors/test_ckks_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,29 @@ def test_vec_plain_matrix_mul_depth2(context, vec, matrix1, matrix2, precision):
), "Matrix multiplication is incorrect."


@pytest.mark.parametrize(
"matrix_shape, vector_size",
[((1, 1), 1), ((2, 1), 1), ((3, 2), 2), ((4, 4), 4), ((9, 7), 7), ((16, 12), 12),],
)
def test_enc_matmul_plain(context, matrix_shape, vector_size, precision):
def generate_input(matrix_shape, vector_size):
# generated random values
matrix = np.random.randn(*matrix_shape)
vector = np.random.randn(vector_size)

return matrix, vector

matrix, vector = generate_input(matrix_shape, vector_size)
expected = matrix @ vector

context.generate_galois_keys()
ckks_vector = ts.enc_matmul_encoding(context, matrix.tolist())
result = ckks_vector.enc_matmul_plain(vector.tolist(), matrix_shape[0])
assert _almost_equal(
result.decrypt(), expected, precision
), "Matrix multiplication is incorrect."


@pytest.mark.parametrize(
"data, polynom",
[
Expand Down

0 comments on commit 6425b7e

Please sign in to comment.