Skip to content

Commit

Permalink
Support building SimpleDMatrix from Arrow data format
Browse files Browse the repository at this point in the history
* Integrate with Arrow C data API.
* Support Arrow dataset.
* Support Arrow table.

Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
  • Loading branch information
3 people committed Mar 14, 2022
1 parent 4dafb5f commit 2fd751e
Show file tree
Hide file tree
Showing 14 changed files with 732 additions and 10 deletions.
19 changes: 18 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
char const *indices, char const *data,
bst_ulong ncol);


/*
* ==========================- End data callback APIs ==========================
*/


XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);

/*!
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
* and subject to change in the future.
*
* \param next Callback function for fetching arrow records.
* \param json_config JSON encoded configuration. Required values are:
*
* - missing
* - nthread
*
* \param out The created DMatrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out);

/*!
* \brief create a new dmatrix from sliced content of existing matrix
Expand Down
93 changes: 92 additions & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# pylint: disable=too-many-return-statements, import-error
'''Data dispatching for DMatrix.'''
import ctypes
from distutils import version
import json
import warnings
import os
from typing import Any, Tuple, Callable, Optional, List, Union
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator

import numpy as np

Expand Down Expand Up @@ -466,6 +467,92 @@ def _from_dt_df(
return handle, feature_names, feature_types


def _is_arrow(data) -> bool:
try:
import pyarrow as pa
from pyarrow import dataset as arrow_dataset
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
except ImportError:
return False


def record_batch_data_iter(data_iter: Iterator) -> Callable:
"""Data iterator used to ingest Arrow columnar record batches. We are not using
class DataIter because it is only intended for building Device DMatrix and external
memory DMatrix.
"""
from pyarrow.cffi import ffi

c_schemas: List[ffi.CData] = []
c_arrays: List[ffi.CData] = []

def _next(data_handle: int) -> int:
from pyarrow.cffi import ffi

try:
batch = next(data_iter)
c_schemas.append(ffi.new("struct ArrowSchema*"))
c_arrays.append(ffi.new("struct ArrowArray*"))
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
# pylint: disable=protected-access
batch._export_to_c(ptr_array, ptr_schema)
_check_call(
_LIB.XGImportArrowRecordBatch(
ctypes.c_void_p(data_handle),
ctypes.c_void_p(ptr_array),
ctypes.c_void_p(ptr_schema),
)
)
return 1
except StopIteration:
return 0

return _next


def _from_arrow(
data,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
import pyarrow as pa

if not all(
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
):
raise ValueError(
"Features in dataset can only be integers or floating point number"
)
if enable_categorical:
raise ValueError("categorical data in arrow is not supported yet.")

major, _, _ = version.StrictVersion(pa.__version__).version
if major == 4:
rb_iter = iter(data.to_batches())
else:
# use_async=True to workaround pyarrow 6.0.1 hang,
# see Modin-3982 and ARROW-15362
rb_iter = iter(data.to_batches(use_async=True))
it = record_batch_data_iter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
handle = ctypes.c_void_p()

config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromArrowCallback(
next_callback,
config,
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types


def _is_cudf_df(data) -> bool:
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")

Expand Down Expand Up @@ -814,6 +901,9 @@ def dispatch_data_backend(
return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types
)
if _is_arrow(data):
return _from_arrow(
data, missing, threads, feature_names, feature_types, enable_categorical)
if _has_array_protocol(data):
array = np.asarray(data)
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
Expand Down Expand Up @@ -954,6 +1044,7 @@ def dispatch_meta_backend(
_meta_from_numpy(data, name, dtype, handle)
return
if _has_array_protocol(data):
# pyarrow goes here.
array = np.asarray(data)
_meta_from_numpy(array, name, dtype, handle)
return
Expand Down
21 changes: 21 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,27 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
API_END();
}

XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
void *ptr_schema) {
API_BEGIN();
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
->SetData(static_cast<struct ArrowArray *>(ptr_array),
static_cast<struct ArrowSchema *>(ptr_schema));
API_END();
}

XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out) {
API_BEGIN();
auto config = Json::Load(StringView{json_config});
auto missing = GetMissing(config);
int32_t n_threads = get<Integer const>(config["nthread"]);
n_threads = common::OmpGetNumThreads(n_threads);
data::RecordBatchesIterAdapter adapter(next, n_threads);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END();
}

XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
Expand Down
Loading

0 comments on commit 2fd751e

Please sign in to comment.