Skip to content

Commit

Permalink
CKKSTensor Dot operation (#196)
Browse files Browse the repository at this point in the history
* first skeleton for dot operation

* matmul

* lint

* tests

* fix: create acc ct based on result param_id

* lint

* rename test func

* fix indexing

* redef API for dot

dot_product* is now just dot*

* lint

* plain dot

* comments

* bind plain dot

* parallel matmul

* fix: to_sum should be scoped in the worker_func only
  • Loading branch information
youben11 authored Dec 23, 2020
1 parent 6232c06 commit ea8f04f
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 36 deletions.
12 changes: 8 additions & 4 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("polyval_", &CKKSVector::polyval_inplace)
// because dot doesn't have a magic function like __add__
// we prefer to overload it instead of having dot_plain functions
.def("dot", &CKKSVector::dot_product)
.def("dot", &CKKSVector::dot_product_plain)
.def("dot_", &CKKSVector::dot_product_inplace)
.def("dot_", &CKKSVector::dot_product_plain_inplace)
.def("dot", &CKKSVector::dot)
.def("dot", &CKKSVector::dot_plain)
.def("dot_", &CKKSVector::dot_inplace)
.def("dot_", &CKKSVector::dot_plain_inplace)
.def("sum", &CKKSVector::sum, py::arg("axis") = 0)
.def("sum_", &CKKSVector::sum_inplace, py::arg("axis") = 0)
.def(
Expand Down Expand Up @@ -502,6 +502,10 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
&CKKSTensor::mul_plain_inplace))
.def("polyval", &CKKSTensor::polyval)
.def("polyval_", &CKKSTensor::polyval_inplace)
.def("dot", &CKKSTensor::dot)
.def("dot_", &CKKSTensor::dot_inplace)
.def("dot", &CKKSTensor::dot_plain)
.def("dot_", &CKKSTensor::dot_plain_inplace)
// python arithmetic
.def("__add__", &CKKSTensor::add)
.def("__add__", py::overload_cast<const double &>(
Expand Down
4 changes: 2 additions & 2 deletions tenseal/cpp/tensors/bfvvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ shared_ptr<BFVVector> BFVVector::mul_inplace(
return shared_from_this();
}

shared_ptr<BFVVector> BFVVector::dot_product_inplace(
shared_ptr<BFVVector> BFVVector::dot_inplace(
const shared_ptr<BFVVector>& to_mul) {
this->mul_inplace(to_mul);
this->sum_inplace();

return shared_from_this();
}

shared_ptr<BFVVector> BFVVector::dot_product_plain_inplace(
shared_ptr<BFVVector> BFVVector::dot_plain_inplace(
const BFVVector::plain_t& to_mul) {
this->mul_plain_inplace(to_mul);
this->sum_inplace();
Expand Down
4 changes: 2 additions & 2 deletions tenseal/cpp/tensors/bfvvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class BFVVector
encrypted_t add_inplace(const encrypted_t& to_add) override;
encrypted_t sub_inplace(const encrypted_t& to_sub) override;
encrypted_t mul_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_product_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override;
encrypted_t dot_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_plain_inplace(const plain_t& to_mul) override;
encrypted_t sum_inplace(size_t axis = 0) override;

/**
Expand Down
229 changes: 217 additions & 12 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,6 @@ shared_ptr<CKKSTensor> CKKSTensor::mul_inplace(
return this->op_inplace(to_mul, OP::MUL);
}

shared_ptr<CKKSTensor> CKKSTensor::dot_product_inplace(
const shared_ptr<CKKSTensor>& to_mul) {
// TODO
return shared_from_this();
}

shared_ptr<CKKSTensor> CKKSTensor::add_plain_inplace(
const PlainTensor<double>& to_add) {
return this->op_plain_inplace(to_add, OP::ADD);
Expand All @@ -387,12 +381,6 @@ shared_ptr<CKKSTensor> CKKSTensor::mul_plain_inplace(
return this->op_plain_inplace(to_mul, OP::MUL);
}

shared_ptr<CKKSTensor> CKKSTensor::dot_product_plain_inplace(
const PlainTensor<double>& to_mul) {
// TODO
return shared_from_this();
}

shared_ptr<CKKSTensor> CKKSTensor::add_plain_inplace(const double& to_add) {
return this->op_plain_inplace(to_add, OP::ADD);
}
Expand Down Expand Up @@ -508,6 +496,223 @@ shared_ptr<CKKSTensor> CKKSTensor::polyval_inplace(
return shared_from_this();
}

shared_ptr<CKKSTensor> CKKSTensor::dot_inplace(
const shared_ptr<CKKSTensor>& other) {
auto this_shape = this->shape();
auto other_shape = other->shape();

if (this_shape.size() == 1) {
if (other_shape.size() == 1) { // 1D-1D
// inner product
this->mul_inplace(other);
this->sum_inplace();
return shared_from_this();
} else if (other_shape.size() == 2) { // 1D-2D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("1D-2D dot isn't implemented yet");
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else if (this_shape.size() == 2) {
if (other_shape.size() == 1) { // 2D-1D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("2D-1D dot isn't implemented yet");
} else if (other_shape.size() == 2) { // 2D-2D
this->matmul_inplace(other);
return shared_from_this();
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
}

shared_ptr<CKKSTensor> CKKSTensor::dot_plain_inplace(
const PlainTensor<double>& other) {
auto this_shape = this->shape();
auto other_shape = other.shape();

if (this_shape.size() == 1) {
if (other_shape.size() == 1) { // 1D-1D
// inner product
this->mul_plain_inplace(other);
this->sum_inplace();
return shared_from_this();
} else if (other_shape.size() == 2) { // 1D-2D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("1D-2D dot isn't implemented yet");
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else if (this_shape.size() == 2) {
if (other_shape.size() == 1) { // 2D-1D
// TODO: better implement broadcasting for mul first then would be
// implemented similar to 1D-1D
throw invalid_argument("2D-1D dot isn't implemented yet");
} else if (other_shape.size() == 2) { // 2D-2D
this->matmul_plain_inplace(other);
return shared_from_this();
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
} else {
throw invalid_argument(
"don't support dot operations of more than 2 dimensions");
}
}

shared_ptr<CKKSTensor> CKKSTensor::matmul_inplace(
const shared_ptr<CKKSTensor> other) {
auto this_shape = this->shape();
auto other_shape = other->shape();

if (this_shape.size() != 2)
throw invalid_argument("this tensor isn't a matrix");
if (other_shape.size() != 2)
throw invalid_argument("operand tensor isn't a matrix");
if (this_shape[1] != other_shape[0])
throw invalid_argument("can't multiply matrices"); // put matrix shapes

vector<size_t> new_shape = vector({this_shape[0], other_shape[1]});
size_t new_size = new_shape[0] * new_shape[1];
vector<Ciphertext> new_data;
new_data.resize(new_shape[0] * new_shape[1]);

size_t n_jobs = this->tenseal_context()->dispatcher_size();

auto worker_func = [&](size_t start, size_t end) -> bool {
vector<Ciphertext> to_sum;
to_sum.resize(this_shape[1]);
for (size_t i = start; i < end; i++) {
auto evaluator = this->tenseal_context()->evaluator;
size_t row = i / new_shape[1];
size_t col = i % new_shape[1];
// inner product
for (size_t j = 0; j < this_shape[1]; j++) {
to_sum[j] = this->_data.at({row, j});
this->perform_op(to_sum[j], other->_data.at({j, col}), OP::MUL);
}
Ciphertext acc(*this->tenseal_context()->seal_context(),
to_sum[0].parms_id());
evaluator->add_many(to_sum, acc);
// set element[row, col] to the computed inner product
new_data[i] = acc;
}
return true;
};

if (n_jobs == 1) {
worker_func(0, new_size);
} else {
size_t batch_size = (new_size + n_jobs - 1) / n_jobs;
vector<future<bool>> futures;
for (size_t i = 0; i < n_jobs; i++) {
futures.push_back(
this->tenseal_context()->dispatcher()->enqueue_task(
worker_func, i * batch_size,
std::min((i + 1) * batch_size, new_size)));
}
// waiting
optional<string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e.what();
}
}

if (fail) {
throw invalid_argument(fail.value());
}
}

this->_data = TensorStorage(new_data, new_shape);
return shared_from_this();
}

shared_ptr<CKKSTensor> CKKSTensor::matmul_plain_inplace(
const PlainTensor<double>& other) {
auto this_shape = this->shape();
auto other_shape = other.shape();

if (this_shape.size() != 2)
throw invalid_argument("this tensor isn't a matrix");
if (other_shape.size() != 2)
throw invalid_argument("operand tensor isn't a matrix");
if (this_shape[1] != other_shape[0])
throw invalid_argument("can't multiply matrices"); // put matrix shapes

vector<size_t> new_shape = vector({this_shape[0], other_shape[1]});
size_t new_size = new_shape[0] * new_shape[1];
vector<Ciphertext> new_data;
new_data.resize(new_shape[0] * new_shape[1]);

size_t n_jobs = this->tenseal_context()->dispatcher_size();

auto worker_func = [&](size_t start, size_t end) -> bool {
vector<Ciphertext> to_sum;
to_sum.resize(this_shape[1]);
for (size_t i = start; i < end; i++) {
auto evaluator = this->tenseal_context()->evaluator;
size_t row = i / new_shape[1];
size_t col = i % new_shape[1];
// inner product
for (size_t j = 0; j < this_shape[1]; j++) {
to_sum[j] = this->_data.at({row, j});
Plaintext pt;
this->tenseal_context()->encode<CKKSEncoder>(
other.at({j, col}), pt, this->_init_scale);
this->perform_plain_op(to_sum[j], pt, OP::MUL);
}
Ciphertext acc(*this->tenseal_context()->seal_context(),
to_sum[0].parms_id());
evaluator->add_many(to_sum, acc);
// set element[row, col] to the computed inner product
new_data[i] = acc;
}
return true;
};

if (n_jobs == 1) {
worker_func(0, new_size);
} else {
size_t batch_size = (new_size + n_jobs - 1) / n_jobs;
vector<future<bool>> futures;
for (size_t i = 0; i < n_jobs; i++) {
futures.push_back(
this->tenseal_context()->dispatcher()->enqueue_task(
worker_func, i * batch_size,
std::min((i + 1) * batch_size, new_size)));
}
// waiting
optional<string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e.what();
}
}

if (fail) {
throw invalid_argument(fail.value());
}
}

this->_data = TensorStorage(new_data, new_shape);
return shared_from_this();
}

void CKKSTensor::clear() {
this->_data = TensorStorage<Ciphertext>();
this->_batch_size = optional<double>();
Expand Down
19 changes: 15 additions & 4 deletions tenseal/cpp/tensors/ckkstensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
const shared_ptr<CKKSTensor>& to_sub) override;
shared_ptr<CKKSTensor> mul_inplace(
const shared_ptr<CKKSTensor>& to_mul) override;
shared_ptr<CKKSTensor> dot_product_inplace(
const shared_ptr<CKKSTensor>& to_mul) override;

shared_ptr<CKKSTensor> add_plain_inplace(const double& to_add) override;
shared_ptr<CKKSTensor> sub_plain_inplace(const double& to_sub) override;
Expand All @@ -51,8 +49,6 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
const PlainTensor<double>& to_sub) override;
shared_ptr<CKKSTensor> mul_plain_inplace(
const PlainTensor<double>& to_mul) override;
shared_ptr<CKKSTensor> dot_product_plain_inplace(
const PlainTensor<double>& to_mul) override;

shared_ptr<CKKSTensor> sum_inplace(size_t axis = 0) override;
shared_ptr<CKKSTensor> sum_batch() {
Expand All @@ -63,6 +59,21 @@ class CKKSTensor : public EncryptedTensor<double, shared_ptr<CKKSTensor>>,
shared_ptr<CKKSTensor> polyval_inplace(
const vector<double>& coefficients) override;

shared_ptr<CKKSTensor> dot_inplace(
const shared_ptr<CKKSTensor>& to_mul) override;
shared_ptr<CKKSTensor> dot_plain_inplace(
const PlainTensor<double>& to_mul) override;

shared_ptr<CKKSTensor> matmul(const shared_ptr<CKKSTensor> other) {
return this->copy()->matmul_inplace(other);
}
shared_ptr<CKKSTensor> matmul_inplace(const shared_ptr<CKKSTensor> other);
shared_ptr<CKKSTensor> matmul_plain(const PlainTensor<double>& other) {
return this->copy()->matmul_plain_inplace(other);
}
shared_ptr<CKKSTensor> matmul_plain_inplace(
const PlainTensor<double>& other);

void load(const string& vec) override;
string save() const override;

Expand Down
5 changes: 2 additions & 3 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,15 @@ shared_ptr<CKKSVector> CKKSVector::mul_inplace(
return shared_from_this();
}

shared_ptr<CKKSVector> CKKSVector::dot_product_inplace(
shared_ptr<CKKSVector> CKKSVector::dot_inplace(
const shared_ptr<CKKSVector>& to_mul) {
this->mul_inplace(to_mul);
this->sum_inplace();

return shared_from_this();
}

shared_ptr<CKKSVector> CKKSVector::dot_product_plain_inplace(
const plain_t& to_mul) {
shared_ptr<CKKSVector> CKKSVector::dot_plain_inplace(const plain_t& to_mul) {
this->mul_plain_inplace(to_mul);
this->sum_inplace();

Expand Down
4 changes: 2 additions & 2 deletions tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class CKKSVector
encrypted_t add_inplace(const encrypted_t& to_add) override;
encrypted_t sub_inplace(const encrypted_t& to_sub) override;
encrypted_t mul_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_product_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_product_plain_inplace(const plain_t& to_mul) override;
encrypted_t dot_inplace(const encrypted_t& to_mul) override;
encrypted_t dot_plain_inplace(const plain_t& to_mul) override;
encrypted_t sum_inplace(size_t axis = 0) override;

/**
Expand Down
Loading

0 comments on commit ea8f04f

Please sign in to comment.