Skip to content

Commit

Permalink
Optimizations for quantisation on device (#4572)
Browse files Browse the repository at this point in the history
* - do not create device vectors for the entire sparse page while computing histograms...
   - while creating the compressed histogram indices, the row vector is created for the entire
     sparse page batch. this is needless as we only process chunks at a time based on a slice
     of the total gpu memory
   - this pr will allocate only as much as required to store the ppropriate row indices and the entries

* - do not dereference row_ptrs once the device_vector has been created to elide host copies of those counts
   - instead, grab the entry counts directly from the sparsepage
  • Loading branch information
sriramch authored and RAMitchell committed Jun 18, 2019
1 parent ba1d848 commit 6757654
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
43 changes: 18 additions & 25 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,6 @@ __global__ void CompressBinEllpackKernel(
const uint32_t* __restrict__ cut_rows, // HistCutMatrix::row_ptrs
size_t base_row, // batch_row_begin
size_t n_rows,
// row_ptr_begin: row_offset[base_row], the start position of base_row
size_t row_ptr_begin,
size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + blockIdx.x * blockDim.x;
Expand All @@ -495,7 +493,7 @@ __global__ void CompressBinEllpackKernel(
int row_length = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_length) {
Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
Entry entry = entries[row_ptrs[irow] - row_ptrs[0] + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
Expand Down Expand Up @@ -697,8 +695,6 @@ struct DeviceShard {
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
/*! \brief On-device feature set, only actually used on one of the devices */
dh::device_vector<int> feature_set_d;
dh::device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
/*! The row offset for this shard. */
Expand Down Expand Up @@ -1311,14 +1307,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
size_t row_stride = this->ellpack_matrix.row_stride;

const auto &offset_vec = row_batch.offset.ConstHostVector();
/*! \brief row offset in SparsePage (the input data). */
CHECK_LE(device_row_state.rows_to_process_from_batch, offset_vec.size());
dh::device_vector<size_t> row_ptrs(device_row_state.rows_to_process_from_batch+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch,
offset_vec.data() + device_row_state.row_offset_in_current_batch +
device_row_state.rows_to_process_from_batch + 1,
row_ptrs.begin());

int num_symbols = n_bins + 1;
// bin and compress entries in batches of rows
Expand All @@ -1327,7 +1315,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
static_cast<size_t>(device_row_state.rows_to_process_from_batch));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();

dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(device_row_state.rows_to_process_from_batch,
gpu_batch_nrows);

Expand All @@ -1339,35 +1326,41 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
}
size_t batch_nrows = batch_row_end - batch_row_begin;

const auto ent_cnt_begin =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_begin];
const auto ent_cnt_end =
offset_vec[device_row_state.row_offset_in_current_batch + batch_row_end];

/*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows+1);
thrust::copy(
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_begin,
offset_vec.data() + device_row_state.row_offset_in_current_batch + batch_row_end + 1,
row_ptrs.begin());

// number of entries in this batch.
size_t n_entries = row_ptrs[batch_row_end] - row_ptrs[batch_row_begin];
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda
(cudaMemcpy
(entries_d.data().get(), data_vec.data() + row_ptrs[batch_row_begin],
(entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(dh::DivRoundUp(device_row_state.rows_to_process_from_batch, block3.x),
const dim3 grid3(dh::DivRoundUp(batch_nrows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
CompressBinEllpackKernel<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols),
gidx_buffer.data(),
row_ptrs.data().get() + batch_row_begin,
row_ptrs.data().get(),
entries_d.data().get(),
gidx_fvalue_map.data(),
feature_segments.data(),
device_row_state.total_rows_processed + batch_row_begin,
batch_nrows,
row_ptrs[batch_row_begin],
row_stride,
null_gidx_value);
}

// free the memory that is no longer needed
row_ptrs.resize(0);
row_ptrs.shrink_to_fit();
entries_d.resize(0);
entries_d.shrink_to_fit();
}

// An instance of this type is created which keeps track of total number of rows to process,
Expand Down
40 changes: 28 additions & 12 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}

TEST(GpuHist, TestHistogramIndex) {
void TestHistogramIndexImpl(int n_gpus) {
// Test if the compressed histogram index matches when using a sparse
// dmatrix with and without using external memory

Expand All @@ -491,31 +491,47 @@ TEST(GpuHist, TestHistogramIndex) {
CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true));

std::vector<std::pair<std::string, std::string>> training_params = {
{"max_depth", "1"},
{"max_depth", "10"},
{"max_leaves", "0"}
};

LearnerTrainParam learner_param(CreateEmptyGenericParam(0, 1));
LearnerTrainParam learner_param(CreateEmptyGenericParam(0, n_gpus));
hist_maker.Init(training_params, &learner_param);
hist_maker.InitDataOnce(hist_maker_dmat.get());
hist_maker_ext.Init(training_params, &learner_param);
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());

ASSERT_EQ(hist_maker.shards_.size(), hist_maker_ext.shards_.size());

// Extract the device shards from the histogram makers and from that its compressed
// histogram index
const auto &dev_shard = hist_maker.shards_[0];
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
for (size_t i = 0; i < hist_maker.shards_.size(); ++i) {
const auto &dev_shard = hist_maker.shards_[i];
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);

const auto &dev_shard_ext = hist_maker_ext.shards_[i];
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);

ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());

const auto &dev_shard_ext = hist_maker_ext.shards_[0];
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}
}

ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl(1);
}

ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
#if defined(XGBOOST_USE_NCCL)
TEST(GpuHist, MGPU_TestHistogramIndex) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestHistogramIndexImpl(-1);
}
#endif

} // namespace tree
} // namespace xgboost

0 comments on commit 6757654

Please sign in to comment.