diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 4dff55cfcdb5..17cd5f4af36d 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -502,17 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, char const *indices, char const *data, bst_ulong ncol); - -XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema); - -XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config, - DMatrixHandle *out); - /* * ==========================- End data callback APIs ========================== */ +XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema); + +/*! + * \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable + * and subject to change in the future. + * + * \param next Callback function for fetching arrow records. + * \param json_config JSON encoded configuration. Required values are: + * + * - missing + * - nthread + * + * \param out The created DMatrix. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config, + DMatrixHandle *out); /*! * \brief create a new dmatrix from sliced content of existing matrix diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index f864f30f4203..7a07be4d2d20 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -483,16 +483,15 @@ class RecordBatchDataIter: def __init__(self, data_iter: Iterator) -> None: from pyarrow.cffi import ffi + self.data_iter = data_iter # an iterator for Arrow record batches self.c_schemas: List[ffi.CData] = [] self.c_arrays: List[ffi.CData] = [] - def reset(self) -> None: # pylint: disable=missing-function-docstring - raise NotImplementedError() - - # pylint: disable=missing-function-docstring def next(self, data_handle: int) -> int: + "Fetch the next batch." from pyarrow.cffi import ffi + try: batch = next(self.data_iter) self.c_schemas.append(ffi.new("struct ArrowSchema*")) @@ -502,7 +501,7 @@ def next(self, data_handle: int) -> int: # pylint: disable=protected-access batch._export_to_c(ptr_array, ptr_schema) _check_call( - _LIB.XGImportRecordBatch( + _LIB.XGImportArrowRecordBatch( ctypes.c_void_p(data_handle), ctypes.c_void_p(ptr_array), ctypes.c_void_p(ptr_schema), diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 2529d6924d0d..86d763a6af1e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -416,11 +416,12 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, API_END(); } -XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema) { +XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, + void *ptr_schema) { API_BEGIN(); - static_cast(data_handle)->SetData( - static_cast(ptr_array), - static_cast(ptr_schema)); + static_cast(data_handle) + ->SetData(static_cast(ptr_array), + static_cast(ptr_schema)); API_END(); }