diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index cc04e062883f..17cd5f4af36d 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -502,12 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, char const *indices, char const *data, bst_ulong ncol); - /* * ==========================- End data callback APIs ========================== */ +XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema); + +/*! + * \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable + * and subject to change in the future. + * + * \param next Callback function for fetching arrow records. + * \param json_config JSON encoded configuration. Required values are: + * + * - missing + * - nthread + * + * \param out The created DMatrix. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config, + DMatrixHandle *out); /*! * \brief create a new dmatrix from sliced content of existing matrix diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index adf1cff5c9ab..67a9208fd7ed 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -2,10 +2,11 @@ # pylint: disable=too-many-return-statements, import-error '''Data dispatching for DMatrix.''' import ctypes +from distutils import version import json import warnings import os -from typing import Any, Tuple, Callable, Optional, List, Union +from typing import Any, Tuple, Callable, Optional, List, Union, Iterator import numpy as np @@ -466,6 +467,92 @@ def _from_dt_df( return handle, feature_names, feature_types +def _is_arrow(data) -> bool: + try: + import pyarrow as pa + from pyarrow import dataset as arrow_dataset + return isinstance(data, (pa.Table, arrow_dataset.Dataset)) + except ImportError: + return False + + +def record_batch_data_iter(data_iter: Iterator) -> Callable: + """Data iterator used to ingest Arrow columnar record batches. We are not using + class DataIter because it is only intended for building Device DMatrix and external + memory DMatrix. + + """ + from pyarrow.cffi import ffi + + c_schemas: List[ffi.CData] = [] + c_arrays: List[ffi.CData] = [] + + def _next(data_handle: int) -> int: + from pyarrow.cffi import ffi + + try: + batch = next(data_iter) + c_schemas.append(ffi.new("struct ArrowSchema*")) + c_arrays.append(ffi.new("struct ArrowArray*")) + ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1])) + ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1])) + # pylint: disable=protected-access + batch._export_to_c(ptr_array, ptr_schema) + _check_call( + _LIB.XGImportArrowRecordBatch( + ctypes.c_void_p(data_handle), + ctypes.c_void_p(ptr_array), + ctypes.c_void_p(ptr_schema), + ) + ) + return 1 + except StopIteration: + return 0 + + return _next + + +def _from_arrow( + data, + missing: float, + nthread: int, + feature_names: Optional[List[str]], + feature_types: Optional[List[str]], + enable_categorical: bool, +) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]: + import pyarrow as pa + + if not all( + pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types + ): + raise ValueError( + "Features in dataset can only be integers or floating point number" + ) + if enable_categorical: + raise ValueError("categorical data in arrow is not supported yet.") + + major, _, _ = version.StrictVersion(pa.__version__).version + if major == 4: + rb_iter = iter(data.to_batches()) + else: + # use_async=True to workaround pyarrow 6.0.1 hang, + # see Modin-3982 and ARROW-15362 + rb_iter = iter(data.to_batches(use_async=True)) + it = record_batch_data_iter(rb_iter) + next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it) + handle = ctypes.c_void_p() + + config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") + _check_call( + _LIB.XGDMatrixCreateFromArrowCallback( + next_callback, + config, + ctypes.byref(handle), + ) + ) + return handle, feature_names, feature_types + + def _is_cudf_df(data) -> bool: return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame") @@ -814,6 +901,9 @@ def dispatch_data_backend( return _from_pandas_series( data, missing, threads, enable_categorical, feature_names, feature_types ) + if _is_arrow(data): + return _from_arrow( + data, missing, threads, feature_names, feature_types, enable_categorical) if _has_array_protocol(data): array = np.asarray(data) return _from_numpy_array(array, missing, threads, feature_names, feature_types) @@ -954,6 +1044,7 @@ def dispatch_meta_backend( _meta_from_numpy(data, name, dtype, handle) return if _has_array_protocol(data): + # pyarrow goes here. array = np.asarray(data) _meta_from_numpy(array, name, dtype, handle) return diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 25f055f87b5b..86d763a6af1e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -416,6 +416,27 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, API_END(); } +XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, + void *ptr_schema) { + API_BEGIN(); + static_cast(data_handle) + ->SetData(static_cast(ptr_array), + static_cast(ptr_schema)); + API_END(); +} + +XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config, + DMatrixHandle *out) { + API_BEGIN(); + auto config = Json::Load(StringView{json_config}); + auto missing = GetMissing(config); + int32_t n_threads = get(config["nthread"]); + n_threads = common::OmpGetNumThreads(n_threads); + data::RecordBatchesIterAdapter adapter(next, n_threads); + *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); + API_END(); +} + XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int* idxset, xgboost::bst_ulong len, diff --git a/src/data/adapter.h b/src/data/adapter.h index 4214171e8b7e..4025ccd8e996 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include "xgboost/logging.h" #include "xgboost/base.h" @@ -22,6 +24,7 @@ #include "array_interface.h" #include "../c_api/c_api_error.h" #include "../common/math.h" +#include "arrow-cdi.h" namespace xgboost { namespace data { @@ -676,11 +679,10 @@ class FileAdapter : dmlc::DataIter { template class IteratorAdapter : public dmlc::DataIter { public: - IteratorAdapter(DataIterHandle data_handle, - XGBCallbackDataIterNext* next_callback) - : columns_{data::kAdapterUnknownSize}, row_offset_{0}, - at_first_(true), - data_handle_(data_handle), next_callback_(next_callback) {} + IteratorAdapter(DataIterHandle data_handle, XGBCallbackDataIterNext* next_callback) + : columns_{data::kAdapterUnknownSize}, + data_handle_(data_handle), + next_callback_(next_callback) {} // override functions void BeforeFirst() override { @@ -766,9 +768,9 @@ class IteratorAdapter : public dmlc::DataIter { std::vector value_; size_t columns_; - size_t row_offset_; + size_t row_offset_{0}; // at the beginning. - bool at_first_; + bool at_first_{true}; // handle to the iterator, DataIterHandle data_handle_; // call back to get the data. @@ -777,6 +779,358 @@ class IteratorAdapter : public dmlc::DataIter { dmlc::RowBlock block_; std::unique_ptr batch_; }; + +enum ColumnDType : uint8_t { + kUnknown, + kInt8, + kUInt8, + kInt16, + kUInt16, + kInt32, + kUInt32, + kInt64, + kUInt64, + kFloat, + kDouble +}; + +class Column { + public: + Column() = default; + + Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap) + : col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {} + + virtual ~Column() = default; + + Column(const Column&) = delete; + Column& operator=(const Column&) = delete; + Column(Column&&) = delete; + Column& operator=(Column&&) = delete; + + // whether the valid bit is set for this element + bool IsValid(size_t row_idx) const { + return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8)))); + } + + virtual COOTuple GetElement(size_t row_idx) const = 0; + + virtual bool IsValidElement(size_t row_idx) const = 0; + + virtual std::vector AsFloatVector() const = 0; + + virtual std::vector AsUint64Vector() const = 0; + + size_t Length() const { return length_; } + + protected: + size_t col_idx_; + size_t length_; + size_t null_count_; + const uint8_t* bitmap_; +}; + +// Only columns of primitive types are supported. An ArrowColumnarBatch is a +// collection of std::shared_ptr. These columns can be of different data types. +// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns +// derive from the abstract class Column. +template +class PrimitiveColumn : public Column { + static constexpr float kNaN = std::numeric_limits::quiet_NaN(); + + public: + PrimitiveColumn(size_t idx, size_t length, size_t null_count, + const uint8_t* bitmap, const T* data, float missing) + : Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {} + + COOTuple GetElement(size_t row_idx) const override { + CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column"; + return { row_idx, col_idx_, IsValidElement(row_idx) ? + static_cast(data_[row_idx]) : kNaN }; + } + + bool IsValidElement(size_t row_idx) const override { + // std::isfinite needs to cast to double to prevent msvc report error + return IsValid(row_idx) + && std::isfinite(static_cast(data_[row_idx])) + && static_cast(data_[row_idx]) != missing_; + } + + std::vector AsFloatVector() const override { + CHECK(data_) << "Column is empty"; + std::vector fv(length_); + std::transform(data_, data_ + length_, fv.begin(), + [](T v) { return static_cast(v); }); + return fv; + } + + std::vector AsUint64Vector() const override { + CHECK(data_) << "Column is empty"; + std::vector iv(length_); + std::transform(data_, data_ + length_, iv.begin(), + [](T v) { return static_cast(v); }); + return iv; + } + + private: + const T* data_; + float missing_; // user specified missing value +}; + +struct ColumnarMetaInfo { + // data type of the column + ColumnDType type{ColumnDType::kUnknown}; + // location of the column in an Arrow record batch + int64_t loc{-1}; +}; + +struct ArrowSchemaImporter { + std::vector columns; + + // map Arrow format strings to types + static ColumnDType FormatMap(char const* format_str) { + CHECK(format_str) << "Format string cannot be empty"; + switch (format_str[0]) { + case 'c': + return ColumnDType::kInt8; + case 'C': + return ColumnDType::kUInt8; + case 's': + return ColumnDType::kInt16; + case 'S': + return ColumnDType::kUInt16; + case 'i': + return ColumnDType::kInt32; + case 'I': + return ColumnDType::kUInt32; + case 'l': + return ColumnDType::kInt64; + case 'L': + return ColumnDType::kUInt64; + case 'f': + return ColumnDType::kFloat; + case 'g': + return ColumnDType::kDouble; + default: + CHECK(false) << "Column data type not supported by XGBoost"; + return ColumnDType::kUnknown; + } + } + + void Import(struct ArrowSchema *schema) { + if (schema) { + CHECK(std::string(schema->format) == "+s"); // NOLINT + CHECK(columns.empty()); + for (auto i = 0; i < schema->n_children; ++i) { + std::string name{schema->children[i]->name}; + ColumnDType type = FormatMap(schema->children[i]->format); + ColumnarMetaInfo col_info{type, i}; + columns.push_back(col_info); + } + if (schema->release) { + schema->release(schema); + } + } + } +}; + +class ArrowColumnarBatch { + public: + ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema) + : rb_{rb}, schema_{schema} { + CHECK(rb_) << "Cannot import non-existent record batch"; + CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema"; + } + + size_t Import(float missing) { + auto& infov = schema_->columns; + for (size_t i = 0; i < infov.size(); ++i) { + columns_.push_back(CreateColumn(i, infov[i], missing)); + } + + // Compute the starting location for every row in this batch + auto batch_size = rb_->length; + auto num_columns = columns_.size(); + row_offsets_.resize(batch_size + 1, 0); + for (auto i = 0; i < batch_size; ++i) { + row_offsets_[i+1] = row_offsets_[i]; + for (size_t j = 0; j < num_columns; ++j) { + if (GetColumn(j).IsValidElement(i)) { + row_offsets_[i+1]++; + } + } + } + // return number of elements in the batch + return row_offsets_.back(); + } + + ArrowColumnarBatch(const ArrowColumnarBatch&) = delete; + ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete; + ArrowColumnarBatch(ArrowColumnarBatch&&) = delete; + ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete; + + virtual ~ArrowColumnarBatch() { + if (rb_ && rb_->release) { + rb_->release(rb_); + rb_ = nullptr; + } + columns_.clear(); + } + + size_t Size() const { return rb_ ? rb_->length : 0; } + + size_t NumColumns() const { return columns_.size(); } + + size_t NumElements() const { return row_offsets_.back(); } + + const Column& GetColumn(size_t col_idx) const { + return *columns_[col_idx]; + } + + void ShiftRowOffsets(size_t batch_offset) { + std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(), + [=](size_t c) { return c + batch_offset; }); + } + + const std::vector& RowOffsets() const { return row_offsets_; } + + private: + std::shared_ptr CreateColumn(size_t idx, + ColumnarMetaInfo info, + float missing) const { + if (info.loc < 0) { + return nullptr; + } + + auto loc_in_batch = info.loc; + auto length = rb_->length; + auto null_count = rb_->null_count; + auto buffers0 = rb_->children[loc_in_batch]->buffers[0]; + auto buffers1 = rb_->children[loc_in_batch]->buffers[1]; + const uint8_t* bitmap = buffers0 ? reinterpret_cast(buffers0) : nullptr; + const uint8_t* data = buffers1 ? reinterpret_cast(buffers1) : nullptr; + + // if null_count is not computed, compute it here + if (null_count < 0) { + if (!bitmap) { + null_count = 0; + } else { + null_count = length; + for (auto i = 0; i < length; ++i) { + if (bitmap[i/8] & (1 << (i%8))) { + null_count--; + } + } + } + } + + switch (info.type) { + case ColumnDType::kInt8: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kUInt8: + return std::make_shared>( + idx, length, null_count, bitmap, data, missing); + case ColumnDType::kInt16: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kUInt16: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kInt32: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kUInt32: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kInt64: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kUInt64: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kFloat: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + case ColumnDType::kDouble: + return std::make_shared>( + idx, length, null_count, bitmap, + reinterpret_cast(data), missing); + default: + return nullptr; + } + } + + struct ArrowArray* rb_; + struct ArrowSchemaImporter* schema_; + std::vector> columns_; + std::vector row_offsets_; +}; + +using ArrowColumnarBatchVec = std::vector>; +class RecordBatchesIterAdapter: public dmlc::DataIter { + public: + RecordBatchesIterAdapter(XGDMatrixCallbackNext *next_callback, + int nthread) + : next_callback_{next_callback}, + nbatches_{nthread} {} + + void BeforeFirst() override { + CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter"; + } + + bool Next() override { + batches_.clear(); + while (batches_.size() < static_cast(nbatches_) && (*next_callback_)(this) != 0) { + at_first_ = false; + } + + if (batches_.size() > 0) { + return true; + } else { + return false; + } + } + + void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) { + // Schema is only imported once at the beginning, regardless how many + // baches are comming. + // But even schema is not imported we still need to release its C data + // exported from Arrow. + if (at_first_ && schema) { + schema_.Import(schema); + } else { + if (schema && schema->release) { + schema->release(schema); + } + } + if (rb) { + batches_.push_back(std::make_unique(rb, &schema_)); + } + } + + const ArrowColumnarBatchVec& Value() const override { + return batches_; + } + + size_t NumColumns() const { return schema_.columns.size(); } + size_t NumRows() const { return kAdapterUnknownSize; } + + private: + XGDMatrixCallbackNext *next_callback_; + bool at_first_{true}; + int nbatches_; + struct ArrowSchemaImporter schema_; + ArrowColumnarBatchVec batches_; +}; }; // namespace data } // namespace xgboost #endif // XGBOOST_DATA_ADAPTER_H_ diff --git a/src/data/arrow-cdi.h b/src/data/arrow-cdi.h new file mode 100644 index 000000000000..2cb061b3a3cd --- /dev/null +++ b/src/data/arrow-cdi.h @@ -0,0 +1,66 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#ifdef __cplusplus +} +#endif diff --git a/src/data/data.cc b/src/data/data.cc index 3d1e3cc2862d..8f3db4f9f6c0 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1000,6 +1000,8 @@ template DMatrix * DMatrix::Create(data::IteratorAdapter *adapter, float missing, int nthread, const std::string &cache_prefix); +template DMatrix* DMatrix::Create( + data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&); SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { SparsePage transpose; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 754304fb2c5c..7d2ab32c255f 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -249,5 +249,70 @@ template SimpleDMatrix::SimpleDMatrix( IteratorAdapter *adapter, float missing, int nthread); + +template <> +SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) { + auto& offset_vec = sparse_page_->offset.HostVector(); + auto& data_vec = sparse_page_->data.HostVector(); + uint64_t total_batch_size = 0; + uint64_t total_elements = 0; + + adapter->BeforeFirst(); + // Iterate over batches of input data + while (adapter->Next()) { + auto& batches = adapter->Value(); + size_t num_elements = 0; + size_t num_rows = 0; + // Import Arrow RecordBatches +#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(nthread) + for (int i = 0; i < static_cast(batches.size()); ++i) { // NOLINT + num_elements += batches[i]->Import(missing); + num_rows += batches[i]->Size(); + } + total_elements += num_elements; + total_batch_size += num_rows; + // Compute global offset for every row and starting row for every batch + std::vector batch_offsets(batches.size()); + for (size_t i = 0; i < batches.size(); ++i) { + if (i == 0) { + batch_offsets[i] = total_batch_size - num_rows; + batches[i]->ShiftRowOffsets(total_elements - num_elements); + } else { + batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size(); + batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back()); + } + } + // Pre-allocate DMatrix memory + data_vec.resize(total_elements); + offset_vec.resize(total_batch_size + 1); + // Copy data into DMatrix +#pragma omp parallel num_threads(nthread) + { +#pragma omp for nowait + for (int i = 0; i < static_cast(batches.size()); ++i) { // NOLINT + size_t begin = batches[i]->RowOffsets()[0]; + for (size_t k = 0; k < batches[i]->Size(); ++k) { + for (size_t j = 0; j < batches[i]->NumColumns(); ++j) { + auto element = batches[i]->GetColumn(j).GetElement(k); + if (!std::isnan(element.value)) { + data_vec[begin++] = Entry(element.column_idx, element.value); + } + } + } + } +#pragma omp for nowait + for (int i = 0; i < static_cast(batches.size()); ++i) { + auto& offsets = batches[i]->RowOffsets(); + std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1); + } + } + } + // Synchronise worker columns + info_.num_col_ = adapter->NumColumns(); + rabit::Allreduce(&info_.num_col_, 1); + info_.num_row_ = total_batch_size; + info_.num_nonzero_ = data_vec.size(); + CHECK_EQ(offset_vec.back(), info_.num_nonzero_); +} } // namespace data } // namespace xgboost diff --git a/tests/ci_build/conda_env/aarch64_test.yml b/tests/ci_build/conda_env/aarch64_test.yml index f74833ebf77d..99e8f4840985 100644 --- a/tests/ci_build/conda_env/aarch64_test.yml +++ b/tests/ci_build/conda_env/aarch64_test.yml @@ -26,6 +26,8 @@ dependencies: - awscli - numba - llvmlite +- cffi +- pyarrow - pip: - shap - awscli diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index c82125503678..883471be425c 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -33,6 +33,8 @@ dependencies: - numba - llvmlite - py-ubjson +- cffi +- pyarrow - pip: - shap - ipython # required by shap at import time. diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index c08d21ca4086..38ac8aa1f421 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -33,6 +33,8 @@ dependencies: - boto3 - awscli - py-ubjson +- cffi +- pyarrow - pip: - sphinx_rtd_theme - datatable diff --git a/tests/ci_build/conda_env/win64_cpu_test.yml b/tests/ci_build/conda_env/win64_cpu_test.yml index a8ac47de7ced..7789e94a6fcb 100644 --- a/tests/ci_build/conda_env/win64_cpu_test.yml +++ b/tests/ci_build/conda_env/win64_cpu_test.yml @@ -15,7 +15,8 @@ dependencies: - pytest - jsonschema - hypothesis -- jsonschema - python-graphviz - pip - py-ubjson +- cffi +- pyarrow diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index bf4274ef3d5b..5d761ede8835 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -17,3 +17,5 @@ dependencies: - modin-ray - pip - py-ubjson +- cffi +- pyarrow diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py new file mode 100644 index 000000000000..ad2448294e23 --- /dev/null +++ b/tests/python/test_with_arrow.py @@ -0,0 +1,88 @@ +import unittest +import pytest +import numpy as np +import testing as tm +import xgboost as xgb +import os + +try: + import pyarrow as pa + import pyarrow.csv as pc + import pandas as pd +except ImportError: + pass + +pytestmark = pytest.mark.skipif( + tm.no_arrow()["condition"] or tm.no_pandas()["condition"], + reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"], +) + +dpath = "demo/data/" + + +class TestArrowTable(unittest.TestCase): + def test_arrow_table(self): + df = pd.DataFrame( + [[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"] + ) + table = pa.Table.from_pandas(df) + dm = xgb.DMatrix(table) + assert dm.num_row() == 2 + assert dm.num_col() == 4 + + def test_arrow_table_with_label(self): + df = pd.DataFrame([[1, 2.0, 3.0], [2, 3.0, 4.0]], columns=["a", "b", "c"]) + table = pa.Table.from_pandas(df) + label = np.array([0, 1]) + dm = xgb.DMatrix(table) + dm.set_label(label) + assert dm.num_row() == 2 + assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([0, 1])) + + def test_arrow_table_from_np(self): + coldata = np.array( + [[1.0, 1.0, 0.0, 0.0], [2.0, 0.0, 1.0, 0.0], [3.0, 0.0, 0.0, 1.0]] + ) + cols = list(map(pa.array, coldata)) + table = pa.Table.from_arrays(cols, ["a", "b", "c"]) + dm = xgb.DMatrix(table) + assert dm.num_row() == 4 + assert dm.num_col() == 3 + + def test_arrow_train(self): + import pandas as pd + + rows = 100 + X = pd.DataFrame( + { + "A": np.random.randint(0, 10, size=rows), + "B": np.random.randn(rows), + "C": np.random.permutation([1, 0] * (rows // 2)), + } + ) + y = pd.Series(np.random.randn(rows)) + table = pa.Table.from_pandas(X) + dtrain1 = xgb.DMatrix(table) + dtrain1.set_label(y) + bst1 = xgb.train({}, dtrain1, num_boost_round=10) + preds1 = bst1.predict(xgb.DMatrix(X)) + dtrain2 = xgb.DMatrix(X, y) + bst2 = xgb.train({}, dtrain2, num_boost_round=10) + preds2 = bst2.predict(xgb.DMatrix(X)) + np.testing.assert_allclose(preds1, preds2) + + def test_arrow_survival(self): + data = os.path.join(tm.PROJECT_ROOT, "demo", "data", "veterans_lung_cancer.csv") + table = pc.read_csv(data) + y_lower_bound = table["Survival_label_lower_bound"] + y_upper_bound = table["Survival_label_upper_bound"] + X = table.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"]) + + dtrain = xgb.DMatrix( + X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound + ) + y_np_up = dtrain.get_float_info("label_upper_bound") + y_np_low = dtrain.get_float_info("label_lower_bound") + np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values) + np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values) diff --git a/tests/python/testing.py b/tests/python/testing.py index d2b45bdec30d..64417af42ab9 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -53,6 +53,15 @@ def no_pandas(): 'reason': 'Pandas is not installed.'} +def no_arrow(): + reason = "pyarrow is not installed" + try: + import pyarrow # noqa + return {"condition": False, "reason": reason} + except ImportError: + return {"condition": True, "reason": reason} + + def no_modin(): reason = 'Modin is not installed.' try: