Skip to content

Commit

Permalink
use function instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 21, 2022
1 parent 7ddbe80 commit 618717d
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,29 +475,26 @@ def _is_arrow(data) -> bool:
return False


class RecordBatchDataIter:
"""Data iterator used to ingest Arrow columnar record batches. We are not
using class DataIter because it is only intended for building Device DMatrix.
def record_batch_data_iter(data_iter: Iterator) -> Callable:
"""Data iterator used to ingest Arrow columnar record batches. We are not using
class DataIter because it is only intended for building Device DMatrix and external
memory DMatrix.
"""
from pyarrow.cffi import ffi

def __init__(self, data_iter: Iterator) -> None:
from pyarrow.cffi import ffi

self.data_iter = data_iter # an iterator for Arrow record batches
self.c_schemas: List[ffi.CData] = []
self.c_arrays: List[ffi.CData] = []
c_schemas: List[ffi.CData] = []
c_arrays: List[ffi.CData] = []

def next(self, data_handle: int) -> int:
"Fetch the next batch."
def _next(data_handle: int) -> int:
from pyarrow.cffi import ffi

try:
batch = next(self.data_iter)
self.c_schemas.append(ffi.new("struct ArrowSchema*"))
self.c_arrays.append(ffi.new("struct ArrowArray*"))
ptr_schema = int(ffi.cast("uintptr_t", self.c_schemas[-1]))
ptr_array = int(ffi.cast("uintptr_t", self.c_arrays[-1]))
batch = next(data_iter)
c_schemas.append(ffi.new("struct ArrowSchema*"))
c_arrays.append(ffi.new("struct ArrowArray*"))
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
# pylint: disable=protected-access
batch._export_to_c(ptr_array, ptr_schema)
_check_call(
Expand All @@ -511,6 +508,8 @@ def next(self, data_handle: int) -> int:
except StopIteration:
return 0

return _next


def _from_arrow(
data,
Expand Down Expand Up @@ -538,8 +537,8 @@ def _from_arrow(
# use_async=True to workaround pyarrow 6.0.1 hang,
# see Modin-3982 and ARROW-15362
rb_iter = iter(data.to_batches(use_async=True))
it = RecordBatchDataIter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it.next)
it = record_batch_data_iter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
handle = ctypes.c_void_p()

config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
Expand Down

0 comments on commit 618717d

Please sign in to comment.