Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processor Interface #10073

Closed
wants to merge 13 commits into from
141 changes: 141 additions & 0 deletions src/processing/plugins/dummy_processor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#include <iostream>
#include <cstring>
#include "./dummy_processor.h"

using std::vector;
using std::cout;
using std::endl;

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int64_t kPrefixLen = 24;

bool ValidDam(void *buffer, size_t size) {
return size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0;
}

void* DummyProcessor::ProcessGHPairs(size_t &size, std::vector<double>& pairs) {
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl;

size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = size;
// This memory needs to be freed
char *buf = static_cast<char *>(calloc(size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeGHPairs, 8);

// Simulate encryption by duplicating value 10 times
int index = kPrefixLen;
for (auto value : pairs) {
for (int i = 0; i < 10; i++) {
memcpy(buf+index, &value, 8);
index += 8;
}
}

// Save pairs for future operations
this->gh_pairs_ = new vector<double>(pairs);

return buf;
}


void* DummyProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) {
cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl;

size = buf_size;
if (!ValidDam(buffer, size)) {
cout << "Invalid buffer received" << endl;
return buffer;
}

// For dummy, this call is used to set gh_pairs for passive sites
if (!active_) {
int8_t *ptr = static_cast<int8_t *>(buffer);
ptr += kPrefixLen;
double *pairs = reinterpret_cast<double *>(ptr);
size_t num = (buf_size - kPrefixLen) / 8;
gh_pairs_ = new vector<double>();
for (int i = 0; i < num; i += 10) {
gh_pairs_->push_back(pairs[i]);
}
cout << "GH Pairs saved. Size: " << gh_pairs_->size() << endl;
}

return buffer;
}

void *DummyProcessor::ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) {
auto total_bin_size = cuts_.back();
auto histo_size = total_bin_size*2;
size = kPrefixLen + 8*histo_size*nodes.size();
int64_t buf_size = size;
cout << "ProcessAggregation called with bin size: " << total_bin_size << " Buffer Size: " << buf_size << endl;
std::int8_t *buf = static_cast<std::int8_t *>(calloc(buf_size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeHisto, 8);

double *histo = reinterpret_cast<double *>(buf + kPrefixLen);
for ( const auto &node : nodes ) {
auto rows = node.second;
for (const auto &row_id : rows) {

auto num = cuts_.size() - 1;
for (std::size_t f = 0; f < num; f++) {
auto slot = slots_[f + num*row_id];
if (slot < 0) {
continue;
}

if (slot >= total_bin_size) {
cout << "Slot too big, ignored: " << slot << endl;
continue;
}

if (row_id >= gh_pairs_->size()/2) {
cout << "Row ID too big: " << row_id << endl;
}

auto g = (*gh_pairs_)[row_id*2];
auto h = (*gh_pairs_)[row_id*2+1];
histo[slot*2] += g;
histo[slot*2+1] += h;
}
}
histo += histo_size;
}

return buf;
}

std::vector<double> DummyProcessor::HandleAggregation(void *buffer, size_t buf_size) {
cout << "HandleAggregation called with buffer size: " << buf_size << endl;
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr, rest_size)) {
cout << "Invalid buffer at offset " << buf_size - rest_size << endl;
continue;
}
std::int64_t *size_ptr = reinterpret_cast<std::int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen)/8;
cout << "Histo size for buffer: " << array_size << endl;
result.insert(result.end(), array_start, array_start + array_size);
cout << "Result size: " << result.size() << endl;
rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

cout << "Total histo size: " << result.size() << endl;

return result;
}
56 changes: 56 additions & 0 deletions src/processing/plugins/dummy_processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include "../processor.h"

// Data type definition
const int64_t kDataTypeGHPairs = 1;
const int64_t kDataTypeHisto = 2;

class DummyProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_{nullptr};
std::vector<double> *gh_pairs_{nullptr};
std::vector<uint32_t> cuts_;
std::vector<int> slots_;

public:
void Initialize(bool active, std::map<std::string, std::string> params) override {
this->active_ = active;
this->params_ = &params;
}

void Shutdown() override {
this->gh_pairs_ = nullptr;
this->cuts_.clear();
this->slots_.clear();
}

void FreeBuffer(void *buffer) override {
free(buffer);
}

void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) override;

void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) override {
std::cout << "InitAggregationContext called with cuts size: " << cuts.size()-1 <<
" number of slot: " << slots.size() << std::endl;
this->cuts_ = cuts;
if (this->slots_.empty()) {
this->slots_ = slots;
} else {
std::cout << "Multiple calls to InitAggregationContext" << std::endl;
}
}

void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;
};
110 changes: 110 additions & 0 deletions src/processing/processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#pragma once

#include <map>
#include <any>
#include <string>
#include <vector>

namespace processing {

const char kLibraryPath[] = "LIBRARY_PATH";
const char kDummyProcessor[] = "dummy";
const char kLoadFunc[] = "LoadProcessor";

/*! \brief An processor interface to handle tasks that require external library through plugins */
class Processor {
public:
/*!
* \brief Initialize the processor
*
* \param active If true, this is the active node
* \param params Optional parameters
*/
virtual void Initialize(bool active, std::map<std::string, std::string> params) = 0;

/*!
* \brief Shutdown the processor and free all the resources
*
*/
virtual void Shutdown() = 0;

/*!
* \brief Free buffer
*
* \param buffer Any buffer returned by the calls from the plugin
*/
virtual void FreeBuffer(void* buffer) = 0;

/*!
* \brief Preparing g & h pairs to be sent to other clients by active client
*
* \param size The size of the buffer
* \param pairs g&h pairs in a vector (g1, h1, g2, h2 ...) for every sample
*
* \return The encoded buffer to be sent
*/
virtual void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) = 0;

/*!
* \brief Handle buffers with encoded pairs received from broadcast
*
* \param size Output buffer size
* \param The encoded buffer
* \param The encoded buffer size
*
* \return The encoded buffer
*/
virtual void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) = 0;

/*!
* \brief Initialize aggregation context by providing global GHistIndexMatrix
*
* \param cuts The cut point for each feature
* \param slots The slot assignment in a flattened matrix for each feature/row. The size is num_feature*num_row
*/
virtual void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) = 0;

/*!
* \brief Prepare row set for aggregation
*
* \param size The output buffer size
* \param nodes Map of node and the rows belong to this node
*
* \return The encoded buffer to be sent via AllGather
*/
virtual void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) = 0;

/*!
* \brief Handle all gather result
*
* \param buffer Buffer from all gather, only buffer from active site is needed
* \param buf_size The size of the buffer
*
* \return A flattened vector of histograms for each site, each node in the form of
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3
*/
virtual std::vector<double> HandleAggregation(void *buffer, size_t buf_size) = 0;
};

class ProcessorLoader {
private:
std::map<std::string, std::string> params;
void *handle = NULL;


public:
ProcessorLoader(): params{} {}

ProcessorLoader(std::map<std::string, std::string>& params): params(params) {}

Processor* load(const std::string& plugin_name);

void unload();
};

} // namespace processing

extern processing::Processor *processor_instance;
62 changes: 62 additions & 0 deletions src/processing/processor_loader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#include <iostream>
#include <dlfcn.h>

#include "./processor.h"
#include "plugins/dummy_processor.h"

namespace processing {
using LoadFunc = Processor *(const char *);

Processor* ProcessorLoader::load(const std::string& plugin_name) {
// Dummy processor for unit testing without loading a shared library
if (plugin_name == kDummyProcessor) {
return new DummyProcessor();
}

auto lib_name = "libproc_" + plugin_name;

auto extension =
#if defined(__APPLE__) || defined(__MACH__)
".dylib";
#else
".so";
#endif
auto lib_file_name = lib_name + extension;

std::string lib_path;

if (params.find(kLibraryPath) == params.end()) {
lib_path = lib_file_name;
} else {
auto p = params[kLibraryPath];
if (p.back() != '/') {
p += '/';
}
lib_path = p + lib_file_name;
}

handle = dlopen(lib_path.c_str(), RTLD_LAZY);
if (!handle) {
std::cerr << "Failed to load the dynamic library: " << dlerror() << std::endl;
return NULL;
}

void* func_ptr = dlsym(handle, kLoadFunc);

if (!func_ptr) {
std::cerr << "Failed to find loader function: " << dlerror() << std::endl;
return NULL;
}

auto func = reinterpret_cast<LoadFunc *>(func_ptr);

return (*func)(plugin_name.c_str());
}

void ProcessorLoader::unload() {
dlclose(handle);
}
} // namespace processing
Loading