Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy context #197

Merged
merged 7 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def bfv_vector_from(context: Context, data: bytes) -> BFVVector:
return BFVVector.load(context, data)


def lazy_bfv_vector_from(data: bytes) -> BFVVector:
"""Load a BFVVector from a protocol buffer."""
return BFVVector.lazy_load(data)


def ckks_vector(*args, **kwargs) -> CKKSVector:
"""Constructor function for tenseal.CKKSVector"""
return CKKSVector(*args, **kwargs)
Expand All @@ -98,6 +103,11 @@ def ckks_vector_from(context: Context, data: bytes) -> CKKSVector:
return CKKSVector.load(context, data)


def lazy_ckks_vector_from(data: bytes) -> CKKSVector:
"""Load a CKKSVector from a protocol buffer."""
return CKKSVector.lazy_load(data)


def ckks_tensor(*args, **kwargs) -> CKKSTensor:
"""Constructor function for tenseal.CKKSTensor"""
return CKKSTensor(*args, **kwargs)
Expand All @@ -109,13 +119,21 @@ def ckks_tensor_from(context: Context, data: bytes) -> CKKSTensor:
return CKKSTensor.load(context, data)


def lazy_ckks_tensor_from(data: bytes) -> CKKSTensor:
"""Load a CKKSTensor from a protocol buffer"""
return CKKSTensor.lazy_load(data)


__all__ = [
"bfv_vector",
"bfv_vector_from",
"lazy_bfv_vector_from",
"ckks_vector",
"ckks_vector_from",
"lazy_ckks_vector_from",
"ckks_tensor",
"ckks_tensor_from",
"lazy_ckks_tensor_from",
"context",
"context_from",
"im2col_encoding",
Expand Down
18 changes: 12 additions & 6 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
[](const shared_ptr<TenSEALContext> &ctx, const std::string &data) {
return BFVVector::Create(ctx, data);
}))
.def(py::init(
[](const std::string &data) { return BFVVector::Create(data); }))
.def("size", py::overload_cast<>(&BFVVector::size, py::const_))
.def("decrypt",
[](shared_ptr<BFVVector> obj) { return obj->decrypt().data(); })
Expand Down Expand Up @@ -144,8 +146,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
[](shared_ptr<BFVVector> obj, const vector<int64_t> &other) {
return obj->mul_plain_inplace(other);
})
.def("context",
[](shared_ptr<BFVVector> obj) { return obj->tenseal_context(); })
.def("context", &BFVVector::tenseal_context)
.def("link_context", &BFVVector::link_tenseal_context)
.def("serialize",
[](shared_ptr<BFVVector> &obj) { return py::bytes(obj->save()); })
.def("copy", &BFVVector::deepcopy)
Expand Down Expand Up @@ -214,6 +216,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
[](const shared_ptr<TenSEALContext> &ctx, const std::string &data) {
return CKKSVector::Create(ctx, data);
}))
.def(py::init(
[](const std::string &data) { return CKKSVector::Create(data); }))
.def("size", py::overload_cast<>(&CKKSVector::size, py::const_))
.def("decrypt",
[](shared_ptr<CKKSVector> obj) { return obj->decrypt().data(); })
Expand Down Expand Up @@ -415,8 +419,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
return obj->matmul_plain_inplace(matrix, n_jobs);
},
py::arg("matrix"), py::arg("n_jobs") = 0)
.def("context",
[](shared_ptr<CKKSVector> obj) { return obj->tenseal_context(); })
.def("context", &CKKSVector::tenseal_context)
.def("link_context", &CKKSVector::link_tenseal_context)
.def("serialize",
[](shared_ptr<CKKSVector> obj) { return py::bytes(obj->save()); })
.def("copy", &CKKSVector::deepcopy)
Expand Down Expand Up @@ -448,6 +452,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
[](const shared_ptr<TenSEALContext> &ctx, const std::string &data) {
return CKKSTensor::Create(ctx, data);
}))
.def(py::init(
[](const std::string &data) { return CKKSTensor::Create(data); }))
.def("decrypt",
[](shared_ptr<CKKSTensor> obj) { return obj->decrypt(); })
.def("decrypt",
Expand Down Expand Up @@ -560,8 +566,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("__neg__", &CKKSTensor::negate)
.def("__pow__", &CKKSTensor::power)
.def("__ipow__", &CKKSTensor::power_inplace)
.def("context",
[](shared_ptr<CKKSTensor> obj) { return obj->tenseal_context(); })
.def("context", &CKKSTensor::tenseal_context)
.def("link_context", &CKKSTensor::link_tenseal_context)
.def("serialize",
[](shared_ptr<CKKSTensor> &obj) { return py::bytes(obj->save()); })
.def("copy", &CKKSTensor::deepcopy)
Expand Down
7 changes: 5 additions & 2 deletions tenseal/cpp/tensors/bfvvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ BFVVector::BFVVector(const shared_ptr<TenSEALContext>& ctx, const string& vec) {
this->load(vec);
}

BFVVector::BFVVector(const string& vec) { this->load(vec); }

BFVVector::BFVVector(const shared_ptr<TenSEALContext>& ctx,
const BFVVectorProto& vec) {
this->prepare_context(ctx);
Expand Down Expand Up @@ -321,8 +323,9 @@ BFVVectorProto BFVVector::save_proto() const {
}

void BFVVector::load(const std::string& vec) {
if (this->tenseal_context() == nullptr) {
throw invalid_argument("context missing for deserialization");
if (!this->has_context()) {
_lazy_buffer = vec;
return;
}
BFVVectorProto buffer;
if (!buffer.ParseFromArray(vec.c_str(), static_cast<int>(vec.size()))) {
Expand Down
1 change: 1 addition & 0 deletions tenseal/cpp/tensors/bfvvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class BFVVector
BFVVector(const shared_ptr<TenSEALContext>& ctx, const plain_t& vec);
BFVVector(const shared_ptr<const BFVVector>&);
BFVVector(const shared_ptr<TenSEALContext>& ctx, const string& vec);
BFVVector(const string& vec);
BFVVector(const TenSEALContextProto& ctx, const BFVVectorProto& vec);
BFVVector(const shared_ptr<TenSEALContext>& ctx, const BFVVectorProto& vec);

Expand Down
7 changes: 7 additions & 0 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ CKKSTensor::CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
this->load(tensor);
}

CKKSTensor::CKKSTensor(const string& tensor) { this->load(tensor); }

CKKSTensor::CKKSTensor(const TenSEALContextProto& ctx,
const CKKSTensorProto& tensor) {
this->load_context_proto(ctx);
Expand Down Expand Up @@ -550,6 +552,11 @@ CKKSTensorProto CKKSTensor::save_proto() const {
}

void CKKSTensor::load(const std::string& tensor_str) {
if (!this->has_context()) {
_lazy_buffer = tensor_str;
return;
}

CKKSTensorProto buffer;
if (!buffer.ParseFromArray(tensor_str.c_str(),
static_cast<int>(tensor_str.size()))) {
Expand Down
1 change: 1 addition & 0 deletions tenseal/cpp/tensors/ckkstensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
std::optional<double> scale = {}, bool batch = true);
CKKSTensor(const TenSEALContextProto& ctx, const CKKSTensorProto& tensor);
CKKSTensor(const shared_ptr<TenSEALContext>& ctx, const string& vec);
CKKSTensor(const string& vec);
CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
const CKKSTensorProto& tensor);
CKKSTensor(const shared_ptr<const CKKSTensor>& vec);
Expand Down
7 changes: 7 additions & 0 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ CKKSVector::CKKSVector(const shared_ptr<TenSEALContext>& ctx,
this->load(vec);
}

CKKSVector::CKKSVector(const string& vec) { this->load(vec); }

CKKSVector::CKKSVector(const TenSEALContextProto& ctx,
const CKKSVectorProto& vec) {
this->load_context_proto(ctx);
Expand Down Expand Up @@ -461,6 +463,11 @@ CKKSVectorProto CKKSVector::save_proto() const {
}

void CKKSVector::load(const std::string& vec) {
if (!this->has_context()) {
_lazy_buffer = vec;
return;
}

CKKSVectorProto buffer;
if (!buffer.ParseFromArray(vec.c_str(), static_cast<int>(vec.size()))) {
throw invalid_argument("failed to parse CKKS stream");
Expand Down
1 change: 1 addition & 0 deletions tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class CKKSVector
CKKSVector(const shared_ptr<TenSEALContext>& ctx, const plain_t& vec,
optional<double> scale = {});
CKKSVector(const shared_ptr<TenSEALContext>& ctx, const string& vec);
CKKSVector(const string& vec);
CKKSVector(const TenSEALContextProto& ctx, const CKKSVectorProto& vec);
CKKSVector(const shared_ptr<TenSEALContext>& ctx,
const CKKSVectorProto& vec);
Expand Down
14 changes: 12 additions & 2 deletions tenseal/cpp/tensors/encrypted_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,21 @@ class EncryptedTensor {
if (_context == nullptr) throw invalid_argument("missing context");
return _context;
};

/**
* Check if the context is linked
* **/
bool has_context() const {
return _context != nullptr;
};
/**
* Link to a TenSEAL context.
**/
void link_tenseal_context(shared_ptr<TenSEALContext> ctx) {
this->_context = ctx;
if (_lazy_buffer) {
this->load(*_lazy_buffer);
_lazy_buffer = {};
}
youben11 marked this conversation as resolved.
Show resolved Hide resolved
};
void load_context_proto(const TenSEALContextProto& ctx) {
this->link_tenseal_context(TenSEALContext::Create(ctx));
Expand Down Expand Up @@ -252,9 +261,10 @@ class EncryptedTensor {
virtual ~EncryptedTensor(){};

protected:
shared_ptr<TenSEALContext> _context;
optional<string> _lazy_buffer;

private:
shared_ptr<TenSEALContext> _context;
};

} // namespace tenseal
Expand Down
19 changes: 19 additions & 0 deletions tenseal/tensors/abstract_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def context(self) -> "ts.Context":
"""Get the context linked to this tensor"""
return ts.Context._wrap(self.data.context())

def link_context(self, ctx: "ts.Context"):
"""Set the context linked to this tensor"""
return self.data.link_context(ctx.data)

@property
def shape(self) -> List[int]:
return self.data.shape()
Expand All @@ -50,6 +54,21 @@ def load(cls, context: "ts.Context", data: bytes) -> "AbstractTensor":
"Invalid input types context: {} and vector: {}".format(type(context), type(data))
)

@classmethod
def lazy_load(cls, data: bytes) -> "AbstractTensor":
"""
Constructor method for the tensor object from a serialized protobuffer, without a context.
Args:
data: the serialized protobuffer.
Returns:
Tensor object.
"""
if isinstance(data, bytes):
native_type = getattr(ts._ts_cpp, cls.__name__)
return cls._wrap(native_type(data))

raise TypeError("Invalid input types vector: {}".format(type(data)))

def serialize(self) -> bytes:
"""Serialize the tensor into a stream of bytes"""
return self.data.serialize()
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/tensors/bfvvector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ TEST_P(BFVVectorTest, TestEmptyPlaintext) {
std::exception);
}

TEST_F(BFVVectorTest, TestBFVLazyLoading) {
auto ctx = TenSEALContext::Create(scheme_type::bfv, 8192, 1032193, {});
ASSERT_TRUE(ctx != nullptr);

auto l = BFVVector::Create(ctx, vector<int64_t>({1, 2, 3}));
auto r = BFVVector::Create(ctx, vector<int64_t>({2, 3, 4}));

auto buffer = l->save();
auto newl = BFVVector::Create(buffer);

EXPECT_THROW(newl->add(r), std::exception);

newl->link_tenseal_context(ctx);
auto res = newl->add(r);

auto decr = res->decrypt();
EXPECT_THAT(decr.data(), ElementsAreArray({3, 5, 7}));
}

INSTANTIATE_TEST_CASE_P(TestBFVVector, BFVVectorTest,
::testing::Values(false, true));

Expand Down
24 changes: 24 additions & 0 deletions tests/cpp/tensors/ckkstensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,29 @@ TEST_P(CKKSTensorTest, TestEmptyPlaintext) {
INSTANTIATE_TEST_CASE_P(TestCKKSTensor, CKKSTensorTest,
::testing::Values(false, true));

TEST_F(CKKSTensorTest, TestCKKSLazyContext) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);

ctx->global_scale(std::pow(2, 40));

auto l = CKKSTensor::Create(
ctx, PlainTensor(std::vector<double>({1, 2, 3, 4}), {2, 2}));
auto r = CKKSTensor::Create(
ctx, PlainTensor(std::vector<double>({5, 6, 7, 8}), {2, 2}));

auto buffer = l->save();
auto newl = CKKSTensor::Create(buffer);

EXPECT_THROW(newl->add(r), std::exception);

newl->link_tenseal_context(ctx);
auto res = newl->add(r);

auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {6, 8, 10, 12}));
}

} // namespace
} // namespace tenseal
22 changes: 22 additions & 0 deletions tests/cpp/tensors/ckksvector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,5 +249,27 @@ TEST_P(CKKSVectorTest, TestEmptyPlaintext) {
INSTANTIATE_TEST_CASE_P(TestCKKSVector, CKKSVectorTest,
::testing::Values(false, true));

TEST_F(CKKSVectorTest, TestCKKSLazyContext) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);

ctx->global_scale(std::pow(2, 40));

auto l = CKKSVector::Create(ctx, std::vector<double>({1, 2, 3}));
auto r = CKKSVector::Create(ctx, std::vector<double>({3, 4, 4}));

auto buffer = l->save();
auto newl = CKKSVector::Create(buffer);

EXPECT_THROW(newl->add(r), std::exception);

newl->link_tenseal_context(ctx);
auto res = newl->add(r);

auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {4, 6, 7}));
}

} // namespace
} // namespace tenseal
Loading