Skip to content

Commit

Permalink
Add doc for C.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 21, 2022
1 parent 1ecebd1 commit 7ddbe80
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
24 changes: 18 additions & 6 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,17 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
char const *indices, char const *data,
bst_ulong ncol);


XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema);

XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out);

/*
* ==========================- End data callback APIs ==========================
*/


XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);

/*!
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
* and subject to change in the future.
*
* \param next Callback function for fetching arrow records.
* \param json_config JSON encoded configuration. Required values are:
*
* - missing
* - nthread
*
* \param out The created DMatrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out);

/*!
* \brief create a new dmatrix from sliced content of existing matrix
Expand Down
9 changes: 4 additions & 5 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,15 @@ class RecordBatchDataIter:

def __init__(self, data_iter: Iterator) -> None:
from pyarrow.cffi import ffi

self.data_iter = data_iter # an iterator for Arrow record batches
self.c_schemas: List[ffi.CData] = []
self.c_arrays: List[ffi.CData] = []

def reset(self) -> None: # pylint: disable=missing-function-docstring
raise NotImplementedError()

# pylint: disable=missing-function-docstring
def next(self, data_handle: int) -> int:
"Fetch the next batch."
from pyarrow.cffi import ffi

try:
batch = next(self.data_iter)
self.c_schemas.append(ffi.new("struct ArrowSchema*"))
Expand All @@ -502,7 +501,7 @@ def next(self, data_handle: int) -> int:
# pylint: disable=protected-access
batch._export_to_c(ptr_array, ptr_schema)
_check_call(
_LIB.XGImportRecordBatch(
_LIB.XGImportArrowRecordBatch(
ctypes.c_void_p(data_handle),
ctypes.c_void_p(ptr_array),
ctypes.c_void_p(ptr_schema),
Expand Down
9 changes: 5 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,12 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
API_END();
}

XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema) {
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
void *ptr_schema) {
API_BEGIN();
static_cast<data::RecordBatchesIterAdapter *>(data_handle)->SetData(
static_cast<struct ArrowArray *>(ptr_array),
static_cast<struct ArrowSchema *>(ptr_schema));
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
->SetData(static_cast<struct ArrowArray *>(ptr_array),
static_cast<struct ArrowSchema *>(ptr_schema));
API_END();
}

Expand Down

0 comments on commit 7ddbe80

Please sign in to comment.