-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement incremental building of device DMatrix.
* Add new iterative DMatrix. * Add new proxy DMatrix. * Add dask interface.
- Loading branch information
1 parent
67d267f
commit 6c33130
Showing
25 changed files
with
1,117 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
'''A demo for defining data iterator. | ||
The demo that defines a customized iterator for passing batches of data into | ||
`xgboost.DMatrix` and use this `DMatrix` for training. | ||
Aftering going through the demo, one might ask why don't we use more native | ||
Python iterator? That's because XGBoost require a `reset` function, while | ||
using `itertools.tee` might incur significant memory usage according to: | ||
https://docs.python.org/3/library/itertools.html#itertools.tee. | ||
''' | ||
|
||
import xgboost | ||
import cupy | ||
import numpy | ||
|
||
COLS = 64 | ||
ROWS_PER_BATCH = 100 # data is splited by rows | ||
BATCHES = 16 | ||
|
||
|
||
class IterForDMatixDemo(xgboost.core.DataIter): | ||
'''A data iterator for XGBoost DMatrix. | ||
`reset` and `next` are required for any data iterator, other functions here | ||
are utilites for demonstration's purpose. | ||
''' | ||
def __init__(self): | ||
'''Generate some random data for demostration. | ||
Actual data can be anything that is currently supported by XGBoost. | ||
''' | ||
self.rows = ROWS_PER_BATCH | ||
self.cols = COLS | ||
rng = cupy.random.RandomState(1994) | ||
self._data = [rng.randn(self.rows, self.cols)] * BATCHES | ||
self._labels = [rng.randn(self.rows)] * BATCHES | ||
|
||
self.it = 0 # set iterator to 0 | ||
super().__init__() | ||
|
||
def as_array(self): | ||
return cupy.concatenate(self._data) | ||
|
||
def as_array_labels(self): | ||
return cupy.concatenate(self._labels) | ||
|
||
def data(self): | ||
'''Utility function for obtaining current batch of data.''' | ||
return self._data[self.it] | ||
|
||
def labels(self): | ||
'''Utility function for obtaining current batch of label.''' | ||
return self._labels[self.it] | ||
|
||
def reset(self): | ||
'''Reset the iterator''' | ||
self.it = 0 | ||
|
||
def next(self, input_data): | ||
'''Yield next batch of data''' | ||
if self.it == len(self._data): | ||
# Return 0 when there's no more batch. | ||
return 0 | ||
input_data(data=self.data(), label=self.labels()) | ||
self.it += 1 | ||
return 1 | ||
|
||
|
||
def main(): | ||
rounds = 100 | ||
it = IterForDMatixDemo() | ||
|
||
# Use iterator | ||
m = xgboost.DMatrix(it) | ||
reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m, | ||
num_boost_round=rounds) | ||
predict_with_it = reg_with_it.predict(m) | ||
|
||
# Without using iterator | ||
m = xgboost.DMatrix(it.as_array(), it.as_array_labels()) | ||
reg = xgboost.train({'tree_method': 'gpu_hist'}, m, | ||
num_boost_round=rounds) | ||
predict = reg.predict(m) | ||
|
||
numpy.testing.assert_allclose(predict_with_it, predict) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/*! | ||
* Copyright 2020 by Contributors | ||
* \file callback.h | ||
*/ | ||
#ifndef XGBOOST_C_CALLBACK_H_ | ||
#define XGBOOST_C_CALLBACK_H_ | ||
|
||
#include <xgboost/c/config.h> | ||
|
||
/*! \brief handle to a external data iterator */ | ||
typedef void *DataIterHandle; // NOLINT(*) | ||
/*! \brief handle to a internal data holder. */ | ||
typedef void *DataHolderHandle; // NOLINT(*) | ||
|
||
|
||
/*! \brief Mini batch used in XGBoost Data Iteration */ | ||
typedef struct { // NOLINT(*) | ||
/*! \brief number of rows in the minibatch */ | ||
size_t size; | ||
/* \brief number of columns in the minibatch. */ | ||
size_t columns; | ||
/*! \brief row pointer to the rows in the data */ | ||
#ifdef __APPLE__ | ||
/* Necessary as Java on MacOS defines jlong as long int | ||
* and gcc defines int64_t as long long int. */ | ||
long* offset; // NOLINT(*) | ||
#else | ||
int64_t* offset; // NOLINT(*) | ||
#endif // __APPLE__ | ||
/*! \brief labels of each instance */ | ||
float* label; | ||
/*! \brief weight of each instance, can be NULL */ | ||
float* weight; | ||
/*! \brief feature index */ | ||
int* index; | ||
/*! \brief feature values */ | ||
float* value; | ||
} XGBoostBatchCSR; | ||
|
||
/*! | ||
* \brief Callback to set the data to handle, | ||
* \param handle The handle to the callback. | ||
* \param batch The data content to be set. | ||
*/ | ||
XGB_EXTERN_C typedef int XGBCallbackSetData( // NOLINT(*) | ||
DataHolderHandle handle, XGBoostBatchCSR batch); | ||
|
||
/*! | ||
* \brief The data reading callback function. | ||
* The iterator will be able to give subset of batch in the data. | ||
* | ||
* If there is data, the function will call set_function to set the data. | ||
* | ||
* \param data_handle The handle to the callback. | ||
* \param set_function The batch returned by the iterator | ||
* \param set_function_handle The handle to be passed to set function. | ||
* \return 0 if we are reaching the end and batch is not returned. | ||
*/ | ||
XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*) | ||
DataIterHandle data_handle, XGBCallbackSetData *set_function, | ||
DataHolderHandle set_function_handle); | ||
|
||
/*! | ||
* \brief Create a DMatrix from a data iterator. | ||
* \param data_handle The handle to the data. | ||
* \param callback The callback to get the data. | ||
* \param cache_info Additional information about cache file, can be null. | ||
* \param out The created DMatrix | ||
* \return 0 when success, -1 when failure happens. | ||
*/ | ||
XGB_DLL int XGDMatrixCreateFromDataIter( | ||
DataIterHandle data_handle, | ||
XGBCallbackDataIterNext* callback, | ||
const char* cache_info, | ||
DMatrixHandle *out); | ||
|
||
/*! | ||
* \brief Callback function for getting next batch of data. | ||
*/ | ||
XGB_EXTERN_C typedef int XGDMatrixCallbackNext( // NOLINT(*) | ||
DataIterHandle iter, DMatrixHandle handle); | ||
|
||
/*! | ||
* \brief Callback function for reseting external iterator | ||
*/ | ||
XGB_EXTERN_C typedef void DataIterResetCallback(DataIterHandle); // NOLINT(*) | ||
#endif // XGBOOST_C_CALLBACK_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/*! | ||
* Copyright 2020 by Contributors | ||
* \file config.h | ||
*/ | ||
#ifndef XGBOOST_C_CONFIG_H_ | ||
#define XGBOOST_C_CONFIG_H_ | ||
|
||
#ifdef __cplusplus | ||
#define XGB_EXTERN_C extern "C" | ||
#include <cstdio> | ||
#include <cstdint> | ||
#else | ||
#define XGB_EXTERN_C | ||
#include <stdio.h> | ||
#include <stdint.h> | ||
#endif // __cplusplus | ||
|
||
#if defined(_MSC_VER) || defined(_WIN32) | ||
#define XGB_DLL XGB_EXTERN_C __declspec(dllexport) | ||
#else | ||
#define XGB_DLL XGB_EXTERN_C __attribute__ ((visibility ("default"))) | ||
#endif // defined(_MSC_VER) || defined(_WIN32) | ||
|
||
// manually define unsigned long | ||
typedef uint64_t bst_ulong; // NOLINT(*) | ||
|
||
/*! \brief handle to DMatrix */ | ||
typedef void *DMatrixHandle; // NOLINT(*) | ||
/*! \brief handle to Booster */ | ||
typedef void *BoosterHandle; // NOLINT(*) | ||
|
||
#endif // XGBOOST_C_CONFIG_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.