Support column major array. (#6765)

This commit is contained in:
Jiaming Yuan 2021-03-20 05:19:46 +08:00 committed by GitHub
parent f6fe15d11f
commit 4ee8340e79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 181 additions and 151 deletions

View File

@ -432,7 +432,6 @@ def _transform_cupy_array(data):
data, '__array__'): data, '__array__'):
import cupy # pylint: disable=import-error import cupy # pylint: disable=import-error
data = cupy.array(data, copy=False) data = cupy.array(data, copy=False)
data = data.astype(dtype=data.dtype, order='C', copy=False)
return data return data

View File

@ -234,6 +234,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
class Line { class Line {
ArrayInterface array_interface_; ArrayInterface array_interface_;
size_t ridx_; size_t ridx_;
public: public:
Line(ArrayInterface array_interface, size_t ridx) Line(ArrayInterface array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {} : array_interface_{std::move(array_interface)}, ridx_{ridx} {}
@ -241,15 +242,14 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
size_t Size() const { return array_interface_.num_cols; } size_t Size() const { return array_interface_.num_cols; }
COOTuple GetElement(size_t idx) const { COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_.GetElement(idx)}; return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
} }
}; };
public: public:
ArrayAdapterBatch() = default; ArrayAdapterBatch() = default;
Line const GetLine(size_t idx) const { Line const GetLine(size_t idx) const {
auto line = array_interface_.SliceRow(idx); return Line{array_interface_, idx};
return Line{line, idx};
} }
explicit ArrayAdapterBatch(ArrayInterface array_interface) explicit ArrayAdapterBatch(ArrayInterface array_interface)
@ -286,14 +286,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indices_; ArrayInterface indices_;
ArrayInterface values_; ArrayInterface values_;
size_t ridx_; size_t ridx_;
size_t offset_;
public: public:
Line(ArrayInterface indices, ArrayInterface values, size_t ridx) Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {} size_t offset)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
offset_{offset} {}
COOTuple GetElement(size_t idx) const { COOTuple GetElement(size_t idx) const {
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)}; return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
values_.GetElement(offset_ + idx, 0)};
} }
size_t Size() const { size_t Size() const {
return values_.num_rows * values_.num_cols; return values_.num_rows * values_.num_cols;
} }
@ -304,7 +309,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices, CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values) ArrayInterface values)
: indptr_{std::move(indptr)}, indices_{std::move(indices)}, : indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {} values_{std::move(values)} {
indptr_.AsColumnVector();
values_.AsColumnVector();
indices_.AsColumnVector();
}
size_t Size() const { size_t Size() const {
size_t size = indptr_.num_rows * indptr_.num_cols; size_t size = indptr_.num_rows * indptr_.num_cols;
@ -313,15 +322,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
} }
Line const GetLine(size_t idx) const { Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx); auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
auto end_offset = indptr_.GetElement<size_t>(idx + 1); auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);
auto indices = indices_.SliceOffset(begin_offset);
auto values = values_.SliceOffset(begin_offset); auto indices = indices_;
auto values = values_;
values.num_cols = end_offset - begin_offset; values.num_cols = end_offset - begin_offset;
values.num_rows = 1; values.num_rows = 1;
indices.num_cols = values.num_cols; indices.num_cols = values.num_cols;
indices.num_rows = values.num_rows; indices.num_rows = values.num_rows;
return Line{indices, values, idx};
return Line{indices, values, idx, begin_offset};
} }
}; };

View File

@ -6,6 +6,7 @@
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_ #ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm>
#include <cinttypes> #include <cinttypes>
#include <map> #include <map>
#include <string> #include <string>
@ -40,7 +41,7 @@ struct ArrayInterfaceErrors {
return str.c_str(); return str.c_str();
} }
static char const* Version() { static char const* Version() {
return "Only version 1 of `__cuda_array_interface__' is supported."; return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
} }
static char const* OfType(std::string const& type) { static char const* OfType(std::string const& type) {
static std::string str; static std::string str;
@ -191,43 +192,46 @@ class ArrayInterfaceHandler {
std::map<std::string, Json> const& column) { std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape")); auto j_shape = get<Array const>(column.at("shape"));
auto typestr = get<String const>(column.at("typestr")); auto typestr = get<String const>(column.at("typestr"));
if (column.find("strides") != column.cend()) {
if (!IsA<Null>(column.at("strides"))) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), j_shape.size())
<< ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
<< ArrayInterfaceErrors::Contigious();
}
}
if (j_shape.size() == 1) { if (j_shape.size() == 1) {
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1}; return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
} else { } else {
CHECK_EQ(j_shape.size(), 2) CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported.";
<< "Only 1D or 2-D arrays currently supported.";
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))}; static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
} }
} }
template <typename T> static void ExtractStride(std::map<std::string, Json> const &column,
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) { size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
strides[0] = cols;
strides[1] = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
}
strides[1] = n;
}
auto valid = (rows - 1) * strides[0] + (cols - 1) * strides[1] == (rows * cols) - 1;
CHECK(valid) << "Invalid strides in array.";
}
static void* ExtractData(std::map<std::string, Json> const &column,
StringView typestr,
std::pair<size_t, size_t> shape) {
Validate(column); Validate(column);
void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
auto typestr = get<String const>(column.at("typestr"));
CHECK_EQ(typestr.at(1), TypeChar<T>())
<< "Input data type and typestr mismatch. typestr: " << typestr;
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
<< "Input data type and typestr mismatch. typestr: " << typestr;
auto shape = ExtractShape(column);
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
if (!p_data) { if (!p_data) {
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape."; CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
} }
return common::Span<T>{p_data, shape.first * shape.second}; return p_data;
} }
}; };
@ -236,11 +240,15 @@ class ArrayInterface {
void Initialize(std::map<std::string, Json> const &column, void Initialize(std::map<std::string, Json> const &column,
bool allow_mask = true) { bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column); ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column); auto typestr = get<String const>(column.at("typestr"));
this->AssignType(StringView{typestr});
auto shape = ArrayInterfaceHandler::ExtractShape(column); auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first; num_rows = shape.first;
num_cols = shape.second; num_cols = shape.second;
data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);
if (allow_mask) { if (allow_mask) {
common::Span<RBitField8::value_type> s_mask; common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask); size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
@ -257,8 +265,8 @@ class ArrayInterface {
<< "Masked array is not yet supported."; << "Masked array is not yet supported.";
} }
auto typestr = get<String const>(column.at("typestr")); ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
this->AssignType(StringView{typestr}); typestr[2] - '0');
} }
public: public:
@ -288,6 +296,15 @@ class ArrayInterface {
} }
} }
void AsColumnVector() {
CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix.";
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;
strides[0] = std::max(strides[0], strides[1]);
strides[1] = 1;
}
void AssignType(StringView typestr) { void AssignType(StringView typestr) {
if (typestr[1] == 'f' && typestr[2] == '4') { if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4; type = kF4;
@ -320,95 +337,45 @@ class ArrayInterface {
switch (type) { switch (type) {
case kF4: case kF4:
return func(reinterpret_cast<float *>(data)); return func(reinterpret_cast<float *>(data));
break;
case kF8: case kF8:
return func(reinterpret_cast<double *>(data)); return func(reinterpret_cast<double *>(data));
break;
case kI1: case kI1:
return func(reinterpret_cast<int8_t *>(data)); return func(reinterpret_cast<int8_t *>(data));
break;
case kI2: case kI2:
return func(reinterpret_cast<int16_t *>(data)); return func(reinterpret_cast<int16_t *>(data));
break;
case kI4: case kI4:
return func(reinterpret_cast<int32_t *>(data)); return func(reinterpret_cast<int32_t *>(data));
break;
case kI8: case kI8:
return func(reinterpret_cast<int64_t *>(data)); return func(reinterpret_cast<int64_t *>(data));
break;
case kU1: case kU1:
return func(reinterpret_cast<uint8_t *>(data)); return func(reinterpret_cast<uint8_t *>(data));
break;
case kU2: case kU2:
return func(reinterpret_cast<uint16_t *>(data)); return func(reinterpret_cast<uint16_t *>(data));
break;
case kU4: case kU4:
return func(reinterpret_cast<uint32_t *>(data)); return func(reinterpret_cast<uint32_t *>(data));
break;
case kU8: case kU8:
return func(reinterpret_cast<uint64_t *>(data)); return func(reinterpret_cast<uint64_t *>(data));
break;
} }
SPAN_CHECK(false); SPAN_CHECK(false);
return func(reinterpret_cast<uint64_t *>(data)); return func(reinterpret_cast<uint64_t *>(data));
} }
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values{nullptr};
this->DispatchCall([&p_values, offset](auto *ptr) {
p_values = ptr + offset;
});
ArrayInterface ret = *this;
ret.data = p_values;
return ret;
}
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
size_t offset = idx * num_cols;
auto ret = this->SliceOffset(offset);
ret.num_rows = 1;
return ret;
}
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
switch (type) {
case kF4:
return reinterpret_cast<float*>(data)[idx];
case kF8:
return reinterpret_cast<double*>(data)[idx];
case kI1:
return reinterpret_cast<int8_t*>(data)[idx];
case kI2:
return reinterpret_cast<int16_t*>(data)[idx];
case kI4:
return reinterpret_cast<int32_t*>(data)[idx];
case kI8:
return reinterpret_cast<int64_t*>(data)[idx];
case kU1:
return reinterpret_cast<uint8_t*>(data)[idx];
case kU2:
return reinterpret_cast<uint16_t*>(data)[idx];
case kU4:
return reinterpret_cast<uint32_t*>(data)[idx];
case kU8:
return reinterpret_cast<uint64_t*>(data)[idx];
}
SPAN_CHECK(false);
return reinterpret_cast<float*>(data)[idx];
}
XGBOOST_DEVICE size_t ElementSize() { XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall([](auto* p_values) { return this->DispatchCall([](auto* p_values) {
return sizeof(std::remove_pointer_t<decltype(p_values)>); return sizeof(std::remove_pointer_t<decltype(p_values)>);
}); });
} }
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
}
RBitField8 valid; RBitField8 valid;
bst_row_t num_rows; bst_row_t num_rows;
bst_feature_t num_cols; bst_feature_t num_cols;
size_t strides[2]{0, 0};
void* data; void* data;
Type type; Type type;

View File

@ -30,7 +30,7 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
p_dst[idx] = column.GetElement(idx); p_dst[idx] = column.GetElement(idx, 0);
}); });
} }
@ -53,7 +53,7 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
auto d_tmp = temp.data(); auto d_tmp = temp.data();
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement<size_t>(idx); d_tmp[idx] = column.GetElement<size_t>(idx, 0);
}); });
auto length = column.num_rows; auto length = column.num_rows;
out->resize(length + 1); out->resize(length + 1);
@ -62,6 +62,50 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
std::partial_sum(out->begin(), out->end(), out->begin()); std::partial_sum(out->begin(), out->end(), out->begin());
} }
void CopyQidImpl(ArrayInterface array_interface,
std::vector<bst_group_t> *p_group_ptr) {
auto &group_ptr_ = *p_group_ptr;
auto it = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) {
return array_interface.GetElement<uint32_t>(i, 0);
});
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (array_interface.GetElement<uint32_t>(i, 0) >
array_interface.GetElement<uint32_t>(i + 1, 0)) {
d_flag[0] = false;
}
});
bool non_dec = true;
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost));
CHECK(non_dec) << "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
HostDeviceVector<int> d_num_runs_out(1, 0, d);
cub::DeviceRunLengthEncode::Encode(
nullptr, bytes, it, out.begin(), cnt.begin(),
d_num_runs_out.DevicePointer(), array_interface.num_rows);
dh::caching_device_vector<char> tmp(bytes);
cub::DeviceRunLengthEncode::Encode(
tmp.data().get(), bytes, it, out.begin(), cnt.begin(),
d_num_runs_out.DevicePointer(), array_interface.num_rows);
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
group_ptr_.clear();
group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
}
namespace { namespace {
// thrust::all_of tries to copy lambda function. // thrust::all_of tries to copy lambda function.
struct AllOfOp { struct AllOfOp {
@ -78,10 +122,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1); << "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(interface_str); ArrayInterface array_interface(interface_str);
std::string key{c_key}; std::string key{c_key};
array_interface.AsColumnVector();
CHECK(!array_interface.valid.Data()) CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask"; << "Meta info " << key << " should be dense, found validity mask";
CHECK_EQ(array_interface.num_cols, 1)
<< "Meta info should be a single column.";
if (array_interface.num_rows == 0) { if (array_interface.num_rows == 0) {
return; return;
} }
@ -100,45 +143,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
CopyGroupInfoImpl(array_interface, &group_ptr_); CopyGroupInfoImpl(array_interface, &group_ptr_);
return; return;
} else if (key == "qid") { } else if (key == "qid") {
auto it = dh::MakeTransformIterator<uint32_t>( CopyQidImpl(array_interface, &group_ptr_);
thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) {
return array_interface.GetElement<uint32_t>(i);
});
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (array_interface.GetElement<uint32_t>(i) >
array_interface.GetElement<uint32_t>(i + 1)) {
d_flag[0] = false;
}
});
bool non_dec = true;
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost));
CHECK(non_dec)
<< "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
HostDeviceVector<int> d_num_runs_out(1, 0, d);
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
dh::caching_device_vector<char> tmp(bytes);
cub::DeviceRunLengthEncode::Encode(tmp.data().get(), bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
return; return;
} else if (key == "label_lower_bound") { } else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_); CopyInfoImpl(array_interface, &labels_lower_bound_);

View File

@ -47,7 +47,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
size_t row_idx = idx / columns_.size(); size_t row_idx = idx / columns_.size();
auto const& column = columns_[column_idx]; auto const& column = columns_[column_idx];
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
? column.GetElement(row_idx) ? column.GetElement(row_idx, 0)
: std::numeric_limits<float>::quiet_NaN(); : std::numeric_limits<float>::quiet_NaN();
return {row_idx, column_idx, value}; return {row_idx, column_idx, value};
} }
@ -170,7 +170,7 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
__device__ COOTuple GetElement(size_t idx) const { __device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = idx % array_interface_.num_cols; size_t column_idx = idx % array_interface_.num_cols;
size_t row_idx = idx / array_interface_.num_cols; size_t row_idx = idx / array_interface_.num_cols;
float value = array_interface_.GetElement(idx); float value = array_interface_.GetElement(row_idx, column_idx);
return {row_idx, column_idx, value}; return {row_idx, column_idx, value};
} }

View File

@ -138,5 +138,4 @@ TEST(Adapter, IteratorAdapter) {
} }
ASSERT_EQ(num_batch, 1); ASSERT_EQ(num_batch, 1);
} }
} // namespace xgboost } // namespace xgboost

View File

@ -38,17 +38,28 @@ TEST(ArrayInterface, Error) {
Json(Boolean(false))}; Json(Boolean(false))};
auto const& column_obj = get<Object>(column); auto const& column_obj = get<Object>(column);
std::pair<size_t, size_t> shape{kRows, kCols};
std::string typestr{"<f4"};
// missing version // missing version
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error); EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
StringView{typestr}, shape),
dmlc::Error);
column["version"] = Integer(static_cast<Integer::Int>(1)); column["version"] = Integer(static_cast<Integer::Int>(1));
// missing data // missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error); EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
StringView{typestr}, shape),
dmlc::Error);
column["data"] = j_data; column["data"] = j_data;
// missing typestr // missing typestr
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error); EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
StringView{typestr}, shape),
dmlc::Error);
column["typestr"] = String("<f4"); column["typestr"] = String("<f4");
// nullptr is not valid // nullptr is not valid
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error); EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
StringView{typestr}, shape),
dmlc::Error);
HostDeviceVector<float> storage; HostDeviceVector<float> storage;
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
@ -56,7 +67,23 @@ TEST(ArrayInterface, Error) {
Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))), Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))),
Json(Boolean(false))}; Json(Boolean(false))};
column["data"] = j_data; column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj)); EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(
column_obj, StringView{typestr}, shape));
} }
TEST(ArrayInterface, GetElement) {
size_t kRows = 4, kCols = 2;
HostDeviceVector<float> storage;
auto intefrace_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
ArrayInterface array_interface{intefrace_str};
auto const& h_storage = storage.ConstHostVector();
for (size_t i = 0; i < kRows; ++i) {
for (size_t j = 0; j < kCols; ++j) {
float v0 = array_interface.GetElement(i, j);
float v1 = h_storage.at(i * kCols + j);
ASSERT_EQ(v0, v1);
}
}
}
} // namespace xgboost } // namespace xgboost

View File

@ -210,9 +210,13 @@ class TestGPUPredict:
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix) cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
def predict_df(x): def predict_df(x):
inplace_predt = booster.inplace_predict(x) # column major array
inplace_predt = booster.inplace_predict(x.values)
d = xgb.DMatrix(x) d = xgb.DMatrix(x)
copied_predt = cp.array(booster.predict(d)) copied_predt = cp.array(booster.predict(d))
assert cp.all(copied_predt == inplace_predt)
inplace_predt = booster.inplace_predict(x)
return cp.all(copied_predt == inplace_predt) return cp.all(copied_predt == inplace_predt)
for i in range(10): for i in range(10):

View File

@ -2,7 +2,10 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
from scipy import sparse from scipy import sparse
import pytest
import pandas as pd
import testing as tm
import xgboost as xgb import xgboost as xgb
@ -147,6 +150,19 @@ class TestInplacePredict:
for i in range(10): for i in range(10):
run_threaded_predict(X, self.rows, predict_csr) run_threaded_predict(X, self.rows, predict_csr)
@pytest.mark.skipif(**tm.no_pandas())
def test_predict_pd(self):
X = self.X
# construct it in column major style
df = pd.DataFrame({str(i): X[:, i] for i in range(X.shape[1])})
booster = self.booster
df_predt = booster.inplace_predict(df)
arr_predt = booster.inplace_predict(X)
dmat_predt = booster.predict(xgb.DMatrix(X))
np.testing.assert_allclose(dmat_predt, arr_predt)
np.testing.assert_allclose(df_predt, arr_predt)
def test_base_margin(self): def test_base_margin(self):
booster = self.booster booster = self.booster