diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 3d910235ea27..d74d1e33a2e5 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -16,10 +16,11 @@ #include #include #include + #include "./base.h" + #include "../../src/common/span.h" #include "../../src/common/group_data.h" - #include "../../src/common/host_device_vector.h" namespace xgboost { @@ -34,6 +35,17 @@ enum DataType { kUInt64 = 4 }; +typedef unsigned char foreign_valid_type; +typedef int foreign_size_type; + +struct ForeignColumn { + void * data; + foreign_valid_type * valid; + foreign_size_type size; + foreign_size_type num_nonzero; + foreign_size_type null_count; +}; + /*! * \brief Meta information about dataset, always sit in memory. */ @@ -122,6 +134,13 @@ class MetaInfo { * \param num Number of elements in the source array. */ void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num); + /*! + * \brief Set information in the meta info for foreign columns buffer. + * \param key The key of the information. + * \param cols The foreign columns buffer used to set the info. + * \param n_cols The number of foreign columns. + */ + void SetInfo(const char * key, ForeignColumn ** cols, foreign_size_type n_cols); private: /*! \brief argsort of labels */ @@ -151,6 +170,14 @@ struct Entry { } }; +struct ForeignCSR { + Entry * data; + size_t * offsets; + size_t n_nonzero; + size_t n_cols; + size_t n_rows; +}; + /*! * \brief In-memory storage unit of sparse batch, stored in CSR format. */ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a2389e7fab2d..ec383ecdb6e1 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -526,6 +526,33 @@ def _init_from_dt(self, data, nthread): nthread)) self.handle = handle + def _init_from_columnar(df): + '''Initialize DMatrix from columnar memory format. For now assuming + it's cudf.DataFrame. + + ''' + print('_init_from_columnar') + col_ptrs = [ctypes.c_void_p(df[col]._column._pointer) for col in + df.columns] + validity_masks = [] + for col in df.columns: + if df[col].has_null_mask: + validity_masks.append(df[col].nullmask) + else: + validity_masks.append(False) + col_pairs = list(zip(col_ptrs, validity_masks)) + interfaces = [] + for pointers in col_pairs: + col = {'data': (pointers[0], False)} + if pointers[1] is not False: + col['mask'] = pointers[1].mem.__cuda_array_interface__ + else: + col['mask'] = '' + interfaces.append(str(col)) + print(col) + interfaces = from_pystr_to_cstr(interfaces) + return interfaces + def __del__(self): if hasattr(self, "handle") and self.handle is not None: _check_call(_LIB.XGDMatrixFree(self.handle)) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 601aad994f46..31021c69ab68 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -189,6 +190,29 @@ int XGDMatrixCreateFromDataIter( API_END(); } +int XGDMatrixCreateFromForeignColumns(ForeignColumn ** cols, + foreign_size_type n_cols, + DMatrixHandle * out) { + API_BEGIN(); + std::unique_ptr source(new data::SimpleCSRSource()); + source->CopyFrom(cols, n_cols); + *out = new std::shared_ptr(DMatrix::Create(std::move(source))); + API_END(); +} + +XGB_DLL int XGDMatrixCreateFromArrayInterface(char** c_json_strs, size_t n_columns, DMatrixHandle* out) { + API_BEGIN(); + std::vector json_strs; + for (size_t i = 0; i < n_columns; ++i) { + json_strs.emplace_back(c_json_strs[i]); + } + std::vector interfaces(n_columns); + for (size_t i = 0; i < n_columns; ++i) { + interfaces[i] = Json::Load({json_strs[i].c_str(), json_strs[i].size()}); + } + API_END(); +} + XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr, const unsigned* indices, const bst_float* data, @@ -689,6 +713,17 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, API_END(); } +XGB_DLL int XGDMatrixSetForeignInfo(DMatrixHandle handle, + const char * field, + ForeignColumn ** cols, + foreign_size_type n_cols) { + API_BEGIN(); + CHECK_HANDLE(); + static_cast*>(handle) + ->get()->Info().SetInfo(field, cols, n_cols); + API_END(); +} + XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char* field, const unsigned* info, diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 06bbb0f24f75..21ef21e3e108 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -470,4 +470,4 @@ size_t DeviceSketch } } // namespace common -} // namespace xgboost +} // namespace xgboost \ No newline at end of file diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index dad344806157..ed0a076e3136 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -623,4 +623,4 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; -} // namespace xgboost +} // namespace xgboost \ No newline at end of file diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index 0ffe75f5ef9c..47f9da1b1333 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -1,7 +1,6 @@ /*! * Copyright 2017-2019 XGBoost contributors */ - /** * @file host_device_vector.h * @brief A device-and-host vector abstraction layer. diff --git a/src/data/data.cc b/src/data/data.cc index dc013c767f54..188c0d0924e2 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -146,7 +146,6 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t } } - DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, diff --git a/src/data/data.cu b/src/data/data.cu new file mode 100644 index 000000000000..28e94057ca73 --- /dev/null +++ b/src/data/data.cu @@ -0,0 +1,67 @@ +/*! + * Copyright 2019 by XGBoost Contributors + * \file data.cuh + * \brief An extension for the data interface to support foreign columnar data buffers + This file adds the necessary functions to fill the meta information for the columnar buffers + * \author Andrey Adinets + * \author Matthew Jones + */ +#include +#include + +#include "../common/device_helpers.cuh" +#include "../common/host_device_vector.h" + +namespace xgboost { + +__global__ void ReadColumn(ForeignColumn * col, + foreign_size_type n_cols, + float * data) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + foreign_size_type n_rows = col->size; + if (n_rows <= tid) { + return; + } else { + float * d = (float *) (col->data); + data[n_cols * tid] = float(d[tid]); + } +} + +void SetInfoFromForeignColumns(MetaInfo * info, + const char * key, + ForeignColumn ** cols, + foreign_size_type n_cols) { + CHECK_GT(n_cols, 0); + foreign_size_type n_rows = cols[0]->size; + for (foreign_size_type i = 0; i < n_cols; ++i) { + CHECK_EQ(n_rows, cols[i]->size) << "all foreign columns must be the same size"; + CHECK_EQ(cols[i]->null_count, 0) << "all labels and weights must be valid"; + } + HostDeviceVector * field = nullptr; + if(!strcmp(key, "label")) { + field = &info->labels_; + } else if (!strcmp(key, "weight")) { + CHECK_EQ(n_cols, 1) << "only one foreign column permitted for weights"; + field = &info->weights_; + } else { + LOG(WARNING) << key << ": invalid key value for MetaInfo field"; + } + + GPUSet devices = GPUSet::Range(0, 1); + field->Reshard(GPUDistribution::Granular(devices, n_cols)); + field->Resize(n_cols * n_rows); + bst_float * data = field->DevicePointer(0); + + int threads = 1024; + int blocks = (n_rows + threads - 1) / threads; + + for (foreign_size_type i = 0; i < n_cols; ++i) { + ReadColumn <<>> (cols[i], n_cols, data + i); + dh::safe_cuda(cudaGetLastError()); + } +} + +void MetaInfo::SetInfo(const char * key, ForeignColumn ** cols, foreign_size_type n_cols) { + SetInfoFromForeignColumns(this, key, cols, n_cols); +} +} // namespace xgboost \ No newline at end of file diff --git a/src/data/simple_csr_source.cu b/src/data/simple_csr_source.cu new file mode 100644 index 000000000000..b12da6a510ff --- /dev/null +++ b/src/data/simple_csr_source.cu @@ -0,0 +1,127 @@ +/*! + * Copyright 2019 by XGBoost Contributors + * \file simple_csr_source.cuh + * \brief An extension for the simple CSR source in-memory data structure to accept + foreign columnar data buffers, and convert them to XGBoost's internal DMatrix + * \author Andrey Adinets + * \author Matthew Jones + */ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "simple_csr_source.h" + +namespace xgboost { +namespace data { + +__device__ int which_bit (int bit) { + return bit % 8; +} +__device__ int which_bitmap (int record) { + return record / 8; +} + +__device__ int check_bit (foreign_valid_type bitmap, int bid) { + foreign_valid_type bitmask[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return bitmap & bitmask[bid]; +} + +__device__ bool is_valid(foreign_valid_type * valid, int tid) { + if (valid == nullptr) { + return true; + } + int bmid = which_bitmap(tid); + int bid = which_bit(tid); + foreign_valid_type bitmap = valid[bmid]; + return check_bit(bitmap, bid); +} + +__global__ void CountValid(foreign_valid_type * valid, + foreign_size_type n_rows, + foreign_size_type n_cols, + size_t * offsets) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + if (n_rows <= tid) { + return; + } else if (is_valid(valid, tid)) { + ++offsets[tid]; + } +} + +__global__ void CreateCSR(ForeignColumn * col, int col_idx, ForeignCSR * csr) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + if (col->size <= tid) { + return; + } else if (is_valid(col->valid, tid)) { + foreign_size_type oid = csr->offsets[tid]; + float * d = (float *) (col->data); + csr->data[oid].fvalue = float(d[tid]); + csr->data[oid].index = col_idx; + ++csr->offsets[tid]; + } +} + +void ForeignColsToCSR(ForeignColumn ** cols, foreign_size_type n_cols, ForeignCSR * csr) { + foreign_size_type n_rows = cols[0]->size; + int threads = 1024; + int blocks = (n_rows + threads - 1) / threads; + + dh::safe_cuda(cudaMemset(csr->offsets, 0, sizeof(foreign_size_type) * (n_rows + 1))); + if (0 < blocks) { + for (foreign_size_type i = 0 ; i < n_cols; ++i) { + CountValid <<>> (cols[i]->valid, n_rows, n_cols, csr->offsets); + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); + } + + thrust::device_ptr offsets(csr->offsets); + int64_t n_valid = thrust::reduce(offsets, offsets + n_rows, 0ull, thrust::plus()); + thrust::exclusive_scan(offsets, offsets + n_rows + 1, offsets); + + csr->n_nonzero = n_valid; + csr->n_rows = n_rows; + csr->n_cols = n_cols; + + for (foreign_size_type i = 0; i < n_cols; ++i) { + CreateCSR <<>> (cols[i], i, csr); + } + } +} + +void SimpleCSRSource::CopyFrom(ForeignColumn ** cols, foreign_size_type n_cols) { + CHECK_GT(n_cols, 0); + foreign_size_type n_valid = 0; + for (foreign_size_type i = 0; i < n_cols; ++i) { + CHECK_EQ(cols[0]->size, cols[i]->size); + n_valid += cols[i]->size - cols[i]->null_count; + } + + info.num_col_ = n_cols; + info.num_row_ = cols[0]->size; + info.num_nonzero_ = n_valid; + + GPUSet devices = GPUSet::Range(0, 1); + page_.offset.Reshard(GPUDistribution::Overlap(devices, 1)); + page_.offset.Resize(cols[0]->size + 1); + + std::vector device_offsets{0, (size_t) n_valid}; + page_.data.Reshard(GPUDistribution::Explicit(devices, device_offsets)); + page_.data.Reshard(GPUDistribution::Overlap(devices, 1)); + page_.data.Resize(n_valid); + + ForeignCSR csr; + csr.data = page_.data.DevicePointer(0); + csr.offsets = page_.offset.DevicePointer(0); + + ForeignColsToCSR(cols, n_cols, &csr); +} + +} // namespace data +} // namespace xgboost diff --git a/src/data/simple_csr_source.h b/src/data/simple_csr_source.h index 22c87d2681ec..9a657a1c990b 100644 --- a/src/data/simple_csr_source.h +++ b/src/data/simple_csr_source.h @@ -35,7 +35,7 @@ class SimpleCSRSource : public DataSource { /*! \brief destructor */ ~SimpleCSRSource() override = default; /*! \brief clear the data structure */ - void Clear(); + void Clear(); /*! * \brief copy content of data from src * \param src source data iter. @@ -47,6 +47,12 @@ class SimpleCSRSource : public DataSource { * \param info The additional information reflected in the parser. */ void CopyFrom(dmlc::Parser* src); + /*! + * \brief copy content of data from foreign columns buffer. + * \param cols foreign columns data buffer. + * \param n_cols the number of foreign columns. + */ + void CopyFrom(ForeignColumn ** cols, foreign_size_type n_cols); /*! * \brief Load data from binary stream. * \param fi the pointer to load data from. diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e1e236945627..b51d20aa5991 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1549,4 +1549,4 @@ XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") #endif // !defined(GTEST_TEST) } // namespace tree -} // namespace xgboost +} // namespace xgboost \ No newline at end of file diff --git a/tests/python-gpu/test_gpu_gdf.py b/tests/python-gpu/test_gpu_gdf.py new file mode 100644 index 000000000000..1a3dbdc61dd3 --- /dev/null +++ b/tests/python-gpu/test_gpu_gdf.py @@ -0,0 +1,44 @@ +import numpy as np +import pandas as pd +try: + import cudf.dataframe as gdf +except ImportError as e: + print("Failed to import cuDF: " + str(e)) + print("Skipping this test") +from sklearn import datasets +import sys +import unittest +import xgboost as xgb + +from regression_test_utilities import run_suite, parameter_combinations, \ + assert_results_non_increasing, Dataset + + +def get_gdf(): + rng = np.random.RandomState(199) + n = 50000 + m = 20 + sparsity = 0.25 + X, y = datasets.make_regression(n, m, random_state=rng) + Xy = (np.ascontiguousarray + (np.transpose(np.concatenate((X, np.expand_dims(y, axis=1)), axis=1)))) + df = gdf.DataFrame(list(zip(['col%d' % i for i in range(m+1)], Xy))) + all_columns = list(df.columns) + cols_X = all_columns[0:len(all_columns)-1] + cols_y = [all_columns[len(all_columns)-1]] + return df[cols_X], df[cols_y] + + +class TestGPU(unittest.TestCase): + + gdf_datasets = [Dataset("GDF", get_gdf, "reg:linear", "rmse")] + + def test_gdf(self): + variable_param = {'n_gpus': [1], 'max_depth': [10], 'max_leaves': [255], + 'max_bin': [255], + 'grow_policy': ['lossguide']} + for param in parameter_combinations(variable_param): + param['tree_method'] = 'gpu_hist' + gpu_results = run_suite(param, num_rounds=20, + select_datasets=self.gdf_datasets) + assert_results_non_increasing(gpu_results, 1e-2)