Skip to content

Commit

Permalink
Add packed bit field support to Row Structure (#493)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #493

# Background:

Currently in order to successfully use UDP, you must write some carefully crafted code that will take all the rows of metadata for one side and package it into a collection of bytes. Afterwards the caller will get a `SecString` object back which is a bit representation of all the bytes they passed in, minus the filtered out rows. The user must then extract the corresponding bits for each column into separate MPC Types.  This is a cumbersome process which is error prone, as you must make sure to carefully match up the two steps and any changes can cause a bug.

# This Diff

Adds support to RowDefinition for packed bits. Right now you directly pass in the column names that are to be packed, and the interface will combine each column. The output will contain columns in a vector form. I will refactor that in a future diff to make a bit easier to use.

Reviewed By: haochenuw

Differential Revision: D43366173

fbshipit-source-id: 6604bad668dbd4acaf7e59d3b60dfc0be217a091
  • Loading branch information
Tal Davidi authored and facebook-github-bot committed Feb 23, 2023
1 parent 4f8705f commit 2f51aa1
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class RowStructureDefinition : public IRowStructureDefinition<schedulerId> {
auto columnType = columnDefinition->getColumnType();

switch (columnType) {
case IColumnDefinition<
schedulerId>::SupportedColumnTypes::PackedBitField:
serializePackedBitFieldColumn(
columnPointer, data, writeBuffers, numRows, byteOffset);
break;
case IColumnDefinition<schedulerId>::SupportedColumnTypes::UInt32:
serializeIntegerColumn<false, 32>(
columnPointer, data, writeBuffers, numRows, byteOffset);
Expand Down Expand Up @@ -179,6 +184,55 @@ class RowStructureDefinition : public IRowStructureDefinition<schedulerId> {
return result;
}

void serializePackedBitFieldColumn(
const IColumnDefinition<schedulerId>* columnPointer,
const std::unordered_map<
std::string,
typename IRowStructureDefinition<schedulerId>::InputColumnDataType>&
inputData,
std::vector<std::vector<unsigned char>>& writeBuffers,
int numRows,
size_t byteOffset) const {
const PackedBitFieldColumn<schedulerId>* packedBitCol =
dynamic_cast<const PackedBitFieldColumn<schedulerId>*>(columnPointer);

if (packedBitCol == nullptr) {
throw std::runtime_error("Failed to cast to PackedBitFieldColumn");
}
std::vector<std::vector<bool>> bitPack(
numRows, std::vector<bool>(packedBitCol->getSubColumnNames().size()));

for (int i = 0; i < packedBitCol->getSubColumnNames().size(); i++) {
std::string colName = packedBitCol->getSubColumnNames()[i];
if (!inputData.contains(colName)) {
throw std::runtime_error(
"Column: " + colName +
" which was defined in the structure was not included in the input data map.");
}

const std::vector<bool> bitVals =
std::get<std::vector<bool>>(inputData.at(colName));

if (bitVals.size() != numRows) {
std::string err = folly::sformat(
"Invalid number of values for column {} .Got {} values but number of rows should be {} ",
colName,
bitVals.size(),
numRows);
throw std::runtime_error(err);
}

for (int j = 0; j < numRows; j++) {
bitPack[j][i] = bitVals[j];
}
}

for (int i = 0; i < numRows; i++) {
packedBitCol->serializeColumnAsPlaintextBytes(
bitPack.data() + i, writeBuffers[i].data() + byteOffset);
}
}

template <bool isSigned, int8_t width>
void serializeIntegerColumn(
const IColumnDefinition<schedulerId>* columnPointer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#include "fbpcf/scheduler/SchedulerHelper.h"
#include "fbpcf/test/TestHelper.h"

#include "fbpcf/mpc_std_lib/unified_data_process/serialization/FixedSizeArrayColumn.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IRowStructureDefinition.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/PackedBitFieldColumn.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/RowStructureDefinition.h"

namespace fbpcf::mpc_std_lib::unified_data_process::serialization {
Expand All @@ -31,6 +33,12 @@ std::unique_ptr<IRowStructureDefinition<schedulerId>> createRowDefinition() {
std::make_unique<IntegerColumn<schedulerId, true, 64>>("int64Column"));
columnDefs->push_back(
std::make_unique<IntegerColumn<schedulerId, false, 32>>("uint32Column"));

std::vector<std::string> bitColumnNames = {
"boolColumn1", "boolColumn2", "boolColumn3", "boolColumn4"};
columnDefs->push_back(std::make_unique<PackedBitFieldColumn<schedulerId>>(
"packedBits", bitColumnNames));

auto serializer = std::make_unique<RowStructureDefinition<schedulerId>>(
std::move(columnDefs));
return std::move(serializer);
Expand Down Expand Up @@ -112,6 +120,16 @@ deserializeAndRevealAllColumns(
[](uint64_t data) { return data; });
rst.emplace("uint32Column", uint32Data);

std::vector<typename frontend::MPCTypes<schedulerId>::SecBool>
packedBitsMPCValue = std::get<
std::vector<typename frontend::MPCTypes<schedulerId>::SecBool>>(
deserialization.at("packedBits"));

rst.emplace("boolColumn1", packedBitsMPCValue[0].openToParty(0).getValue());
rst.emplace("boolColumn2", packedBitsMPCValue[1].openToParty(0).getValue());
rst.emplace("boolColumn3", packedBitsMPCValue[2].openToParty(0).getValue());
rst.emplace("boolColumn4", packedBitsMPCValue[3].openToParty(0).getValue());

return rst;
}

Expand All @@ -130,6 +148,7 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) {

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<> boolDist(0, 1);
std::uniform_int_distribution<int64_t> uint32Dist(
std::numeric_limits<uint32_t>().min(),
std::numeric_limits<uint32_t>().max());
Expand All @@ -145,17 +164,26 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) {
std::unique_ptr<IRowStructureDefinition<1>> serializer1 =
createRowDefinition<1>();

EXPECT_EQ(serializer0->getRowSizeBytes(), 16);
EXPECT_EQ(serializer1->getRowSizeBytes(), 16);
EXPECT_EQ(serializer0->getRowSizeBytes(), 17);
EXPECT_EQ(serializer1->getRowSizeBytes(), 17);

std::vector<int32_t> int32Data(0);
std::vector<int64_t> int64Data(0);
std::vector<uint32_t> uint32Data(0);
std::vector<bool> boolData1(0);
std::vector<bool> boolData2(0);
std::vector<bool> boolData3(0);
std::vector<bool> boolData4(0);

for (int i = 0; i < batchSize; i++) {
int32Data.push_back(int32Dist(e));
int64Data.push_back(int64Dist(e));
uint32Data.push_back(uint32Dist(e));

boolData1.push_back(boolDist(e));
boolData2.push_back(boolDist(e));
boolData3.push_back(boolDist(e));
boolData4.push_back(boolDist(e));
}

std::unordered_map<
Expand All @@ -164,7 +192,11 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) {
inputData{
{"int32Column", int32Data},
{"int64Column", int64Data},
{"uint32Column", uint32Data}};
{"uint32Column", uint32Data},
{"boolColumn1", boolData1},
{"boolColumn2", boolData2},
{"boolColumn3", boolData3},
{"boolColumn4", boolData4}};

auto serializedBytes =
serializer0->serializeDataAsBytesForUDP(inputData, batchSize);
Expand All @@ -189,9 +221,17 @@ TEST(RowSerializationTest, RowWithMultipleColumnsTest) {
auto int32Rst = std::get<std::vector<int32_t>>(rst.at("int32Column"));
auto int64Rst = std::get<std::vector<int64_t>>(rst.at("int64Column"));
auto uint32Rst = std::get<std::vector<uint32_t>>(rst.at("uint32Column"));
auto boolRst1 = std::get<std::vector<bool>>(rst.at("boolColumn1"));
auto boolRst2 = std::get<std::vector<bool>>(rst.at("boolColumn2"));
auto boolRst3 = std::get<std::vector<bool>>(rst.at("boolColumn3"));
auto boolRst4 = std::get<std::vector<bool>>(rst.at("boolColumn4"));

testVectorEq(int32Data, int32Rst);
testVectorEq(int64Data, int64Rst);
testVectorEq(uint32Data, uint32Rst);
testVectorEq(boolData1, boolRst1);
testVectorEq(boolData2, boolRst2);
testVectorEq(boolData3, boolRst3);
testVectorEq(boolData4, boolRst4);
}
} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization

0 comments on commit 2f51aa1

Please sign in to comment.