Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FOLLOW-UP] Support for building DMatrix from Apache Arrow data format #7512

Merged
merged 1 commit into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
char const *indices, char const *data,
bst_ulong ncol);


/*
* ==========================- End data callback APIs ==========================
*/


XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);

/*!
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
* and subject to change in the future.
*
* \param next Callback function for fetching arrow records.
* \param json_config JSON encoded configuration. Required values are:
*
* - missing
* - nthread
*
* \param out The created DMatrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out);

/*!
* \brief create a new dmatrix from sliced content of existing matrix
Expand Down
93 changes: 92 additions & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# pylint: disable=too-many-return-statements, import-error
'''Data dispatching for DMatrix.'''
import ctypes
from distutils import version
import json
import warnings
import os
from typing import Any, Tuple, Callable, Optional, List, Union
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator

import numpy as np

Expand Down Expand Up @@ -466,6 +467,92 @@ def _from_dt_df(
return handle, feature_names, feature_types


def _is_arrow(data) -> bool:
try:
import pyarrow as pa
from pyarrow import dataset as arrow_dataset
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
except ImportError:
return False


def record_batch_data_iter(data_iter: Iterator) -> Callable:
"""Data iterator used to ingest Arrow columnar record batches. We are not using
class DataIter because it is only intended for building Device DMatrix and external
memory DMatrix.

"""
from pyarrow.cffi import ffi

c_schemas: List[ffi.CData] = []
c_arrays: List[ffi.CData] = []

def _next(data_handle: int) -> int:
from pyarrow.cffi import ffi

try:
batch = next(data_iter)
c_schemas.append(ffi.new("struct ArrowSchema*"))
c_arrays.append(ffi.new("struct ArrowArray*"))
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
# pylint: disable=protected-access
batch._export_to_c(ptr_array, ptr_schema)
_check_call(
_LIB.XGImportArrowRecordBatch(
ctypes.c_void_p(data_handle),
ctypes.c_void_p(ptr_array),
ctypes.c_void_p(ptr_schema),
)
)
return 1
except StopIteration:
return 0

return _next


def _from_arrow(
data,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
import pyarrow as pa

if not all(
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
):
raise ValueError(
"Features in dataset can only be integers or floating point number"
)
if enable_categorical:
raise ValueError("categorical data in arrow is not supported yet.")

major, _, _ = version.StrictVersion(pa.__version__).version
if major == 4:
rb_iter = iter(data.to_batches())
else:
# use_async=True to workaround pyarrow 6.0.1 hang,
# see Modin-3982 and ARROW-15362
rb_iter = iter(data.to_batches(use_async=True))
it = record_batch_data_iter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
handle = ctypes.c_void_p()

config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromArrowCallback(
next_callback,
config,
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types


def _is_cudf_df(data) -> bool:
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")

Expand Down Expand Up @@ -814,6 +901,9 @@ def dispatch_data_backend(
return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types
)
if _is_arrow(data):
return _from_arrow(
data, missing, threads, feature_names, feature_types, enable_categorical)
if _has_array_protocol(data):
array = np.asarray(data)
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
Expand Down Expand Up @@ -954,6 +1044,7 @@ def dispatch_meta_backend(
_meta_from_numpy(data, name, dtype, handle)
return
if _has_array_protocol(data):
# pyarrow goes here.
array = np.asarray(data)
_meta_from_numpy(array, name, dtype, handle)
return
Expand Down
21 changes: 21 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,27 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
API_END();
}

XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
void *ptr_schema) {
API_BEGIN();
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
->SetData(static_cast<struct ArrowArray *>(ptr_array),
static_cast<struct ArrowSchema *>(ptr_schema));
API_END();
}

XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out) {
API_BEGIN();
auto config = Json::Load(StringView{json_config});
auto missing = GetMissing(config);
int32_t n_threads = get<Integer const>(config["nthread"]);
n_threads = common::OmpGetNumThreads(n_threads);
data::RecordBatchesIterAdapter adapter(next, n_threads);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END();
}

XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
Expand Down
Loading