Skip to content

Commit

Permalink
Group aware GPU weighted sketching.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 19, 2020
1 parent d6d1035 commit 19abc7b
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 110 deletions.
178 changes: 108 additions & 70 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,28 +138,62 @@ void GetColumnSizesScan(int device,
* \param column_sizes_scan Describes the boundaries of column segments in
* sorted data
*/
void ExtractCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
Span<size_t> column_sizes_scan) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;

Span<Entry> column_entries =
Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);

size_t rank = (column_entries.size() * cut_idx) / num_available_cuts;
auto value = column_entries[rank].fvalue;
cuts[idx] = SketchEntry(rank, rank + 1, 1, value);
size_t rank = (column_entries.size() * cut_idx) /
static_cast<float>(num_available_cuts);
if (cut_idx == num_available_cuts - 1) {
// Because the cut values represent upper bounds of bins, we need to place special
// treatment for last cut. When a data equals to the greates value of cuts, last
// cut value of the same column is retured.
rank = column_entries.size() - 1;
}
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
column_entries[rank].fvalue);
});
}

void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());

dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);

dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}

/**
* \brief Extracts the cuts from sorted data, considering weights.
*
Expand All @@ -170,26 +204,27 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
size_t num_cuts_per_feature, Span<Entry> sorted_data,
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan) {
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(size_t(num_cuts_per_feature), column_size);
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;

Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
Span<float> column_weights =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);

float total_column_weight = column_weights.back();
Span<float> column_weights_scan =
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
float total_column_weight = column_weights_scan.back();
size_t sample_idx = 0;
if (cut_idx == 0) {
// First cut
Expand All @@ -204,70 +239,68 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
} else {
bst_float rank = (total_column_weight * cut_idx) /
static_cast<float>(num_available_cuts);
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(),
column_weights.end(), rank) -
column_weights.begin() - 1;
sample_idx = thrust::upper_bound(thrust::seq,
column_weights_scan.begin(),
column_weights_scan.end(),
rank) -
column_weights_scan.begin();
sample_idx =
max(size_t(0), min(sample_idx, column_entries.size() - 1));
max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1));
}
// repeated values will be filtered out on the CPU
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0;
bst_float rmax = column_weights[sample_idx];
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0;
bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
column_entries[sample_idx].fvalue);
});
}

void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());

dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);

dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});

// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}

void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
size_t num_columns,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);

// Binary search to assign weights to each element
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
size_t base_rowid = page.base_rowid;
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid];
});
if (is_ranking) {
CHECK_GE(d_group_ptr.size(), 1)
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
auto it =
thrust::upper_bound(thrust::seq,
d_group_ptr.cbegin(), d_group_ptr.cend(),
ridx + base_rowid) - 1;
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
d_temp_weights[idx] = weights[group];
});
} else {
CHECK_EQ(weights.size(), page.offset.Size() - 1);
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
row_ptrs.end(), element_idx) -
row_ptrs.begin() - 1;
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}

// Sort
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());
Expand All @@ -288,11 +321,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,

// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractWeightedCuts(
device, {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{temp_weights.data().get(), temp_weights.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
ExtractWeightedCuts(device, num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
Expand Down Expand Up @@ -320,13 +353,17 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
auto const& info = dmat->Info();
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.size());
thrust::copy(info.group_ptr_.cbegin(), info.group_ptr_.cend(), groups.begin());
for (auto begin = 0ull; begin < batch_nnz;
begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat);
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container, num_cuts, dmat->Info().num_col_);
&sketch_container, num_cuts, dmat->Info().num_col_, is_ranking, dh::ToSpan(groups));
} else {
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts,
dmat->Info().num_col_);
Expand Down Expand Up @@ -383,9 +420,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,

// Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts,
{sorted_entries.data().get(), sorted_entries.size()},
{column_sizes_scan.data().get(), column_sizes_scan.size()});
ExtractCuts(adapter->DeviceIdx(), num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));

// Push cuts into sketches stored in host memory
thrust::host_vector<SketchEntry> host_cuts(cuts);
Expand Down
4 changes: 2 additions & 2 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ class HistogramCuts {
class CutsBuilder {
public:
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);

protected:
HistogramCuts* p_cuts_;
/* \brief return whether group for ranking is used. */
static bool UseGroup(DMatrix* dmat);

public:
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}
Expand Down
50 changes: 39 additions & 11 deletions tests/cpp/common/test_hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@

#include <algorithm>
#include <cmath>


#include <thrust/device_vector.h>

#include "xgboost/c_api.h"
#include <xgboost/data.h>
#include <xgboost/c_api.h>

#include "test_hist_util.h"
#include "../helpers.h"
#include "../data/test_array_interface.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.h"

#include "../helpers.h"
#include <xgboost/data.h>
#include "../../../src/data/device_adapter.cuh"
#include "../data/test_array_interface.h"
#include "../../../src/common/math.h"
#include "../../../src/data/simple_dmatrix.h"
#include "test_hist_util.h"
#include "../../../include/xgboost/logging.h"

namespace xgboost {
Expand Down Expand Up @@ -190,8 +187,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
}
}

TEST(HistUtil, AdapterDeviceSketch)
{
TEST(HistUtil, AdapterDeviceSketch) {
int rows = 5;
int cols = 1;
int num_bins = 4;
Expand Down Expand Up @@ -235,7 +231,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
}

TEST(HistUtil, AdapterDeviceSketchCategorical) {
TEST(HistUtil, AdapterDeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
int sizes[] = {25, 100, 1000};
Expand Down Expand Up @@ -268,6 +264,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
}
}
}

TEST(HistUtil, AdapterDeviceSketchBatches) {
int num_bins = 256;
int num_rows = 5000;
Expand Down Expand Up @@ -305,7 +302,38 @@ TEST(HistUtil, SketchingEquivalent) {
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
}
}
}

TEST(HistUtil, DeviceSketchFromGroupWeights) {
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
size_t constexpr kGroups = 10;
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
auto& h_weights = m->Info().weights_.HostVector();
h_weights.resize(kRows);
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
std::vector<bst_group_t> groups(kGroups);
for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups;
}
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);

h_weights.clear();
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);

ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size());

for (size_t i = 0; i < cuts.Values().size(); ++i) {
EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i;
}
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]);
}
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
}
}
} // namespace common
} // namespace xgboost
Loading

0 comments on commit 19abc7b

Please sign in to comment.