From 60526100e3c064adb68f68ed0c391e9cbdf99c53 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 28 Sep 2023 17:00:16 +0800 Subject: [PATCH] Support arrow through pandas ext types. (#9612) - Use pandas extension type for pyarrow support. - Additional support for QDM. - Additional support for inplace_predict. --- doc/python/python_intro.rst | 3 +- include/xgboost/c_api.h | 18 -- python-package/xgboost/core.py | 4 + python-package/xgboost/data.py | 127 +++++--------- src/c_api/c_api.cc | 29 +-- src/data/adapter.h | 301 -------------------------------- src/data/arrow-cdi.h | 66 ------- src/data/data.cc | 3 - src/data/simple_dmatrix.cc | 74 -------- src/data/simple_dmatrix.h | 10 +- tests/python/test_with_arrow.py | 23 ++- 11 files changed, 74 insertions(+), 584 deletions(-) delete mode 100644 src/data/arrow-cdi.h diff --git a/doc/python/python_intro.rst b/doc/python/python_intro.rst index bb74e7bc3..cc0e461e0 100644 --- a/doc/python/python_intro.rst +++ b/doc/python/python_intro.rst @@ -172,9 +172,8 @@ Support Matrix +-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+ | modin.Series | NPA | FF | NPA | NPA | FF | | +-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+ -| pyarrow.Table | T | F | | NPA | FF | | +| pyarrow.Table | NPA | NPA | NPA | NPA | NPA | NPA | +-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+ -| pyarrow.dataset.Dataset | T | F | | | F | | +-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+ | _\_array\_\_ | NPA | F | NPA | NPA | H | | +-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+ diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 5df62df55..63096cb56 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -552,24 +552,6 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, /** @} */ // End of Streaming -XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema); - -/*! - * \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable - * and subject to change in the future. - * - * \param next Callback function for fetching arrow records. - * \param config JSON encoded configuration. Required values are: - * - missing: Which value to represent missing value. - * - nbatch: Number of batches in arrow table. - * - nthread (optional): Number of threads used for initializing DMatrix. - * \param out The created DMatrix. - * - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *config, - DMatrixHandle *out); - /*! * \brief create a new dmatrix from sliced content of existing matrix * \param handle instance of data matrix to be sliced diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f94e60321..91c6bbd85 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2431,6 +2431,8 @@ class Booster: from .data import ( _array_interface, + _arrow_transform, + _is_arrow, _is_cudf_df, _is_cupy_array, _is_list, @@ -2442,6 +2444,8 @@ class Booster: ) enable_categorical = True + if _is_arrow(data): + data = _arrow_transform(data) if _is_pandas_series(data): import pandas as pd diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 0022a17d4..bfdb21c80 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import ctypes import json import os import warnings -from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, cast +from typing import Any, Callable, List, Optional, Sequence, Tuple, cast import numpy as np @@ -676,86 +676,51 @@ def _from_dt_df( def _is_arrow(data: DataType) -> bool: - try: - import pyarrow as pa - from pyarrow import dataset as arrow_dataset - - return isinstance(data, (pa.Table, arrow_dataset.Dataset)) - except ImportError: - return False + return lazy_isinstance(data, "pyarrow.lib", "Table") or lazy_isinstance( + data, "pyarrow._dataset", "Dataset" + ) -def record_batch_data_iter(data_iter: Iterator) -> Callable: - """Data iterator used to ingest Arrow columnar record batches. We are not using - class DataIter because it is only intended for building Device DMatrix and external - memory DMatrix. - - """ - from pyarrow.cffi import ffi - - c_schemas: List[ffi.CData] = [] - c_arrays: List[ffi.CData] = [] - - def _next(data_handle: int) -> int: - from pyarrow.cffi import ffi - - try: - batch = next(data_iter) - c_schemas.append(ffi.new("struct ArrowSchema*")) - c_arrays.append(ffi.new("struct ArrowArray*")) - ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1])) - ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1])) - # pylint: disable=protected-access - batch._export_to_c(ptr_array, ptr_schema) - _check_call( - _LIB.XGImportArrowRecordBatch( - ctypes.c_void_p(data_handle), - ctypes.c_void_p(ptr_array), - ctypes.c_void_p(ptr_schema), - ) - ) - return 1 - except StopIteration: - return 0 - - return _next - - -def _from_arrow( - data: DataType, - missing: FloatCompatible, - nthread: int, - feature_names: Optional[FeatureNames], - feature_types: Optional[FeatureTypes], - enable_categorical: bool, -) -> DispatchedDataBackendReturnType: +def _arrow_transform(data: DataType) -> Any: + import pandas as pd import pyarrow as pa + from pyarrow.dataset import Dataset - if not all( - pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types - ): - raise ValueError( - "Features in dataset can only be integers or floating point number" - ) - if enable_categorical: - raise ValueError("categorical data in arrow is not supported yet.") + if isinstance(data, Dataset): + raise TypeError("arrow Dataset is not supported.") - batches = data.to_batches() - rb_iter = iter(batches) - it = record_batch_data_iter(rb_iter) - next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it) - handle = ctypes.c_void_p() - config = from_pystr_to_cstr( - json.dumps({"missing": missing, "nthread": nthread, "nbatch": len(batches)}) - ) - _check_call( - _LIB.XGDMatrixCreateFromArrowCallback( - next_callback, - config, - ctypes.byref(handle), - ) - ) - return handle, feature_names, feature_types + data = cast(pa.Table, data) + + def type_mapper(dtype: pa.DataType) -> Optional[str]: + """Maps pyarrow type to pandas arrow extension type.""" + if pa.types.is_int8(dtype): + return pd.ArrowDtype(pa.int8()) + if pa.types.is_int16(dtype): + return pd.ArrowDtype(pa.int16()) + if pa.types.is_int32(dtype): + return pd.ArrowDtype(pa.int32()) + if pa.types.is_int64(dtype): + return pd.ArrowDtype(pa.int64()) + if pa.types.is_uint8(dtype): + return pd.ArrowDtype(pa.uint8()) + if pa.types.is_uint16(dtype): + return pd.ArrowDtype(pa.uint16()) + if pa.types.is_uint32(dtype): + return pd.ArrowDtype(pa.uint32()) + if pa.types.is_uint64(dtype): + return pd.ArrowDtype(pa.uint64()) + if pa.types.is_float16(dtype): + return pd.ArrowDtype(pa.float16()) + if pa.types.is_float32(dtype): + return pd.ArrowDtype(pa.float32()) + if pa.types.is_float64(dtype): + return pd.ArrowDtype(pa.float64()) + if pa.types.is_boolean(dtype): + return pd.ArrowDtype(pa.bool_()) + return None + + df = data.to_pandas(types_mapper=type_mapper) + return df def _is_cudf_df(data: DataType) -> bool: @@ -1081,6 +1046,8 @@ def dispatch_data_backend( return _from_list(data, missing, threads, feature_names, feature_types) if _is_tuple(data): return _from_tuple(data, missing, threads, feature_names, feature_types) + if _is_arrow(data): + data = _arrow_transform(data) if _is_pandas_series(data): import pandas as pd @@ -1114,10 +1081,6 @@ def dispatch_data_backend( return _from_pandas_series( data, missing, threads, enable_categorical, feature_names, feature_types ) - if _is_arrow(data): - return _from_arrow( - data, missing, threads, feature_names, feature_types, enable_categorical - ) if _has_array_protocol(data): array = np.asarray(data) return _from_numpy_array(array, missing, threads, feature_names, feature_types) @@ -1217,6 +1180,8 @@ def dispatch_meta_backend( if _is_np_array_like(data): _meta_from_numpy(data, name, dtype, handle) return + if _is_arrow(data): + data = _arrow_transform(data) if _is_pandas_df(data): data, _, _ = _transform_pandas_df(data, False, meta=name, meta_type=dtype) _meta_from_numpy(data, name, dtype, handle) @@ -1311,6 +1276,8 @@ def _proxy_transform( import pandas as pd data = pd.DataFrame(data) + if _is_arrow(data): + data = _arrow_transform(data) if _is_pandas_df(data): arr, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f6ab8d4df..f8b0aa3de 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -532,33 +532,8 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes, API_END(); } -XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, - void *ptr_schema) { - API_BEGIN(); - static_cast(data_handle) - ->SetData(static_cast(ptr_array), - static_cast(ptr_schema)); - API_END(); -} - -XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *config, - DMatrixHandle *out) { - API_BEGIN(); - xgboost_CHECK_C_ARG_PTR(config); - auto jconfig = Json::Load(StringView{config}); - auto missing = GetMissing(jconfig); - auto n_batches = RequiredArg(jconfig, "nbatch", __func__); - auto n_threads = OptionalArg(jconfig, "nthread", 0); - data::RecordBatchesIterAdapter adapter(next, n_batches); - xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); - API_END(); -} - -XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, - const int* idxset, - xgboost::bst_ulong len, - DMatrixHandle* out) { +XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle, const int *idxset, xgboost::bst_ulong len, + DMatrixHandle *out) { xgboost_CHECK_C_ARG_PTR(out); return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0); } diff --git a/src/data/adapter.h b/src/data/adapter.h index 1463a13a7..e7eaa372f 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -20,7 +20,6 @@ #include "../common/error_msg.h" // for MaxFeatureSize #include "../common/math.h" #include "array_interface.h" -#include "arrow-cdi.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/logging.h" @@ -899,306 +898,6 @@ class Column { const uint8_t* bitmap_; }; -// Only columns of primitive types are supported. An ArrowColumnarBatch is a -// collection of std::shared_ptr. These columns can be of different data types. -// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns -// derive from the abstract class Column. -template -class PrimitiveColumn : public Column { - static constexpr float kNaN = std::numeric_limits::quiet_NaN(); - - public: - PrimitiveColumn(size_t idx, size_t length, size_t null_count, - const uint8_t* bitmap, const T* data, float missing) - : Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {} - - COOTuple GetElement(size_t row_idx) const override { - CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column"; - return { row_idx, col_idx_, IsValidElement(row_idx) ? - static_cast(data_[row_idx]) : kNaN }; - } - - bool IsValidElement(size_t row_idx) const override { - // std::isfinite needs to cast to double to prevent msvc report error - return IsValid(row_idx) - && std::isfinite(static_cast(data_[row_idx])) - && static_cast(data_[row_idx]) != missing_; - } - - std::vector AsFloatVector() const override { - CHECK(data_) << "Column is empty"; - std::vector fv(length_); - std::transform(data_, data_ + length_, fv.begin(), - [](T v) { return static_cast(v); }); - return fv; - } - - std::vector AsUint64Vector() const override { - CHECK(data_) << "Column is empty"; - std::vector iv(length_); - std::transform(data_, data_ + length_, iv.begin(), - [](T v) { return static_cast(v); }); - return iv; - } - - private: - const T* data_; - float missing_; // user specified missing value -}; - -struct ColumnarMetaInfo { - // data type of the column - ColumnDType type{ColumnDType::kUnknown}; - // location of the column in an Arrow record batch - int64_t loc{-1}; -}; - -struct ArrowSchemaImporter { - std::vector columns; - - // map Arrow format strings to types - static ColumnDType FormatMap(char const* format_str) { - CHECK(format_str) << "Format string cannot be empty"; - switch (format_str[0]) { - case 'c': - return ColumnDType::kInt8; - case 'C': - return ColumnDType::kUInt8; - case 's': - return ColumnDType::kInt16; - case 'S': - return ColumnDType::kUInt16; - case 'i': - return ColumnDType::kInt32; - case 'I': - return ColumnDType::kUInt32; - case 'l': - return ColumnDType::kInt64; - case 'L': - return ColumnDType::kUInt64; - case 'f': - return ColumnDType::kFloat; - case 'g': - return ColumnDType::kDouble; - default: - CHECK(false) << "Column data type not supported by XGBoost"; - return ColumnDType::kUnknown; - } - } - - void Import(struct ArrowSchema *schema) { - if (schema) { - CHECK(std::string(schema->format) == "+s"); // NOLINT - CHECK(columns.empty()); - for (auto i = 0; i < schema->n_children; ++i) { - std::string name{schema->children[i]->name}; - ColumnDType type = FormatMap(schema->children[i]->format); - ColumnarMetaInfo col_info{type, i}; - columns.push_back(col_info); - } - if (schema->release) { - schema->release(schema); - } - } - } -}; - -class ArrowColumnarBatch { - public: - ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema) - : rb_{rb}, schema_{schema} { - CHECK(rb_) << "Cannot import non-existent record batch"; - CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema"; - } - - size_t Import(float missing) { - auto& infov = schema_->columns; - for (size_t i = 0; i < infov.size(); ++i) { - columns_.push_back(CreateColumn(i, infov[i], missing)); - } - - // Compute the starting location for every row in this batch - auto batch_size = rb_->length; - auto num_columns = columns_.size(); - row_offsets_.resize(batch_size + 1, 0); - for (auto i = 0; i < batch_size; ++i) { - row_offsets_[i+1] = row_offsets_[i]; - for (size_t j = 0; j < num_columns; ++j) { - if (GetColumn(j).IsValidElement(i)) { - row_offsets_[i+1]++; - } - } - } - // return number of elements in the batch - return row_offsets_.back(); - } - - ArrowColumnarBatch(const ArrowColumnarBatch&) = delete; - ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete; - ArrowColumnarBatch(ArrowColumnarBatch&&) = delete; - ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete; - - virtual ~ArrowColumnarBatch() { - if (rb_ && rb_->release) { - rb_->release(rb_); - rb_ = nullptr; - } - columns_.clear(); - } - - size_t Size() const { return rb_ ? rb_->length : 0; } - - size_t NumColumns() const { return columns_.size(); } - - size_t NumElements() const { return row_offsets_.back(); } - - const Column& GetColumn(size_t col_idx) const { - return *columns_[col_idx]; - } - - void ShiftRowOffsets(size_t batch_offset) { - std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(), - [=](size_t c) { return c + batch_offset; }); - } - - const std::vector& RowOffsets() const { return row_offsets_; } - - private: - std::shared_ptr CreateColumn(size_t idx, - ColumnarMetaInfo info, - float missing) const { - if (info.loc < 0) { - return nullptr; - } - - auto loc_in_batch = info.loc; - auto length = rb_->length; - auto null_count = rb_->null_count; - auto buffers0 = rb_->children[loc_in_batch]->buffers[0]; - auto buffers1 = rb_->children[loc_in_batch]->buffers[1]; - const uint8_t* bitmap = buffers0 ? reinterpret_cast(buffers0) : nullptr; - const uint8_t* data = buffers1 ? reinterpret_cast(buffers1) : nullptr; - - // if null_count is not computed, compute it here - if (null_count < 0) { - if (!bitmap) { - null_count = 0; - } else { - null_count = length; - for (auto i = 0; i < length; ++i) { - if (bitmap[i/8] & (1 << (i%8))) { - null_count--; - } - } - } - } - - switch (info.type) { - case ColumnDType::kInt8: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kUInt8: - return std::make_shared>( - idx, length, null_count, bitmap, data, missing); - case ColumnDType::kInt16: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kUInt16: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kInt32: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kUInt32: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kInt64: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kUInt64: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kFloat: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - case ColumnDType::kDouble: - return std::make_shared>( - idx, length, null_count, bitmap, - reinterpret_cast(data), missing); - default: - return nullptr; - } - } - - struct ArrowArray* rb_; - struct ArrowSchemaImporter* schema_; - std::vector> columns_; - std::vector row_offsets_; -}; - -using ArrowColumnarBatchVec = std::vector>; -class RecordBatchesIterAdapter: public dmlc::DataIter { - public: - RecordBatchesIterAdapter(XGDMatrixCallbackNext* next_callback, int nbatch) - : next_callback_{next_callback}, nbatches_{nbatch} {} - - void BeforeFirst() override { - CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter"; - } - - bool Next() override { - batches_.clear(); - while (batches_.size() < static_cast(nbatches_) && (*next_callback_)(this) != 0) { - at_first_ = false; - } - - if (batches_.size() > 0) { - return true; - } else { - return false; - } - } - - void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) { - // Schema is only imported once at the beginning, regardless how many - // baches are comming. - // But even schema is not imported we still need to release its C data - // exported from Arrow. - if (at_first_ && schema) { - schema_.Import(schema); - } else { - if (schema && schema->release) { - schema->release(schema); - } - } - if (rb) { - batches_.push_back(std::make_unique(rb, &schema_)); - } - } - - const ArrowColumnarBatchVec& Value() const override { - return batches_; - } - - size_t NumColumns() const { return schema_.columns.size(); } - size_t NumRows() const { return kAdapterUnknownSize; } - - private: - XGDMatrixCallbackNext *next_callback_; - bool at_first_{true}; - int nbatches_; - struct ArrowSchemaImporter schema_; - ArrowColumnarBatchVec batches_; -}; - class SparsePageAdapterBatch { HostSparsePageView page_; diff --git a/src/data/arrow-cdi.h b/src/data/arrow-cdi.h deleted file mode 100644 index 2cb061b3a..000000000 --- a/src/data/arrow-cdi.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#pragma once - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -#define ARROW_FLAG_DICTIONARY_ORDERED 1 -#define ARROW_FLAG_NULLABLE 2 -#define ARROW_FLAG_MAP_KEYS_SORTED 4 - -struct ArrowSchema { - // Array type description - const char* format; - const char* name; - const char* metadata; - int64_t flags; - int64_t n_children; - struct ArrowSchema** children; - struct ArrowSchema* dictionary; - - // Release callback - void (*release)(struct ArrowSchema*); - // Opaque producer-specific data - void* private_data; -}; - -struct ArrowArray { - // Array data description - int64_t length; - int64_t null_count; - int64_t offset; - int64_t n_buffers; - int64_t n_children; - const void** buffers; - struct ArrowArray** children; - struct ArrowArray* dictionary; - - // Release callback - void (*release)(struct ArrowArray*); - // Opaque producer-specific data - void* private_data; -}; - -#ifdef __cplusplus -} -#endif diff --git a/src/data/data.cc b/src/data/data.cc index 92547dafd..4a2bef6be 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1011,9 +1011,6 @@ template DMatrix* DMatrix::Create(data::CSCArrayAdapter* template DMatrix* DMatrix::Create( data::IteratorAdapter* adapter, float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode); -template DMatrix* DMatrix::Create( - data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&, - DataSplitMode data_split_mode); SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { SparsePage transpose; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index bf7b27eb7..0adf6b466 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -361,78 +361,4 @@ template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int n template SimpleDMatrix::SimpleDMatrix( IteratorAdapter* adapter, float missing, int nthread, DataSplitMode data_split_mode); - -template <> -SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread, - DataSplitMode data_split_mode) { - Context ctx; - ctx.nthread = nthread; - - auto& offset_vec = sparse_page_->offset.HostVector(); - auto& data_vec = sparse_page_->data.HostVector(); - uint64_t total_batch_size = 0; - uint64_t total_elements = 0; - - adapter->BeforeFirst(); - // Iterate over batches of input data - while (adapter->Next()) { - auto& batches = adapter->Value(); - size_t num_elements = 0; - size_t num_rows = 0; - // Import Arrow RecordBatches -#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx.Threads()) - for (int i = 0; i < static_cast(batches.size()); ++i) { // NOLINT - num_elements += batches[i]->Import(missing); - num_rows += batches[i]->Size(); - } - total_elements += num_elements; - total_batch_size += num_rows; - // Compute global offset for every row and starting row for every batch - std::vector batch_offsets(batches.size()); - for (size_t i = 0; i < batches.size(); ++i) { - if (i == 0) { - batch_offsets[i] = total_batch_size - num_rows; - batches[i]->ShiftRowOffsets(total_elements - num_elements); - } else { - batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size(); - batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back()); - } - } - // Pre-allocate DMatrix memory - data_vec.resize(total_elements); - offset_vec.resize(total_batch_size + 1); - // Copy data into DMatrix -#pragma omp parallel num_threads(ctx.Threads()) - { -#pragma omp for nowait - for (int i = 0; i < static_cast(batches.size()); ++i) { // NOLINT - size_t begin = batches[i]->RowOffsets()[0]; - for (size_t k = 0; k < batches[i]->Size(); ++k) { - for (size_t j = 0; j < batches[i]->NumColumns(); ++j) { - auto element = batches[i]->GetColumn(j).GetElement(k); - if (!std::isnan(element.value)) { - data_vec[begin++] = Entry(element.column_idx, element.value); - } - } - } - } -#pragma omp for nowait - for (int i = 0; i < static_cast(batches.size()); ++i) { - auto& offsets = batches[i]->RowOffsets(); - std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1); - } - } - } - // Synchronise worker columns - info_.num_col_ = adapter->NumColumns(); - info_.data_split_mode = data_split_mode; - ReindexFeatures(&ctx); - info_.SynchronizeNumberOfColumns(); - - info_.num_row_ = total_batch_size; - info_.num_nonzero_ = data_vec.size(); - CHECK_EQ(offset_vec.back(), info_.num_nonzero_); - - fmat_ctx_ = ctx; -} } // namespace xgboost::data diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 56685c1e6..d6164894a 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file simple_dmatrix.h * \brief In-memory version of DMatrix. * \author Tianqi Chen @@ -15,8 +15,7 @@ #include "gradient_index.h" -namespace xgboost { -namespace data { +namespace xgboost::data { // Used for single batch data. class SimpleDMatrix : public DMatrix { public: @@ -75,6 +74,5 @@ class SimpleDMatrix : public DMatrix { // Context used only for DMatrix initialization. Context fmat_ctx_; }; -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_SIMPLE_DMATRIX_H_ diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py index 8b7bce9eb..4673a688e 100644 --- a/tests/python/test_with_arrow.py +++ b/tests/python/test_with_arrow.py @@ -22,7 +22,7 @@ pytestmark = pytest.mark.skipif( dpath = "demo/data/" -class TestArrowTable(unittest.TestCase): +class TestArrowTable: def test_arrow_table(self): df = pd.DataFrame( [[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"] @@ -52,7 +52,8 @@ class TestArrowTable(unittest.TestCase): assert dm.num_row() == 4 assert dm.num_col() == 3 - def test_arrow_train(self): + @pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix]) + def test_arrow_train(self, DMatrixT): import pandas as pd rows = 100 @@ -64,16 +65,24 @@ class TestArrowTable(unittest.TestCase): } ) y = pd.Series(np.random.randn(rows)) + table = pa.Table.from_pandas(X) - dtrain1 = xgb.DMatrix(table) - dtrain1.set_label(y) + dtrain1 = DMatrixT(table) + dtrain1.set_label(pa.Table.from_pandas(pd.DataFrame(y))) bst1 = xgb.train({}, dtrain1, num_boost_round=10) - preds1 = bst1.predict(xgb.DMatrix(X)) - dtrain2 = xgb.DMatrix(X, y) + preds1 = bst1.predict(DMatrixT(X)) + + dtrain2 = DMatrixT(X, y) bst2 = xgb.train({}, dtrain2, num_boost_round=10) - preds2 = bst2.predict(xgb.DMatrix(X)) + preds2 = bst2.predict(DMatrixT(X)) + np.testing.assert_allclose(preds1, preds2) + preds3 = bst2.inplace_predict(table) + np.testing.assert_allclose(preds1, preds3) + assert bst2.feature_names == ["A", "B", "C"] + assert bst2.feature_types == ["int", "float", "int"] + def test_arrow_survival(self): data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv") table = pc.read_csv(data)