diff --git a/src/processing/plugins/dummy_processor.cc b/src/processing/plugins/dummy_processor.cc index fbbf523141c5..0e060f05b7d5 100644 --- a/src/processing/plugins/dummy_processor.cc +++ b/src/processing/plugins/dummy_processor.cc @@ -8,7 +8,11 @@ using std::cout; using std::endl; const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 -const int kPrefixLen = 24; +const int64_t kPrefixLen = 24; + +bool ValidDam(std::int8_t *buffer) { + return memcmp(buffer, kSignature, strlen(kSignature)) == 0; +} xgboost::common::Span DummyProcessor::ProcessGHPairs(vector &pairs) { cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; @@ -31,13 +35,18 @@ xgboost::common::Span DummyProcessor::ProcessGHPairs(vector &pai } // Save pairs for future operations - this->gh_pairs_ = &pairs; + this->gh_pairs_ = new vector(pairs); return xgboost::common::Span(reinterpret_cast(buf), buf_size); } xgboost::common::Span DummyProcessor::HandleGHPairs(xgboost::common::Span buffer) { - cout << "HandleGHPairs called with buffer size: " << buffer.size() << endl; + cout << "HandleGHPairs called with buffer size: " << buffer.size() << " Active: " << active_ << endl; + + if (!ValidDam(buffer.data())) { + cout << "Invalid buffer received" << endl; + return buffer; + } // For dummy, this call is used to set gh_pairs for passive sites if (!active_) { @@ -48,6 +57,7 @@ xgboost::common::Span DummyProcessor::HandleGHPairs(xgboost::common::Spa for (int i = 0; i < num; i += 10) { gh_pairs_->push_back(pairs[i]); } + cout << "GH Pairs saved. Size: " << gh_pairs_->size() << endl; } return buffer; @@ -58,6 +68,7 @@ xgboost::common::Span DummyProcessor::ProcessAggregation( auto total_bin_size = gidx_->Cuts().Values().size(); auto histo_size = total_bin_size*2; auto buf_size = kPrefixLen + 8*histo_size*nodes_to_build.size(); + cout << "ProcessAggregation called with bin size: " << total_bin_size << " Buffer Size: " << buf_size << endl; std::int8_t *buf = static_cast(calloc(buf_size, 1)); memcpy(buf, kSignature, strlen(kSignature)); memcpy(buf + 8, &buf_size, 8); @@ -74,6 +85,15 @@ xgboost::common::Span DummyProcessor::ProcessAggregation( continue; } + if (slot >= total_bin_size) { + cout << "Slot too big, ignored: " << slot << endl; + continue; + } + + if (row_id >= gh_pairs_->size()/2) { + cout << "Row ID too big: " << row_id << endl; + } + auto g = (*gh_pairs_)[row_id*2]; auto h = (*gh_pairs_)[row_id*2+1]; histo[slot*2] += g; @@ -86,17 +106,29 @@ xgboost::common::Span DummyProcessor::ProcessAggregation( return xgboost::common::Span(reinterpret_cast(buf), buf_size); } -std::vector DummyProcessor::HandleAggregation(std::vector> buffers) { +std::vector DummyProcessor::HandleAggregation(xgboost::common::Span buffer) { + cout << "HandleAggregation called with buffer size: " << buffer.size() << endl; std::vector result = std::vector(); - for (auto buf : buffers) { - int8_t *ptr = buf.data(); + int8_t* ptr = buffer.data(); + auto rest_size = buffer.size(); + + while (rest_size > kPrefixLen) { + if (!ValidDam(ptr)) { + cout << "Invalid buffer at offset " << buffer.size() - rest_size << endl; + continue; + } std::int64_t *size_ptr = reinterpret_cast(ptr + 8); double *array_start = reinterpret_cast(ptr + kPrefixLen); - auto array_size = (*size_ptr - kPrefixLen) / 8; + auto array_size = (*size_ptr - kPrefixLen)/8; + cout << "Histo size for buffer: " << array_size << endl; result.insert(result.end(), array_start, array_start + array_size); + cout << "Result size: " << result.size() << endl; + rest_size -= *size_ptr; + ptr = ptr + *size_ptr; } + + cout << "Total histo size: " << result.size() << endl; return result; } - diff --git a/src/processing/plugins/dummy_processor.h b/src/processing/plugins/dummy_processor.h index dc1d937ba4d0..9511cf7f56f6 100644 --- a/src/processing/plugins/dummy_processor.h +++ b/src/processing/plugins/dummy_processor.h @@ -40,5 +40,5 @@ class DummyProcessor: public xgboost::processing::Processor { xgboost::common::Span ProcessAggregation(std::vector const &nodes_to_build, xgboost::common::RowSetCollection const &row_set) override; - std::vector HandleAggregation(std::vector> buffers) override; + std::vector HandleAggregation(xgboost::common::Span buffer) override; }; diff --git a/src/processing/processor.h b/src/processing/processor.h index effc52c6ff7d..beaabb89c5bd 100644 --- a/src/processing/processor.h +++ b/src/processing/processor.h @@ -17,8 +17,8 @@ const char kDummyProcessor[] = "dummy"; const char kLoadFunc[] = "LoadProcessor"; // Data type definition -const int kDataTypeGHPairs = 1; -const int kDataTypeHisto = 2; +const int64_t kDataTypeGHPairs = 1; +const int64_t kDataTypeHisto = 2; /*! \brief An processor interface to handle tasks that require external library through plugins */ class Processor { @@ -82,12 +82,12 @@ class Processor { /*! * \brief Handle all gather result * - * \param buffers Buffer from all gather, only buffer from active site is needed + * \param buffer Buffer from all gather, only buffer from active site is needed * * \return A flattened vector of histograms for each site, each node in the form of * site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3 */ - virtual std::vector HandleAggregation(std::vector> buffers) = 0; + virtual std::vector HandleAggregation(xgboost::common::Span buffer) = 0; }; class ProcessorLoader {