Support building SimpleDMatrix from Arrow data format (#7512)

* Integrate with Arrow C data API.
* Support Arrow dataset.
* Support Arrow table.

Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
Co-authored-by: Zhang Zhang <zhang.zhang@intel.com>
This commit is contained in:
Xiaochang Wu 2022-03-14 22:25:19 -07:00 committed by GitHub
parent 6b6849b001
commit 613ec36c5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 732 additions and 10 deletions

View File

@ -502,12 +502,29 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
char const *indices, char const *data,
bst_ulong ncol);
/*
* ==========================- End data callback APIs ==========================
*/
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array, void *ptr_schema);
/*!
* \brief Construct DMatrix from arrow using callbacks. Arrow related C API is not stable
* and subject to change in the future.
*
* \param next Callback function for fetching arrow records.
* \param json_config JSON encoded configuration. Required values are:
*
* - missing
* - nthread
*
* \param out The created DMatrix.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out);
/*!
* \brief create a new dmatrix from sliced content of existing matrix

View File

@ -2,10 +2,11 @@
# pylint: disable=too-many-return-statements, import-error
'''Data dispatching for DMatrix.'''
import ctypes
from distutils import version
import json
import warnings
import os
from typing import Any, Tuple, Callable, Optional, List, Union
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator
import numpy as np
@ -466,6 +467,92 @@ def _from_dt_df(
return handle, feature_names, feature_types
def _is_arrow(data) -> bool:
try:
import pyarrow as pa
from pyarrow import dataset as arrow_dataset
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
except ImportError:
return False
def record_batch_data_iter(data_iter: Iterator) -> Callable:
"""Data iterator used to ingest Arrow columnar record batches. We are not using
class DataIter because it is only intended for building Device DMatrix and external
memory DMatrix.
"""
from pyarrow.cffi import ffi
c_schemas: List[ffi.CData] = []
c_arrays: List[ffi.CData] = []
def _next(data_handle: int) -> int:
from pyarrow.cffi import ffi
try:
batch = next(data_iter)
c_schemas.append(ffi.new("struct ArrowSchema*"))
c_arrays.append(ffi.new("struct ArrowArray*"))
ptr_schema = int(ffi.cast("uintptr_t", c_schemas[-1]))
ptr_array = int(ffi.cast("uintptr_t", c_arrays[-1]))
# pylint: disable=protected-access
batch._export_to_c(ptr_array, ptr_schema)
_check_call(
_LIB.XGImportArrowRecordBatch(
ctypes.c_void_p(data_handle),
ctypes.c_void_p(ptr_array),
ctypes.c_void_p(ptr_schema),
)
)
return 1
except StopIteration:
return 0
return _next
def _from_arrow(
data,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
import pyarrow as pa
if not all(
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
):
raise ValueError(
"Features in dataset can only be integers or floating point number"
)
if enable_categorical:
raise ValueError("categorical data in arrow is not supported yet.")
major, _, _ = version.StrictVersion(pa.__version__).version
if major == 4:
rb_iter = iter(data.to_batches())
else:
# use_async=True to workaround pyarrow 6.0.1 hang,
# see Modin-3982 and ARROW-15362
rb_iter = iter(data.to_batches(use_async=True))
it = record_batch_data_iter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it)
handle = ctypes.c_void_p()
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromArrowCallback(
next_callback,
config,
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types
def _is_cudf_df(data) -> bool:
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
@ -814,6 +901,9 @@ def dispatch_data_backend(
return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types
)
if _is_arrow(data):
return _from_arrow(
data, missing, threads, feature_names, feature_types, enable_categorical)
if _has_array_protocol(data):
array = np.asarray(data)
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
@ -954,6 +1044,7 @@ def dispatch_meta_backend(
_meta_from_numpy(data, name, dtype, handle)
return
if _has_array_protocol(data):
# pyarrow goes here.
array = np.asarray(data)
_meta_from_numpy(array, name, dtype, handle)
return

View File

@ -416,6 +416,27 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
API_END();
}
XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array,
void *ptr_schema) {
API_BEGIN();
static_cast<data::RecordBatchesIterAdapter *>(data_handle)
->SetData(static_cast<struct ArrowArray *>(ptr_array),
static_cast<struct ArrowSchema *>(ptr_schema));
API_END();
}
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
DMatrixHandle *out) {
API_BEGIN();
auto config = Json::Load(StringView{json_config});
auto missing = GetMissing(config);
int32_t n_threads = get<Integer const>(config["nthread"]);
n_threads = common::OmpGetNumThreads(n_threads);
data::RecordBatchesIterAdapter adapter(next, n_threads);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
API_END();
}
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,

View File

@ -13,6 +13,8 @@
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <algorithm>
#include "xgboost/logging.h"
#include "xgboost/base.h"
@ -22,6 +24,7 @@
#include "array_interface.h"
#include "../c_api/c_api_error.h"
#include "../common/math.h"
#include "arrow-cdi.h"
namespace xgboost {
namespace data {
@ -676,11 +679,10 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
public:
IteratorAdapter(DataIterHandle data_handle,
XGBCallbackDataIterNext* next_callback)
: columns_{data::kAdapterUnknownSize}, row_offset_{0},
at_first_(true),
data_handle_(data_handle), next_callback_(next_callback) {}
IteratorAdapter(DataIterHandle data_handle, XGBCallbackDataIterNext* next_callback)
: columns_{data::kAdapterUnknownSize},
data_handle_(data_handle),
next_callback_(next_callback) {}
// override functions
void BeforeFirst() override {
@ -766,9 +768,9 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
std::vector<dmlc::real_t> value_;
size_t columns_;
size_t row_offset_;
size_t row_offset_{0};
// at the beginning.
bool at_first_;
bool at_first_{true};
// handle to the iterator,
DataIterHandle data_handle_;
// call back to get the data.
@ -777,6 +779,358 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
dmlc::RowBlock<uint32_t> block_;
std::unique_ptr<FileAdapterBatch> batch_;
};
enum ColumnDType : uint8_t {
kUnknown,
kInt8,
kUInt8,
kInt16,
kUInt16,
kInt32,
kUInt32,
kInt64,
kUInt64,
kFloat,
kDouble
};
class Column {
public:
Column() = default;
Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
: col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
virtual ~Column() = default;
Column(const Column&) = delete;
Column& operator=(const Column&) = delete;
Column(Column&&) = delete;
Column& operator=(Column&&) = delete;
// whether the valid bit is set for this element
bool IsValid(size_t row_idx) const {
return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
}
virtual COOTuple GetElement(size_t row_idx) const = 0;
virtual bool IsValidElement(size_t row_idx) const = 0;
virtual std::vector<float> AsFloatVector() const = 0;
virtual std::vector<uint64_t> AsUint64Vector() const = 0;
size_t Length() const { return length_; }
protected:
size_t col_idx_;
size_t length_;
size_t null_count_;
const uint8_t* bitmap_;
};
// Only columns of primitive types are supported. An ArrowColumnarBatch is a
// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
// derive from the abstract class Column.
template <typename T>
class PrimitiveColumn : public Column {
static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
public:
PrimitiveColumn(size_t idx, size_t length, size_t null_count,
const uint8_t* bitmap, const T* data, float missing)
: Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
COOTuple GetElement(size_t row_idx) const override {
CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
return { row_idx, col_idx_, IsValidElement(row_idx) ?
static_cast<float>(data_[row_idx]) : kNaN };
}
bool IsValidElement(size_t row_idx) const override {
// std::isfinite needs to cast to double to prevent msvc report error
return IsValid(row_idx)
&& std::isfinite(static_cast<double>(data_[row_idx]))
&& static_cast<float>(data_[row_idx]) != missing_;
}
std::vector<float> AsFloatVector() const override {
CHECK(data_) << "Column is empty";
std::vector<float> fv(length_);
std::transform(data_, data_ + length_, fv.begin(),
[](T v) { return static_cast<float>(v); });
return fv;
}
std::vector<uint64_t> AsUint64Vector() const override {
CHECK(data_) << "Column is empty";
std::vector<uint64_t> iv(length_);
std::transform(data_, data_ + length_, iv.begin(),
[](T v) { return static_cast<uint64_t>(v); });
return iv;
}
private:
const T* data_;
float missing_; // user specified missing value
};
struct ColumnarMetaInfo {
// data type of the column
ColumnDType type{ColumnDType::kUnknown};
// location of the column in an Arrow record batch
int64_t loc{-1};
};
struct ArrowSchemaImporter {
std::vector<ColumnarMetaInfo> columns;
// map Arrow format strings to types
static ColumnDType FormatMap(char const* format_str) {
CHECK(format_str) << "Format string cannot be empty";
switch (format_str[0]) {
case 'c':
return ColumnDType::kInt8;
case 'C':
return ColumnDType::kUInt8;
case 's':
return ColumnDType::kInt16;
case 'S':
return ColumnDType::kUInt16;
case 'i':
return ColumnDType::kInt32;
case 'I':
return ColumnDType::kUInt32;
case 'l':
return ColumnDType::kInt64;
case 'L':
return ColumnDType::kUInt64;
case 'f':
return ColumnDType::kFloat;
case 'g':
return ColumnDType::kDouble;
default:
CHECK(false) << "Column data type not supported by XGBoost";
return ColumnDType::kUnknown;
}
}
void Import(struct ArrowSchema *schema) {
if (schema) {
CHECK(std::string(schema->format) == "+s"); // NOLINT
CHECK(columns.empty());
for (auto i = 0; i < schema->n_children; ++i) {
std::string name{schema->children[i]->name};
ColumnDType type = FormatMap(schema->children[i]->format);
ColumnarMetaInfo col_info{type, i};
columns.push_back(col_info);
}
if (schema->release) {
schema->release(schema);
}
}
}
};
class ArrowColumnarBatch {
public:
ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
: rb_{rb}, schema_{schema} {
CHECK(rb_) << "Cannot import non-existent record batch";
CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
}
size_t Import(float missing) {
auto& infov = schema_->columns;
for (size_t i = 0; i < infov.size(); ++i) {
columns_.push_back(CreateColumn(i, infov[i], missing));
}
// Compute the starting location for every row in this batch
auto batch_size = rb_->length;
auto num_columns = columns_.size();
row_offsets_.resize(batch_size + 1, 0);
for (auto i = 0; i < batch_size; ++i) {
row_offsets_[i+1] = row_offsets_[i];
for (size_t j = 0; j < num_columns; ++j) {
if (GetColumn(j).IsValidElement(i)) {
row_offsets_[i+1]++;
}
}
}
// return number of elements in the batch
return row_offsets_.back();
}
ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
ArrowColumnarBatch(ArrowColumnarBatch&&) = delete;
ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
virtual ~ArrowColumnarBatch() {
if (rb_ && rb_->release) {
rb_->release(rb_);
rb_ = nullptr;
}
columns_.clear();
}
size_t Size() const { return rb_ ? rb_->length : 0; }
size_t NumColumns() const { return columns_.size(); }
size_t NumElements() const { return row_offsets_.back(); }
const Column& GetColumn(size_t col_idx) const {
return *columns_[col_idx];
}
void ShiftRowOffsets(size_t batch_offset) {
std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
[=](size_t c) { return c + batch_offset; });
}
const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
private:
std::shared_ptr<Column> CreateColumn(size_t idx,
ColumnarMetaInfo info,
float missing) const {
if (info.loc < 0) {
return nullptr;
}
auto loc_in_batch = info.loc;
auto length = rb_->length;
auto null_count = rb_->null_count;
auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
// if null_count is not computed, compute it here
if (null_count < 0) {
if (!bitmap) {
null_count = 0;
} else {
null_count = length;
for (auto i = 0; i < length; ++i) {
if (bitmap[i/8] & (1 << (i%8))) {
null_count--;
}
}
}
}
switch (info.type) {
case ColumnDType::kInt8:
return std::make_shared<PrimitiveColumn<int8_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const int8_t*>(data), missing);
case ColumnDType::kUInt8:
return std::make_shared<PrimitiveColumn<uint8_t>>(
idx, length, null_count, bitmap, data, missing);
case ColumnDType::kInt16:
return std::make_shared<PrimitiveColumn<int16_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const int16_t*>(data), missing);
case ColumnDType::kUInt16:
return std::make_shared<PrimitiveColumn<uint16_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const uint16_t*>(data), missing);
case ColumnDType::kInt32:
return std::make_shared<PrimitiveColumn<int32_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const int32_t*>(data), missing);
case ColumnDType::kUInt32:
return std::make_shared<PrimitiveColumn<uint32_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const uint32_t*>(data), missing);
case ColumnDType::kInt64:
return std::make_shared<PrimitiveColumn<int64_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const int64_t*>(data), missing);
case ColumnDType::kUInt64:
return std::make_shared<PrimitiveColumn<uint64_t>>(
idx, length, null_count, bitmap,
reinterpret_cast<const uint64_t*>(data), missing);
case ColumnDType::kFloat:
return std::make_shared<PrimitiveColumn<float>>(
idx, length, null_count, bitmap,
reinterpret_cast<const float*>(data), missing);
case ColumnDType::kDouble:
return std::make_shared<PrimitiveColumn<double>>(
idx, length, null_count, bitmap,
reinterpret_cast<const double*>(data), missing);
default:
return nullptr;
}
}
struct ArrowArray* rb_;
struct ArrowSchemaImporter* schema_;
std::vector<std::shared_ptr<Column>> columns_;
std::vector<size_t> row_offsets_;
};
using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
public:
RecordBatchesIterAdapter(XGDMatrixCallbackNext *next_callback,
int nthread)
: next_callback_{next_callback},
nbatches_{nthread} {}
void BeforeFirst() override {
CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
}
bool Next() override {
batches_.clear();
while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
at_first_ = false;
}
if (batches_.size() > 0) {
return true;
} else {
return false;
}
}
void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
// Schema is only imported once at the beginning, regardless how many
// baches are comming.
// But even schema is not imported we still need to release its C data
// exported from Arrow.
if (at_first_ && schema) {
schema_.Import(schema);
} else {
if (schema && schema->release) {
schema->release(schema);
}
}
if (rb) {
batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
}
}
const ArrowColumnarBatchVec& Value() const override {
return batches_;
}
size_t NumColumns() const { return schema_.columns.size(); }
size_t NumRows() const { return kAdapterUnknownSize; }
private:
XGDMatrixCallbackNext *next_callback_;
bool at_first_{true};
int nbatches_;
struct ArrowSchemaImporter schema_;
ArrowColumnarBatchVec batches_;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ADAPTER_H_

66
src/data/arrow-cdi.h Normal file
View File

@ -0,0 +1,66 @@
/* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
#include <cstdint>
#ifdef __cplusplus
extern "C" {
#endif
#define ARROW_FLAG_DICTIONARY_ORDERED 1
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4
struct ArrowSchema {
// Array type description
const char* format;
const char* name;
const char* metadata;
int64_t flags;
int64_t n_children;
struct ArrowSchema** children;
struct ArrowSchema* dictionary;
// Release callback
void (*release)(struct ArrowSchema*);
// Opaque producer-specific data
void* private_data;
};
struct ArrowArray {
// Array data description
int64_t length;
int64_t null_count;
int64_t offset;
int64_t n_buffers;
int64_t n_children;
const void** buffers;
struct ArrowArray** children;
struct ArrowArray* dictionary;
// Release callback
void (*release)(struct ArrowArray*);
// Opaque producer-specific data
void* private_data;
};
#ifdef __cplusplus
}
#endif

View File

@ -1000,6 +1000,8 @@ template DMatrix *
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> *adapter,
float missing, int nthread, const std::string &cache_prefix);
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
SparsePage transpose;

View File

@ -249,5 +249,70 @@ template SimpleDMatrix::SimpleDMatrix(
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
*adapter,
float missing, int nthread);
template <>
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) {
auto& offset_vec = sparse_page_->offset.HostVector();
auto& data_vec = sparse_page_->data.HostVector();
uint64_t total_batch_size = 0;
uint64_t total_elements = 0;
adapter->BeforeFirst();
// Iterate over batches of input data
while (adapter->Next()) {
auto& batches = adapter->Value();
size_t num_elements = 0;
size_t num_rows = 0;
// Import Arrow RecordBatches
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(nthread)
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
num_elements += batches[i]->Import(missing);
num_rows += batches[i]->Size();
}
total_elements += num_elements;
total_batch_size += num_rows;
// Compute global offset for every row and starting row for every batch
std::vector<uint64_t> batch_offsets(batches.size());
for (size_t i = 0; i < batches.size(); ++i) {
if (i == 0) {
batch_offsets[i] = total_batch_size - num_rows;
batches[i]->ShiftRowOffsets(total_elements - num_elements);
} else {
batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size();
batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back());
}
}
// Pre-allocate DMatrix memory
data_vec.resize(total_elements);
offset_vec.resize(total_batch_size + 1);
// Copy data into DMatrix
#pragma omp parallel num_threads(nthread)
{
#pragma omp for nowait
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
size_t begin = batches[i]->RowOffsets()[0];
for (size_t k = 0; k < batches[i]->Size(); ++k) {
for (size_t j = 0; j < batches[i]->NumColumns(); ++j) {
auto element = batches[i]->GetColumn(j).GetElement(k);
if (!std::isnan(element.value)) {
data_vec[begin++] = Entry(element.column_idx, element.value);
}
}
}
}
#pragma omp for nowait
for (int i = 0; i < static_cast<int>(batches.size()); ++i) {
auto& offsets = batches[i]->RowOffsets();
std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1);
}
}
}
// Synchronise worker columns
info_.num_col_ = adapter->NumColumns();
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
info_.num_row_ = total_batch_size;
info_.num_nonzero_ = data_vec.size();
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
}
} // namespace data
} // namespace xgboost

View File

@ -26,6 +26,8 @@ dependencies:
- awscli
- numba
- llvmlite
- cffi
- pyarrow
- pip:
- shap
- awscli

View File

@ -33,6 +33,8 @@ dependencies:
- numba
- llvmlite
- py-ubjson
- cffi
- pyarrow
- pip:
- shap
- ipython # required by shap at import time.

View File

@ -33,6 +33,8 @@ dependencies:
- boto3
- awscli
- py-ubjson
- cffi
- pyarrow
- pip:
- sphinx_rtd_theme
- datatable

View File

@ -15,7 +15,8 @@ dependencies:
- pytest
- jsonschema
- hypothesis
- jsonschema
- python-graphviz
- pip
- py-ubjson
- cffi
- pyarrow

View File

@ -17,3 +17,5 @@ dependencies:
- modin-ray
- pip
- py-ubjson
- cffi
- pyarrow

View File

@ -0,0 +1,88 @@
import unittest
import pytest
import numpy as np
import testing as tm
import xgboost as xgb
import os
try:
import pyarrow as pa
import pyarrow.csv as pc
import pandas as pd
except ImportError:
pass
pytestmark = pytest.mark.skipif(
tm.no_arrow()["condition"] or tm.no_pandas()["condition"],
reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"],
)
dpath = "demo/data/"
class TestArrowTable(unittest.TestCase):
def test_arrow_table(self):
df = pd.DataFrame(
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
)
table = pa.Table.from_pandas(df)
dm = xgb.DMatrix(table)
assert dm.num_row() == 2
assert dm.num_col() == 4
def test_arrow_table_with_label(self):
df = pd.DataFrame([[1, 2.0, 3.0], [2, 3.0, 4.0]], columns=["a", "b", "c"])
table = pa.Table.from_pandas(df)
label = np.array([0, 1])
dm = xgb.DMatrix(table)
dm.set_label(label)
assert dm.num_row() == 2
assert dm.num_col() == 3
np.testing.assert_array_equal(dm.get_label(), np.array([0, 1]))
def test_arrow_table_from_np(self):
coldata = np.array(
[[1.0, 1.0, 0.0, 0.0], [2.0, 0.0, 1.0, 0.0], [3.0, 0.0, 0.0, 1.0]]
)
cols = list(map(pa.array, coldata))
table = pa.Table.from_arrays(cols, ["a", "b", "c"])
dm = xgb.DMatrix(table)
assert dm.num_row() == 4
assert dm.num_col() == 3
def test_arrow_train(self):
import pandas as pd
rows = 100
X = pd.DataFrame(
{
"A": np.random.randint(0, 10, size=rows),
"B": np.random.randn(rows),
"C": np.random.permutation([1, 0] * (rows // 2)),
}
)
y = pd.Series(np.random.randn(rows))
table = pa.Table.from_pandas(X)
dtrain1 = xgb.DMatrix(table)
dtrain1.set_label(y)
bst1 = xgb.train({}, dtrain1, num_boost_round=10)
preds1 = bst1.predict(xgb.DMatrix(X))
dtrain2 = xgb.DMatrix(X, y)
bst2 = xgb.train({}, dtrain2, num_boost_round=10)
preds2 = bst2.predict(xgb.DMatrix(X))
np.testing.assert_allclose(preds1, preds2)
def test_arrow_survival(self):
data = os.path.join(tm.PROJECT_ROOT, "demo", "data", "veterans_lung_cancer.csv")
table = pc.read_csv(data)
y_lower_bound = table["Survival_label_lower_bound"]
y_upper_bound = table["Survival_label_upper_bound"]
X = table.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"])
dtrain = xgb.DMatrix(
X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
)
y_np_up = dtrain.get_float_info("label_upper_bound")
y_np_low = dtrain.get_float_info("label_lower_bound")
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)

View File

@ -53,6 +53,15 @@ def no_pandas():
'reason': 'Pandas is not installed.'}
def no_arrow():
reason = "pyarrow is not installed"
try:
import pyarrow # noqa
return {"condition": False, "reason": reason}
except ImportError:
return {"condition": True, "reason": reason}
def no_modin():
reason = 'Modin is not installed.'
try: