Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threadpool support for matrix multiplication #124

Merged
merged 25 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,14 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("dot_", &CKKSVector::dot_product_plain_inplace)
.def("sum", &CKKSVector::sum)
.def("sum_", &CKKSVector::sum_inplace)
.def("matmul", &CKKSVector::matmul_plain, py::arg("matrix"))
.def("matmul_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"))
.def("mm", &CKKSVector::matmul_plain, py::arg("matrix"))
.def("mm_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"))
.def("matmul", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_batches") = 0)
.def("matmul_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"),
py::arg("n_batches") = 0)
.def("mm", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_batches") = 0)
.def("mm_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"),
py::arg("n_batches") = 0)
bcebere marked this conversation as resolved.
Show resolved Hide resolved
// python arithmetic
.def("__neg__", &CKKSVector::negate)
.def("__pow__", &CKKSVector::power)
Expand Down Expand Up @@ -194,9 +198,10 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
py::overload_cast<double>(&CKKSVector::mul_plain_inplace))
.def("__imul__", py::overload_cast<const vector<double> &>(
&CKKSVector::mul_plain_inplace))
.def("__matmul__", &CKKSVector::matmul_plain, py::arg("matrix"))
.def("__matmul__", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_batches") = 0)
.def("__imatmul__", &CKKSVector::matmul_plain_inplace,
py::arg("matrix"))
py::arg("matrix"), py::arg("n_batches") = 0)
.def("context",
[](const CKKSVector &obj) { return obj.tenseal_context(); })
.def("serialize",
Expand Down
10 changes: 6 additions & 4 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,17 @@ CKKSVector& CKKSVector::sum_inplace() {
return *this;
}

CKKSVector CKKSVector::matmul_plain(const vector<vector<double>>& matrix) {
CKKSVector CKKSVector::matmul_plain(const vector<vector<double>>& matrix,
size_t n_batches) {
CKKSVector new_vector = *this;
return new_vector.matmul_plain_inplace(matrix);
return new_vector.matmul_plain_inplace(matrix, n_batches);
}

CKKSVector& CKKSVector::matmul_plain_inplace(
const vector<vector<double>>& matrix) {
const vector<vector<double>>& matrix, size_t n_batches) {
this->ciphertext = diagonal_ct_vector_matmul<double, CKKSEncoder>(
this->tenseal_context(), this->ciphertext, this->size(), matrix);
this->tenseal_context(), this->ciphertext, this->size(), matrix,
n_batches);

this->_size = matrix[0].size();

Expand Down
6 changes: 4 additions & 2 deletions tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ class CKKSVector {
/**
* Matrix multiplication operations.
**/
CKKSVector matmul_plain(const vector<vector<double>>& matrix);
CKKSVector& matmul_plain_inplace(const vector<vector<double>>& matrix);
CKKSVector matmul_plain(const vector<vector<double>>& matrix,
size_t n_batches = 0);
CKKSVector& matmul_plain_inplace(const vector<vector<double>>& matrix,
size_t n_batches = 0);

/**
* Polynomial evaluation with `this` as variable.
Expand Down
13 changes: 7 additions & 6 deletions tenseal/cpp/tensors/utils/matrix_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ Cryptology Conference (pp. 554-571). Springer, Berlin, Heidelberg.
template <typename T, class Encoder>
Ciphertext diagonal_ct_vector_matmul(shared_ptr<TenSEALContext> tenseal_context,
Ciphertext& vec, const size_t vector_size,
const vector<vector<T>>& matrix) {
const vector<vector<T>>& matrix,
size_t n_batches) {
// matrix is organized by rows
// _check_matrix(matrix, this->size())

Expand Down Expand Up @@ -108,20 +109,20 @@ Ciphertext diagonal_ct_vector_matmul(shared_ptr<TenSEALContext> tenseal_context,
return thread_result;
};

size_t worker_cnt = tenseal_context->dispatcher_size();
if (n_batches == 0) n_batches = tenseal_context->dispatcher_size();

if (worker_cnt == 1) return worker_func(0, vector_size);
if (n_batches == 1) return worker_func(0, vector_size);

std::vector<std::future<Ciphertext>> future_results;
size_t batch_size = (vector_size + worker_cnt - 1) / worker_cnt;
size_t batch_size = (vector_size + n_batches - 1) / n_batches;

for (size_t i = 0; i < worker_cnt; i++) {
for (size_t i = 0; i < n_batches; i++) {
future_results.push_back(tenseal_context->dispatcher()->enqueue_task(
worker_func, i * batch_size,
std::min((i + 1) * batch_size, vector_size)));
}

for (size_t i = 0; i < worker_cnt; i++) {
for (size_t i = 0; i < n_batches; i++) {
tenseal_context->evaluator->add_inplace(result,
future_results[i].get());
}
Expand Down
24 changes: 20 additions & 4 deletions tenseal/cpp/utils/queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@
namespace tenseal {
namespace sync {

/**
* Thread-safe queue.
**/
template <class T>
class blocking_queue {
public:
/**
* push() appends a new item to the queue and notifies the "pop" listener
*about the event.
**/
template <class... Args>
void push(Args&&... args) {
{
Expand All @@ -24,7 +31,10 @@ class blocking_queue {
}
ready_.notify_one();
}

/**
* pop() waits until an item is available in the queue, pops it out and
*assigns it to the "out" parameter.
**/
[[nodiscard]] bool pop(T& out) {
std::unique_lock lock{mutex_};
ready_.wait(lock, [this] { return !queue_.empty() || done_; });
Expand All @@ -35,20 +45,26 @@ class blocking_queue {

return true;
}

/**
* done() notifies all listeners to shutdown.
**/
void done() noexcept {
{
std::scoped_lock lock{mutex_};
done_ = true;
}
ready_.notify_all();
}

/**
* empty() returns if the queue is empty or not.
**/
[[nodiscard]] bool empty() const noexcept {
std::scoped_lock lock{mutex_};
return queue_.empty();
}

/**
* size() returns the size of the queue.
**/
[[nodiscard]] unsigned int size() const noexcept {
std::scoped_lock lock{mutex_};
return queue_.size();
Expand Down
27 changes: 20 additions & 7 deletions tenseal/cpp/utils/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,39 @@ inline uint get_concurrency() {

namespace sync {

/**
* A ThreadPool class for managing and dispatching tasks to a number of threads.
**/
class ThreadPool {
public:
ThreadPool(unsigned int threads = get_concurrency())
: m_queues(threads), m_count(threads) {
assert(threads != 0);
/**
* Create "n_threads" workers, each with a dedicated task queue, and execute
* the task as they arrive in the queues.
**/
ThreadPool(unsigned int n_threads = get_concurrency())
: m_queues(n_threads), m_count(n_threads) {
assert(n_threads != 0);
auto worker = [&](unsigned int i) {
while (true) {
Proc f;
if (!m_queues[i].pop(f)) break;
f();
}
};
for (unsigned int i = 0; i < threads; ++i)
m_threads.emplace_back(worker, i);
for (unsigned int i = 0; i < n_threads; ++i)
m_workers.emplace_back(worker, i);
}

~ThreadPool() noexcept {
for (auto& queue : m_queues) queue.done();
for (auto& thread : m_threads) thread.join();
for (auto& worker : m_workers) worker.join();
}

/**
* enqueue_task() assigns tasks to worker queues using round robin
*scheduling.
* @returns a std::future object with the result of the task.
**/
template <typename F, typename... Args>
auto enqueue_task(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
Expand All @@ -63,11 +75,12 @@ class ThreadPool {

private:
using Proc = std::function<void(void)>;

using Queues = std::vector<blocking_queue<Proc>>;
Queues m_queues;

using Threads = std::vector<std::thread>;
Threads m_threads;
Threads m_workers;

const unsigned int m_count;
std::atomic_uint m_index = 0;
Expand Down
10 changes: 6 additions & 4 deletions tests/python/tenseal/tensors/test_ckks_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,12 @@ def test_mul_without_global_scale(vec1, vec2, precision):
],
)
@pytest.mark.parametrize("n_threads", [0, 1, 2, 4])
def test_vec_plain_matrix_mul(vec, matrix, n_threads, precision):
@pytest.mark.parametrize("n_batches", [0, 1, 2, 4])
def test_vec_plain_matrix_mul(vec, matrix, n_threads, n_batches, precision):
context = parallel_context(n_threads)
context.generate_galois_keys()
ct = ts.ckks_vector(context, vec)
result = ct.mm(matrix)
result = ct.mm(matrix, n_batches)
expected = (np.array(vec) @ np.array(matrix)).tolist()
assert _almost_equal(
result.decrypt(), expected, precision
Expand All @@ -981,11 +982,12 @@ def test_vec_plain_matrix_mul(vec, matrix, n_threads, precision):
],
)
@pytest.mark.parametrize("n_threads", [0, 1, 2, 4])
def test_vec_plain_matrix_mul_inplace(vec, matrix, n_threads, precision):
@pytest.mark.parametrize("n_batches", [0, 1, 2, 4])
def test_vec_plain_matrix_mul_inplace(vec, matrix, n_threads, n_batches, precision):
context = parallel_context(n_threads)
context.generate_galois_keys()
ct = ts.ckks_vector(context, vec)
ct.mm_(matrix)
ct.mm_(matrix, n_batches)
expected = (np.array(vec) @ np.array(matrix)).tolist()
assert _almost_equal(ct.decrypt(), expected, precision), "Matrix multiplciation is incorrect."

Expand Down