Skip to content

Commit

Permalink
Threadpool support for matrix multiplication (OpenMined#124)
Browse files Browse the repository at this point in the history
* add benchmarks for mamul_plain operation

* add threadpool implementation

Co-authored-by: Ayoub Benaissa <ayouben9@gmail.com>
  • Loading branch information
bcebere and youben11 authored Jul 30, 2020
1 parent 0197534 commit 08d63c0
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 139 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Run gtest
timeout-minutes: 30
timeout-minutes: 15
run: bazel test --test_output=all --spawn_strategy=standalone --test_timeout=900 //tests/cpp/...
- name: Run SEALAPI tests
timeout-minutes: 30
timeout-minutes: 15
run: bazel test --test_output=all --spawn_strategy=standalone --test_timeout=900 //tests/python/sealapi/...
- name: Run TenSEAL tests
timeout-minutes: 30
run: bazel test --test_output=all --spawn_strategy=standalone --test_timeout=900 //tests/python/tenseal/...
timeout-minutes: 15
run: bazel test --test_output=all --spawn_strategy=standalone --test_output=streamed --local_sigkill_grace_seconds=30 --test_timeout=900 //tests/python/tenseal/...
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ load("@org_openmined_tenseal//tenseal:deps.bzl", "tenseal_deps")
tenseal_deps()
```

## Benchmarks

You can benchmark the implementation at any point by running
```bash
$ bazel run -c opt --spawn_strategy=standalone //tests/cpp/benchmarks:benchmark
## Support

For support in using this library, please join the **#lib_tenseal** Slack channel. If you’d like to follow along with any code changes to the library, please join the **#code_tenseal** Slack channel. [Click here to join our Slack community!](https://slack.openmined.org)
Expand Down
13 changes: 11 additions & 2 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
GaloisKeys = _ts_cpp.GaloisKeys


def context(scheme, poly_modulus_degree, plain_modulus=None, coeff_mod_bit_sizes=None):
def context(
scheme, poly_modulus_degree, plain_modulus=None, coeff_mod_bit_sizes=None, n_threads=None
):
"""Construct a context that holds keys and parameters needed for operating
encrypted tensors using either BFV or CKKS scheme.
Expand Down Expand Up @@ -47,12 +49,17 @@ def context(scheme, poly_modulus_degree, plain_modulus=None, coeff_mod_bit_sizes
raise ValueError("Invalid scheme type, use either SCHEME_TYPE.BFV or SCHEME_TYPE.CKKS")

# We can't pass None here, everything should be set prior to this call
if isinstance(n_threads, int) and n_threads > 0:
return _ts_cpp.TenSEALContext.new(
scheme, poly_modulus_degree, plain_modulus, coeff_mod_bit_sizes, n_threads
)

return _ts_cpp.TenSEALContext.new(
scheme, poly_modulus_degree, plain_modulus, coeff_mod_bit_sizes
)


def context_from(buff):
def context_from(buff, n_threads=None):
"""Construct a context from a serialized buffer.
Args:
Expand All @@ -61,6 +68,8 @@ def context_from(buff):
Returns:
A TenSEALContext object.
"""
if n_threads:
return _ts_cpp.TenSEALContext.deserialize(buff, n_threads)
return _ts_cpp.TenSEALContext.deserialize(buff)


Expand Down
26 changes: 15 additions & 11 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("sum", &CKKSVector::sum)
.def("sum_", &CKKSVector::sum_inplace)
.def("matmul", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_threads") = 0)
py::arg("n_jobs") = 0)
.def("matmul_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"),
py::arg("n_threads") = 0)
py::arg("n_jobs") = 0)
.def("mm", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_threads") = 0)
py::arg("n_jobs") = 0)
.def("mm_", &CKKSVector::matmul_plain_inplace, py::arg("matrix"),
py::arg("n_threads") = 0)
py::arg("n_jobs") = 0)
// python arithmetic
.def("__neg__", &CKKSVector::negate)
.def("__pow__", &CKKSVector::power)
Expand Down Expand Up @@ -199,9 +199,9 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("__imul__", py::overload_cast<const vector<double> &>(
&CKKSVector::mul_plain_inplace))
.def("__matmul__", &CKKSVector::matmul_plain, py::arg("matrix"),
py::arg("n_threads") = 0)
py::arg("n_jobs") = 0)
.def("__imatmul__", &CKKSVector::matmul_plain_inplace,
py::arg("matrix"), py::arg("n_threads") = 0)
py::arg("matrix"), py::arg("n_jobs") = 0)
.def("context",
[](const CKKSVector &obj) { return obj.tenseal_context(); })
.def("serialize",
Expand All @@ -227,18 +227,20 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
py::overload_cast<>(&TenSEALContext::auto_mod_switch),
py::overload_cast<bool>(&TenSEALContext::auto_mod_switch))
.def("new",
py::overload_cast<scheme_type, size_t, uint64_t, vector<int>>(
&TenSEALContext::Create),
py::overload_cast<scheme_type, size_t, uint64_t, vector<int>,
optional<uint>>(&TenSEALContext::Create),
R"(Create a new TenSEALContext object to hold keys and parameters.
Args:
scheme : define the scheme to be used, either SCHEME_TYPE.BFV or SCHEME_TYPE.CKKS.
poly_modulus_degree : The degree of the polynomial modulus, must be a power of two.
plain_modulus : The plaintext modulus. Is not used if scheme is CKKS.
coeff_mod_bit_sizes : List of bit size for each coeffecient modulus.
n_threads : Optional: number of threads to use for multiplications.
Can be an empty list for BFV, a default value will be given.
)",
py::arg("poly_modulus_degree"), py::arg("plain_modulus"),
py::arg("coeff_mod_bit_sizes") = vector<int>())
py::arg("coeff_mod_bit_sizes") = vector<int>(),
py::arg("n_threads") = get_concurrency())
.def("public_key", &TenSEALContext::public_key)
.def("secret_key", &TenSEALContext::secret_key)
.def("relin_keys", &TenSEALContext::relin_keys)
Expand Down Expand Up @@ -266,8 +268,10 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
"Generate Relinearization keys using the secret key")
.def("serialize",
[](const TenSEALContext &obj) { return py::bytes(obj.save()); })
.def_static("deserialize", py::overload_cast<const std::string &>(
&TenSEALContext::Create))
.def_static("deserialize",
py::overload_cast<const std::string &, optional<uint>>(
&TenSEALContext::Create),
py::arg("buffer"), py::arg("n_threads") = get_concurrency())
.def("copy", &TenSEALContext::copy)
.def("__copy__",
[](const std::shared_ptr<TenSEALContext> &self) {
Expand Down
1 change: 0 additions & 1 deletion tenseal/cpp/context/sealcontext.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "tenseal/cpp/context/sealcontext.h"

#include <memory>
#include <thread>

#include "seal/seal.h"

Expand Down
47 changes: 34 additions & 13 deletions tenseal/cpp/context/tensealcontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,35 @@ namespace tenseal {
using namespace seal;
using namespace std;

TenSEALContext::TenSEALContext(EncryptionParameters parms) {
TenSEALContext::TenSEALContext(EncryptionParameters parms,
optional<uint> n_threads) {
this->dispatcher_setup(n_threads);
this->base_setup(parms);
this->keys_setup();
}

TenSEALContext::TenSEALContext(istream& stream) { this->load(stream); }
TenSEALContext::TenSEALContext(const std::string& input) { this->load(input); }
TenSEALContext::TenSEALContext(const TenSEALContextProto& input) {
TenSEALContext::TenSEALContext(istream& stream, optional<uint> n_threads) {
this->dispatcher_setup(n_threads);
this->load(stream);
}
TenSEALContext::TenSEALContext(const std::string& input,
optional<uint> n_threads) {
this->dispatcher_setup(n_threads);
this->load(input);
}
TenSEALContext::TenSEALContext(const TenSEALContextProto& input,
optional<uint> n_threads) {
this->dispatcher_setup(n_threads);
this->load_proto(input);
}

void TenSEALContext::dispatcher_setup(optional<uint> n_threads) {
this->_threads = n_threads.value_or(get_concurrency());
if (this->_threads == 0) this->_threads = get_concurrency();

this->_dispatcher = make_shared<sync::ThreadPool>(this->_threads);
}

void TenSEALContext::base_setup(EncryptionParameters parms) {
this->_parms = parms;
this->_context = SEALContext::Create(this->_parms);
Expand Down Expand Up @@ -65,7 +83,7 @@ void TenSEALContext::keys_setup(optional<PublicKey> public_key,

shared_ptr<TenSEALContext> TenSEALContext::Create(
scheme_type scheme, size_t poly_modulus_degree, uint64_t plain_modulus,
vector<int> coeff_mod_bit_sizes) {
vector<int> coeff_mod_bit_sizes, optional<uint> n_threads) {
EncryptionParameters parms;
switch (scheme) {
case scheme_type::BFV:
Expand All @@ -82,20 +100,22 @@ shared_ptr<TenSEALContext> TenSEALContext::Create(
throw invalid_argument("invalid scheme_type");
}

return shared_ptr<TenSEALContext>(new TenSEALContext(parms));
return shared_ptr<TenSEALContext>(new TenSEALContext(parms, n_threads));
}

shared_ptr<TenSEALContext> TenSEALContext::Create(istream& stream) {
return shared_ptr<TenSEALContext>(new TenSEALContext(stream));
shared_ptr<TenSEALContext> TenSEALContext::Create(istream& stream,
optional<uint> n_threads) {
return shared_ptr<TenSEALContext>(new TenSEALContext(stream, n_threads));
}

shared_ptr<TenSEALContext> TenSEALContext::Create(const std::string& input) {
return shared_ptr<TenSEALContext>(new TenSEALContext(input));
shared_ptr<TenSEALContext> TenSEALContext::Create(const std::string& input,
optional<uint> n_threads) {
return shared_ptr<TenSEALContext>(new TenSEALContext(input, n_threads));
}

shared_ptr<TenSEALContext> TenSEALContext::Create(
const TenSEALContextProto& input) {
return shared_ptr<TenSEALContext>(new TenSEALContext(input));
const TenSEALContextProto& input, optional<uint> n_threads) {
return shared_ptr<TenSEALContext>(new TenSEALContext(input, n_threads));
}

shared_ptr<PublicKey> TenSEALContext::public_key() const {
Expand Down Expand Up @@ -329,7 +349,8 @@ TenSEALContextProto TenSEALContext::save_proto() const {

std::shared_ptr<TenSEALContext> TenSEALContext::copy() const {
TenSEALContextProto buffer = this->save_proto();
return shared_ptr<TenSEALContext>(new TenSEALContext(buffer));
return shared_ptr<TenSEALContext>(
new TenSEALContext(buffer, this->_threads));
}

void TenSEALContext::load(std::istream& stream) {
Expand Down
42 changes: 32 additions & 10 deletions tenseal/cpp/context/tensealcontext.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "seal/seal.h"
#include "tenseal/cpp/context/sealcontext.h"
#include "tenseal/cpp/context/tensealencoder.h"
#include "tenseal/cpp/utils/threadpool.h"
#include "tenseal/proto/tensealcontext.pb.h"

namespace tenseal {
Expand Down Expand Up @@ -38,32 +39,43 @@ class TenSEALContext {
* @param[in] scheme: BFV or CKKS.
* @param[in] poly_modulus_degree: The polynomial modulus degree.
* @param[in] plain_modulus: The plaintext modulus.
* @param[in] coeff_mod_bit_sizes: The bit-lengths of the primes to be
*generated.
* @param[in] coeff_mod_bit_sizes: The bit-lengths of the primes to be/
* @param[in] n_threads: Optional parameter for the size of the threadpool
*dispatcher. generated.
* @returns shared_ptr to a new TenSEALContext object.
**/
static shared_ptr<TenSEALContext> Create(scheme_type scheme,
size_t poly_modulus_degree,
uint64_t plain_modulus,
vector<int> coeff_mod_bit_sizes);
vector<int> coeff_mod_bit_sizes,
optional<uint> n_threads = {});
/**
* Create a context from an input stream.
* @param[in] stream
* @param[in] n_threads: Optional parameter for the size of the threadpool
*dispatcher.
* @returns shared_ptr to a new TenSEALContext object.
**/
static shared_ptr<TenSEALContext> Create(istream& stream);
static shared_ptr<TenSEALContext> Create(istream& stream,
optional<uint> n_threads = {});
/**
* Create a context from a serialized protobuffer.
* @param[in] input: Serialized protobuffer.
* @param[in] n_threads: Optional parameter for the size of the threadpool
*dispatcher.
* @returns shared_ptr to a new TenSEALContext object.
**/
static shared_ptr<TenSEALContext> Create(const std::string& input);
static shared_ptr<TenSEALContext> Create(const std::string& input,
optional<uint> n_threads = {});
/**
* Create a context from a protobuffer.
* @param[in] input: The protobuffer.
* @param[in] n_threads: Optional parameter for the size of the threadpool
*dispatcher.
* @returns shared_ptr to a new TenSEALContext object.
**/
static shared_ptr<TenSEALContext> Create(const TenSEALContextProto& input);
static shared_ptr<TenSEALContext> Create(const TenSEALContextProto& input,
optional<uint> n_threads = {});
/**
* @returns a pointer to the public key.
**/
Expand Down Expand Up @@ -222,6 +234,11 @@ class TenSEALContext {
* @returns true if the contexts are identical.
**/
bool equals(const std::shared_ptr<TenSEALContext>& other);
/**
* @returns a pointer to the threadpool dispatcher
**/
shared_ptr<sync::ThreadPool> dispatcher() { return _dispatcher; }
size_t dispatcher_size() { return _threads; }

private:
EncryptionParameters _parms;
Expand All @@ -231,6 +248,10 @@ class TenSEALContext {
shared_ptr<RelinKeys> _relin_keys;
shared_ptr<GaloisKeys> _galois_keys;
shared_ptr<TenSEALEncoder> encoder_factory;

shared_ptr<sync::ThreadPool> _dispatcher;
uint _threads;

/**
* Switches for automatic relinearization, rescaling, and modulus switching
**/
Expand All @@ -242,12 +263,13 @@ class TenSEALContext {
uint8_t _auto_flags =
flag_auto_relin | flag_auto_rescale | flag_auto_mod_switch;

TenSEALContext(EncryptionParameters parms);
TenSEALContext(istream& stream);
TenSEALContext(const std::string& stream);
TenSEALContext(const TenSEALContextProto& proto);
TenSEALContext(EncryptionParameters parms, optional<uint> n_threads);
TenSEALContext(istream& stream, optional<uint> n_threads);
TenSEALContext(const std::string& stream, optional<uint> n_threads);
TenSEALContext(const TenSEALContextProto& proto, optional<uint> n_threads);

void base_setup(EncryptionParameters parms);
void dispatcher_setup(optional<uint> n_threads);
void keys_setup(optional<PublicKey> public_key = {},
optional<SecretKey> secret_key = {},
bool generate_relin_keys = true,
Expand Down
18 changes: 6 additions & 12 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,22 +463,16 @@ CKKSVector& CKKSVector::sum_inplace() {
}

CKKSVector CKKSVector::matmul_plain(const vector<vector<double>>& matrix,
uint n_threads) {
size_t n_jobs) {
CKKSVector new_vector = *this;
return new_vector.matmul_plain_inplace(matrix, n_threads);
return new_vector.matmul_plain_inplace(matrix, n_jobs);
}

CKKSVector& CKKSVector::matmul_plain_inplace(
const vector<vector<double>>& matrix, uint n_threads) {
if (n_threads != 1) {
this->ciphertext =
diagonal_ct_vector_matmul_parallel<double, CKKSEncoder>(
this->tenseal_context(), this->ciphertext, this->size(), matrix,
n_threads);
} else {
this->ciphertext = diagonal_ct_vector_matmul<double, CKKSEncoder>(
this->tenseal_context(), this->ciphertext, this->size(), matrix);
}
const vector<vector<double>>& matrix, size_t n_jobs) {
this->ciphertext = diagonal_ct_vector_matmul<double, CKKSEncoder>(
this->tenseal_context(), this->ciphertext, this->size(), matrix,
n_jobs);

this->_size = matrix[0].size();

Expand Down
4 changes: 2 additions & 2 deletions tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ class CKKSVector {
* Matrix multiplication operations.
**/
CKKSVector matmul_plain(const vector<vector<double>>& matrix,
uint n_threads = 0);
size_t n_jobs = 0);
CKKSVector& matmul_plain_inplace(const vector<vector<double>>& matrix,
uint n_threads = 0);
size_t n_jobs = 0);

/**
* Polynomial evaluation with `this` as variable.
Expand Down
Loading

0 comments on commit 08d63c0

Please sign in to comment.