Skip to content

Commit

Permalink
Add serialization support for Plaintensor (#221)
Browse files Browse the repository at this point in the history
* add serialization support for Plaintensor

* bazel fixes: drop forks
  • Loading branch information
bcebere authored Jan 21, 2021
1 parent dc78061 commit 2594817
Show file tree
Hide file tree
Showing 19 changed files with 130 additions and 11 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/bazel_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Setup Bazel
uses: abhinavsingh/setup-bazel@v3
with:
version: 3.7.1
- name: Run gtest
timeout-minutes: 30
run: bazel test --test_output=all --spawn_strategy=standalone --test_timeout=1500 --jobs 1 //tests/cpp/...
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@
[submodule "third_party/protobuf"]
path = third_party/protobuf
url = https://github.com/protocolbuffers/protobuf
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json
2 changes: 2 additions & 0 deletions cmake/xtensor.cmake
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
add_definitions(-DXTENSOR_USE_XSIMD)
set(XTENSOR_USE_XSIMD ON)

include_directories(third_party/json/include/)
include_directories(third_party/xtl/include)

add_subdirectory(third_party/xtl)
set(xtl_DIR
"${CMAKE_CURRENT_BINARY_DIR}/third_party/xtl"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.0a2
current_version = 0.3.0a3
commit = True
tag = True
files = tenseal/version.py
Expand Down
6 changes: 6 additions & 0 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def plain_tensor(*args, **kwargs) -> PlainTensor:
return PlainTensor(*args, **kwargs)


def plain_tensor_from(data: bytes, dtype: str = "float") -> PlainTensor:
"""Load a PlainTensor from a buffer."""
return PlainTensor.load(data, dtype)


def bfv_vector(*args, **kwargs) -> BFVVector:
"""Constructor function for tenseal.BFVVector"""
return BFVVector(*args, **kwargs)
Expand Down Expand Up @@ -138,6 +143,7 @@ def lazy_ckks_tensor_from(data: bytes) -> CKKSTensor:
"context_from",
"im2col_encoding",
"plain_tensor",
"plain_tensor_from",
"ENCRYPTION_TYPE",
"__version__",
]
4 changes: 3 additions & 1 deletion tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void bind_plain_tensor(py::module &m, const std::string &name) {
.def(py::init<const vector<plain_t> &>())
.def(py::init<const vector<vector<plain_t>> &>())
.def(py::init<const vector<plain_t> &, const vector<size_t> &>())
.def(py::init<const string &>())
.def("at", &type::at)
.def("get_diagonal", &type::get_diagonal)
.def("horizontal_scan", &type::horizontal_scan)
Expand All @@ -33,7 +34,8 @@ void bind_plain_tensor(py::module &m, const std::string &name) {
.def("reshape_", &type::reshape_inplace)
.def("__len__", &type::size)
.def("empty", &type::empty)
.def("replicate", &type::replicate);
.def("replicate", &type::replicate)
.def("serialize", [](type &obj) { return py::bytes(obj.save()); });
}

PYBIND11_MODULE(_tenseal_cpp, m) {
Expand Down
9 changes: 9 additions & 0 deletions tenseal/cpp/tensors/plain_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class PlainTensor {
* @param[in] existing storage.
*/
PlainTensor(const TensorStorage<plain_t>& data) : _data(data) {}
/**
* Create a new PlainTensor from a serialized buffer.
* @param[in] serialized buffer.
*/
PlainTensor(const std::string& data) { this->load(data); }
/**
* Reshape the tensor
* **/
Expand Down Expand Up @@ -223,6 +228,10 @@ class PlainTensor {
}
PlainTensor<plain_t> copy() { return PlainTensor<plain_t>(this->_data); }

std::string save() { return _data.save(); }

void load(const std::string& buf) { _data = TensorStorage<plain_t>(buf); }

private:
TensorStorage<plain_t> _data;
};
Expand Down
15 changes: 15 additions & 0 deletions tenseal/cpp/tensors/tensor_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "gsl/span"
#include "xtensor/xadapt.hpp"
#include "xtensor/xarray.hpp"
#include "xtensor/xjson.hpp"

namespace tenseal {

Expand Down Expand Up @@ -114,6 +115,11 @@ class TensorStorage {

_data = xt::adapt(flat_data, shape);
}
/**
* Create a new TensorStorage from a serialized buffer.
* @param[in] input buffer.
*/
TensorStorage(const std::string& data) { this->load(data); }
/**
* Reshape the TensorStorage.
* @param[in] new shape.
Expand Down Expand Up @@ -336,6 +342,15 @@ class TensorStorage {
return TensorStorage<dtype_t>(this->_data, this->shape());
}

std::string save() {
nlohmann::json buf = _data;
return buf.dump();
}

void load(const std::string& buf) {
xt::from_json(nlohmann::json::parse(buf), _data);
}

private:
xt::xarray<dtype_t> _data;
};
Expand Down
20 changes: 14 additions & 6 deletions tenseal/deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def tenseal_deps():
if "com_xtensorstack_xtensor" not in native.existing_rules():
http_archive(
name = "com_xtensorstack_xtensor",
sha256 = "c6263cf5e22bff44c3258223ea4af74d885060625dd4b092889eb0e7d2d749b0",
sha256 = "b73aacfdef12422f45b27ac43537bd9371ede092df4c14e20d2b8e41b2b5648e",
build_file = "//third_party:xtensor.BUILD",
strip_prefix = "xtensor-master/include",
urls = ["https://github.com/bcebere/xtensor/archive/master.zip"],
strip_prefix = "xtensor-0.22.0/include",
urls = ["https://github.com/xtensor-stack/xtensor/archive/0.22.0.tar.gz"],
)
if "com_xtensorstack_xtl" not in native.existing_rules():
http_archive(
Expand All @@ -47,13 +47,21 @@ def tenseal_deps():
strip_prefix = "xtl-0.6.23/include",
urls = ["https://github.com/xtensor-stack/xtl/archive/0.6.23.tar.gz"],
)
if "com_nlohmann_json" not in native.existing_rules():
http_archive(
name = "com_nlohmann_json",
build_file = "//third_party:nlohmann_json.BUILD",
sha256 = "4cf0df69731494668bdd6460ed8cb269b68de9c19ad8c27abc24cd72605b2d5b",
strip_prefix = "json-3.9.1/include",
urls = ["https://github.com/nlohmann/json/archive/v3.9.1.tar.gz"],
)
if "com_microsoft_gsl" not in native.existing_rules():
http_archive(
name = "com_microsoft_gsl",
sha256 = "76269b66d95da27b1253f14aaf01ace9f91db785981bc1553b730231faf23fd6",
sha256 = "d3234d7f94cea4389e3ca70619b82e8fb4c2f33bb3a070799f1e18eef500a083",
build_file = "//third_party:gsl.BUILD",
strip_prefix = "GSL-master/include",
urls = ["https://github.com/bcebere/GSL/archive/master.zip"],
strip_prefix = "GSL-3.1.0/include",
urls = ["https://github.com/microsoft/GSL/archive/v3.1.0.tar.gz"],
)


Expand Down
3 changes: 2 additions & 1 deletion tenseal/preload.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def tenseal_preload():
name = "rules_foreign_cc",
remote = "https://github.com/bazelbuild/rules_foreign_cc",
init_submodules = True,
commit="04c04fe7d2fa09e46c630c37d8f32e56598527ca",
commit="d54c78ab86b40770ee19f0949db9d74a831ab9f0",
)

if "pybind11_bazel" not in native.existing_rules():
Expand All @@ -40,5 +40,6 @@ def tenseal_preload():
if "rules_python" not in native.existing_rules():
http_archive(
name = "rules_python",
sha256 = "b6d46438523a3ec0f3cead544190ee13223a52f6a6765a29eae7b7cc24cc83a0",
url = "https://github.com/bazelbuild/rules_python/releases/download/0.1.0/rules_python-0.1.0.tar.gz",
)
31 changes: 31 additions & 0 deletions tenseal/tensors/plaintensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self, tensor, shape: List[int] = None, dtype: str = "float"):
"""
if dtype not in ("float", "int"):
raise ValueError("wrong dtype, must be either 'float' or 'int'")

# wrapping
if isinstance(tensor, (ts._ts_cpp.PlainTensorDouble, ts._ts_cpp.PlainTensorInt64)):
self._dtype = dtype
self.data = tensor
return

try:
t = np.array(tensor, dtype=dtype)
except:
Expand Down Expand Up @@ -109,3 +116,27 @@ def reshape_(self, shape: List[int]):
"Changes the internal representation to a new shape"
self.data.reshape_(shape)
return self

@classmethod
def load(cls, data: bytes, dtype: str = "float") -> "PlainTensor":
"""
Constructor method for the tensor object from a serialized string.
Args:
data: the serialized data.
dtype: underlining data type.
Returns:
Tensor object.
"""
if not isinstance(data, bytes):
raise TypeError("Invalid input types: vector: {}".format(type(data)))

if dtype == "float":
return cls(ts._ts_cpp.PlainTensorDouble(data))
elif dtype == "int":
return cls(ts._ts_cpp.PlainTensorInt(data))
else:
raise ValueError("wrong dtype, must be either 'float' or 'int'")

def serialize(self) -> bytes:
"""Serialize the tensor into a stream of bytes"""
return self.data.serialize()
2 changes: 1 addition & 1 deletion tenseal/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0a2"
__version__ = "0.3.0a3"
12 changes: 12 additions & 0 deletions tests/cpp/tensors/plaintensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ TEST_F(PlainTensorTest, TestCreateFrom2DVector) {
ASSERT_THAT(tensor.strides(), ElementsAreArray({2, 1}));
}

TEST_F(PlainTensorTest, TestCreateFromString) {
vector<vector<double>> data = {{1.1, 2.2}, {3.3, 4.4}};
PlainTensor<double> tensor(data);
auto buf = tensor.save();

auto newtensor = PlainTensor<double>(buf);

ASSERT_THAT(newtensor.data(), ElementsAreArray({1.1, 2.2, 3.3, 4.4}));
ASSERT_THAT(newtensor.shape(), ElementsAreArray({2, 2}));
ASSERT_THAT(newtensor.strides(), ElementsAreArray({2, 1}));
}

TEST_F(PlainTensorTest, TestCreateFrom2DVectorFail) {
vector<vector<double>> data = {{1.1, 2.2}, {3.3}};
EXPECT_THROW(PlainTensor<double> tensor(data), std::exception);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/tenseal/tensors/test_plain_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def test_sanity(data, shape):
assert tensor.empty() == False
assert tensor.strides() == strides

buf = tensor.serialize()
new_tensor = ts.plain_tensor_from(buf)

assert new_tensor.raw == data
assert new_tensor.shape == shape
assert new_tensor.size() == shape[0]
assert len(new_tensor) == shape[0]
assert new_tensor.empty() == False
assert new_tensor.strides() == strides


@pytest.mark.parametrize(
"data, shape, reshape",
Expand Down
1 change: 1 addition & 0 deletions third_party/gsl.BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cc_library(
name = "gsl",
hdrs = glob(["**"]),
includes = ["."],
visibility = ["//visibility:public"],
)
1 change: 1 addition & 0 deletions third_party/json
Submodule json added at 92fa1d
9 changes: 9 additions & 0 deletions third_party/nlohmann_json.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cc_library(
name = "json",
hdrs = glob([
"nlohmann/*.hpp",
"nlohmann/**/*.hpp",
]),
includes = ["."],
visibility = ["//visibility:public"]
)
6 changes: 5 additions & 1 deletion third_party/xtensor.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@ cc_library(
name = "xtensor",
hdrs = glob(["**"]),
visibility = ["//visibility:public"],
deps = ["@com_xtensorstack_xtl//:xtl"]
includes = ["."],
deps = [
"@com_xtensorstack_xtl//:xtl",
"@com_nlohmann_json//:json",
]
)
1 change: 1 addition & 0 deletions third_party/xtl.BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cc_library(
name = "xtl",
hdrs = glob(["**"]),
includes = ["."],
visibility = ["//visibility:public"],
)

0 comments on commit 2594817

Please sign in to comment.