Skip to content

Commit

Permalink
add CKKSTensor serialization (#168)
Browse files Browse the repository at this point in the history
* add CKKSTensor serialization

* add serialization tests for ckkstensor

* update protobuffers
  • Loading branch information
bcebere authored Nov 27, 2020
1 parent eed2a73 commit 7b6b584
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 22 deletions.
4 changes: 4 additions & 0 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ckks_vector,
ckks_vector_from,
ckks_tensor,
ckks_tensor_from,
plain_tensor,
)
from tenseal.version import __version__
Expand All @@ -26,6 +27,7 @@
# Vectors
BFVVector = _ts_cpp.BFVVector
CKKSVector = _ts_cpp.CKKSVector
CKKSTensor = _ts_cpp.CKKSTensor

# utils
im2col_encoding = _ts_cpp.im2col_encoding
Expand Down Expand Up @@ -93,6 +95,8 @@ def context_from(buff, n_threads=None):
"bfv_vector_from",
"ckks_vector",
"ckks_vector_from",
"ckks_tensor",
"ckks_tensor_from",
"context",
"context_from",
"im2col_encoding",
Expand Down
16 changes: 15 additions & 1 deletion tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,26 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
const PlainTensor<double> &tensor) {
return CKKSTensor::Create(ctx, tensor);
}))
.def(py::init(
[](const shared_ptr<TenSEALContext> &ctx, const std::string &data) {
return CKKSTensor::Create(ctx, data);
}))
.def("decrypt",
[](shared_ptr<CKKSTensor> obj) { return obj->decrypt().data(); })
.def("decrypt",
[](shared_ptr<CKKSTensor> obj, const shared_ptr<SecretKey> &sk) {
return obj->decrypt(sk).data();
});
})
.def("context",
[](shared_ptr<CKKSTensor> obj) { return obj->tenseal_context(); })
.def("serialize",
[](shared_ptr<CKKSTensor> &obj) { return py::bytes(obj->save()); })
.def("copy", &CKKSTensor::deepcopy)
.def("__copy__",
[](shared_ptr<CKKSTensor> &obj) { return obj->deepcopy(); })
.def("__deepcopy__", [](const shared_ptr<CKKSTensor> &obj, py::dict) {
return obj->deepcopy();
});

py::class_<TenSEALContext, std::shared_ptr<TenSEALContext>>(
m, "TenSEALContext")
Expand Down
4 changes: 0 additions & 4 deletions tenseal/cpp/tensors/bfvvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,6 @@ void BFVVector::load_proto(const BFVVectorProto& vec) {
*this->tenseal_context()->seal_context(), vec.ciphertext());
}

void BFVVector::load_context_proto(const TenSEALContextProto& ctx) {
this->link_tenseal_context(TenSEALContext::Create(ctx));
}

BFVVectorProto BFVVector::save_proto() const {
BFVVectorProto buffer;

Expand Down
1 change: 0 additions & 1 deletion tenseal/cpp/tensors/bfvvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ class BFVVector
BFVVectorProto save_proto() const;

void prepare_context(const shared_ptr<TenSEALContext>& ctx);
void load_context_proto(const TenSEALContextProto& buffer);
};

} // namespace tenseal
Expand Down
104 changes: 95 additions & 9 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,32 @@ CKKSTensor::CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
CKKSTensor::encrypt(ctx, this->_init_scale, vector<double>({*it})));
}

CKKSTensor::CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
const string& tensor) {
this->link_tenseal_context(ctx);
this->load(tensor);
}

CKKSTensor::CKKSTensor(const TenSEALContextProto& ctx,
const CKKSTensorProto& tensor) {
this->load_context_proto(ctx);
this->load_proto(tensor);
}

CKKSTensor::CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
const CKKSTensorProto& tensor) {
this->link_tenseal_context(ctx);
this->load_proto(tensor);
}

CKKSTensor::CKKSTensor(const shared_ptr<const CKKSTensor>& tensor) {
this->link_tenseal_context(tensor->tenseal_context());
this->_init_scale = tensor->scale();
this->_shape = tensor->shape();
this->_strides = tensor->strides();
this->_data = tensor->data();
}

Ciphertext CKKSTensor::encrypt(const shared_ptr<TenSEALContext>& ctx,
const double scale, const vector<double>& data) {
if (data.empty()) {
Expand Down Expand Up @@ -148,23 +174,83 @@ shared_ptr<CKKSTensor> CKKSTensor::polyval_inplace(
return shared_from_this();
}

void CKKSTensor::load(const string& vec) {
// TODO
void CKKSTensor::clear() {
this->_shape = vector<size_t>();
this->_strides = vector<size_t>();
this->_data = vector<Ciphertext>();
this->_init_scale = 0;
}

string CKKSTensor::save() const {
// TODO
return "saving";
void CKKSTensor::load_proto(const CKKSTensorProto& tensor_proto) {
if (this->tenseal_context() == nullptr) {
throw invalid_argument("context missing for deserialization");
}
this->clear();

for (int idx = 0; idx < tensor_proto.shape_size(); ++idx) {
this->_shape.push_back(tensor_proto.shape(idx));
}
for (int idx = 0; idx < tensor_proto.strides_size(); ++idx) {
this->_strides.push_back(tensor_proto.strides(idx));
}
for (int idx = 0; idx < tensor_proto.ciphertexts_size(); ++idx)
this->_data.push_back(SEALDeserialize<Ciphertext>(
*this->tenseal_context()->seal_context(),
tensor_proto.ciphertexts(idx)));
this->_init_scale = tensor_proto.scale();
}

CKKSTensorProto CKKSTensor::save_proto() const {
CKKSTensorProto buffer;

for (auto& ct : this->_data) {
buffer.add_ciphertexts(SEALSerialize<Ciphertext>(ct));
}
for (auto& dim : this->_shape) {
buffer.add_shape(dim);
}
for (auto& stride : this->_strides) {
buffer.add_strides(stride);
}
buffer.set_scale(this->_init_scale);

return buffer;
}

void CKKSTensor::load(const std::string& tensor_str) {
CKKSTensorProto buffer;
if (!buffer.ParseFromArray(tensor_str.c_str(),
static_cast<int>(tensor_str.size()))) {
throw invalid_argument("failed to parse CKKS tensor stream");
}
this->load_proto(buffer);
}

std::string CKKSTensor::save() const {
auto buffer = this->save_proto();
std::string output;
output.resize(proto_bytes_size(buffer));

if (!buffer.SerializeToArray((void*)output.c_str(),
static_cast<int>(proto_bytes_size(buffer)))) {
throw invalid_argument("failed to save CKKS tensor proto");
}

return output;
}

shared_ptr<CKKSTensor> CKKSTensor::copy() const {
// TODO
return nullptr;
return shared_ptr<CKKSTensor>(new CKKSTensor(shared_from_this()));
}

shared_ptr<CKKSTensor> CKKSTensor::deepcopy() const {
// TODO
return nullptr;
TenSEALContextProto ctx = this->tenseal_context()->save_proto();
CKKSTensorProto vec = this->save_proto();
return CKKSTensor::Create(ctx, vec);
}

vector<Ciphertext> CKKSTensor::data() const { return _data; }
vector<size_t> CKKSTensor::shape() const { return _shape; }
vector<size_t> CKKSTensor::strides() const { return _strides; }
double CKKSTensor::scale() const { return _init_scale; }
} // namespace tenseal
15 changes: 15 additions & 0 deletions tenseal/cpp/tensors/ckkstensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "tenseal/cpp/tensors/encrypted_tensor.h"
#include "tenseal/cpp/tensors/plain_tensor.h"
#include "tenseal/proto/tensors.pb.h"

namespace tenseal {

Expand Down Expand Up @@ -68,6 +69,11 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
shared_ptr<CKKSTensor> copy() const override;
shared_ptr<CKKSTensor> deepcopy() const override;

vector<Ciphertext> data() const;
vector<size_t> shape() const;
vector<size_t> strides() const;
double scale() const;

private:
vector<Ciphertext> _data;
vector<size_t> _shape;
Expand All @@ -77,11 +83,20 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
const PlainTensor<double>& tensor,
std::optional<double> scale = {});
CKKSTensor(const TenSEALContextProto& ctx, const CKKSTensorProto& tensor);
CKKSTensor(const shared_ptr<TenSEALContext>& ctx, const string& vec);
CKKSTensor(const shared_ptr<TenSEALContext>& ctx,
const CKKSTensorProto& tensor);
CKKSTensor(const shared_ptr<const CKKSTensor>& vec);

static Ciphertext encrypt(const shared_ptr<TenSEALContext>& ctx,
const double scale, const vector<double>& data);
static Ciphertext encrypt(const shared_ptr<TenSEALContext>& ctx,
const double scale, const double data);

void load_proto(const CKKSTensorProto& buffer);
CKKSTensorProto save_proto() const;
void clear();
};

} // namespace tenseal
Expand Down
4 changes: 0 additions & 4 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,6 @@ void CKKSVector::load_proto(const CKKSVectorProto& vec) {
this->_init_scale = vec.scale();
}

void CKKSVector::load_context_proto(const TenSEALContextProto& ctx) {
this->link_tenseal_context(TenSEALContext::Create(ctx));
}

CKKSVectorProto CKKSVector::save_proto() const {
CKKSVectorProto buffer;

Expand Down
2 changes: 0 additions & 2 deletions tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ class CKKSVector

void load_proto(const CKKSVectorProto& buffer);
CKKSVectorProto save_proto() const;

void load_context_proto(const TenSEALContextProto& buffer);
};

} // namespace tenseal
Expand Down
3 changes: 3 additions & 0 deletions tenseal/cpp/tensors/encrypted_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class EncryptedTensor {
void link_tenseal_context(shared_ptr<TenSEALContext> ctx) {
this->_context = ctx;
};
void load_context_proto(const TenSEALContextProto& ctx) {
this->link_tenseal_context(TenSEALContext::Create(ctx));
}

virtual ~EncryptedTensor(){};

Expand Down
12 changes: 12 additions & 0 deletions tenseal/proto/tensors.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ message CKKSVectorProto {
// Scale value
double scale = 3;
};

//CKKSTensor parameters
message CKKSTensorProto {
// The shape of the encrypted tensor
repeated uint32 shape = 1;
// The strides of the encrypted tensor
repeated uint32 strides = 2;
// The serialized ciphertexts
repeated bytes ciphertexts = 3;
// Scale value
double scale = 4;
};
26 changes: 25 additions & 1 deletion tenseal/tensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,28 @@ def ckks_tensor(context, tensor, scale=None):
)


__all__ = ["bfv_vector", "bfv_vector_from", "ckks_vector", "ckks_vector_from", "ckks_tensor"]
def ckks_tensor_from(context, data):
"""
Constructor method for the CKKSTensor object from a serialized protobuffer.
Args:
context: a TenSEALContext object, holding the encryption parameters and keys.
data: the serialized protobuffer.
Returns:
CKKSTensor object.
"""
if isinstance(context, _ts_cpp.TenSEALContext) and isinstance(data, bytes):
return _ts_cpp.CKKSTensor(context, data)

raise TypeError(
"Invalid CKKS input types context: {} and vector: {}".format(type(context), type(data))
)


__all__ = [
"bfv_vector",
"bfv_vector_from",
"ckks_vector",
"ckks_vector_from",
"ckks_tensor",
"ckks_tensor_from",
]
1 change: 1 addition & 0 deletions tests/cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cc_test(
"tensealcontext_test.cpp",
"tensors/bfvvector_test.cpp",
"tensors/ckksvector_test.cpp",
"tensors/ckkstensor_test.cpp",
],
copts = TENSEAL_DEFAULT_COPTS,
includes = TENSEAL_DEFAULT_INCLUDES,
Expand Down
59 changes: 59 additions & 0 deletions tests/cpp/tensors/ckkstensor_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "tenseal/cpp/tenseal.h"

namespace tenseal {
namespace {

using namespace ::testing;
using namespace std;

auto duplicate(shared_ptr<CKKSTensor> in) {
auto vec = in->save();

return CKKSTensor::Create(in->tenseal_context(), vec);
}

class CKKSTensorTest : public TestWithParam</*serialize=*/bool> {
protected:
void SetUp() {}
};
TEST_P(CKKSTensorTest, TestCreateCKKS) {
bool should_serialize_first = GetParam();

auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);

auto l = CKKSTensor::Create(ctx, std::vector<double>{1, 2, 3}, 1);

if (should_serialize_first) {
l = duplicate(l);
}

ASSERT_EQ(l->data().size(), 3);
}

TEST_F(CKKSTensorTest, TestCreateCKKSFail) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);

EXPECT_THROW(
auto l = CKKSTensor::Create(ctx, std::vector<double>({1, 2, 3})),
std::exception);
}

TEST_P(CKKSTensorTest, TestEmptyPlaintext) {
auto ctx = TenSEALContext::Create(scheme_type::bfv, 8192, 1032193, {});
ASSERT_TRUE(ctx != nullptr);

EXPECT_THROW(CKKSTensor::Create(ctx, std::vector<double>({})),
std::exception);
}

INSTANTIATE_TEST_CASE_P(TestCKKSTensor, CKKSTensorTest,
::testing::Values(false, true));

} // namespace
} // namespace tenseal
Loading

0 comments on commit 7b6b584

Please sign in to comment.