diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 7a07be4d2d20..687c1117877c 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -475,29 +475,26 @@ def _is_arrow(data) -> bool: return False -class RecordBatchDataIter: - """Data iterator used to ingest Arrow columnar record batches. We are not - using class DataIter because it is only intended for building Device DMatrix. +def record_batch_data_iter(data_iter: Iterator) -> Callable: + """Data iterator used to ingest Arrow columnar record batches. We are not using + class DataIter because it is only intended for building Device DMatrix and external + memory DMatrix. """ + from pyarrow.cffi import ffi - def __init__(self, data_iter: Iterator) -> None: - from pyarrow.cffi import ffi - - self.data_iter = data_iter # an iterator for Arrow record batches - self.c_schemas: List[ffi.CData] = [] - self.c_arrays: List[ffi.CData] = [] + c_schemas: List[ffi.CData] = [] + c_arrays: List[ffi.CData] = [] - def next(self, data_handle: int) -> int: - "Fetch the next batch." + def _next(data_handle: int) -> int: from pyarrow.cffi import ffi try: - batch = next(self.data_iter) - self.c_schemas.append(ffi.new("struct ArrowSchema*")) - self.c_arrays.append(ffi.new("struct ArrowArray*")) - ptr_schema = int(ffi.cast("uintptr_t", self.c_schemas[-1])) - ptr_array = int(ffi.cast("uintptr_t", self.c_arrays[-1])) + batch = next(data_iter) + c_schemas.append(ffi.new("struct ArrowSchema*")) + c_arrays.append(ffi.new("struct ArrowArray*")) + ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1])) + ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1])) # pylint: disable=protected-access batch._export_to_c(ptr_array, ptr_schema) _check_call( @@ -511,6 +508,8 @@ def next(self, data_handle: int) -> int: except StopIteration: return 0 + return _next + def _from_arrow( data, @@ -538,8 +537,8 @@ def _from_arrow( # use_async=True to workaround pyarrow 6.0.1 hang, # see Modin-3982 and ARROW-15362 rb_iter = iter(data.to_batches(use_async=True)) - it = RecordBatchDataIter(rb_iter) - next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it.next) + it = record_batch_data_iter(rb_iter) + next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it) handle = ctypes.c_void_p() config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")