Support column major array. (#6765)
This commit is contained in:
parent
f6fe15d11f
commit
4ee8340e79
@ -432,7 +432,6 @@ def _transform_cupy_array(data):
|
||||
data, '__array__'):
|
||||
import cupy # pylint: disable=import-error
|
||||
data = cupy.array(data, copy=False)
|
||||
data = data.astype(dtype=data.dtype, order='C', copy=False)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@ -234,6 +234,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
@ -241,15 +242,14 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
return Line{array_interface_, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
@ -286,14 +286,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
|
||||
size_t offset)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
|
||||
offset_{offset} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
|
||||
values_.GetElement(offset_ + idx, 0)};
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
@ -304,7 +309,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
values_{std::move(values)} {
|
||||
indptr_.AsColumnVector();
|
||||
values_.AsColumnVector();
|
||||
indices_.AsColumnVector();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
@ -313,15 +322,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);
|
||||
|
||||
auto indices = indices_;
|
||||
auto values = values_;
|
||||
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
|
||||
return Line{indices, values, idx, begin_offset};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <map>
|
||||
#include <string>
|
||||
@ -40,7 +41,7 @@ struct ArrayInterfaceErrors {
|
||||
return str.c_str();
|
||||
}
|
||||
static char const* Version() {
|
||||
return "Only version 1 of `__cuda_array_interface__' is supported.";
|
||||
return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
|
||||
}
|
||||
static char const* OfType(std::string const& type) {
|
||||
static std::string str;
|
||||
@ -191,43 +192,46 @@ class ArrayInterfaceHandler {
|
||||
std::map<std::string, Json> const& column) {
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
if (column.find("strides") != column.cend()) {
|
||||
if (!IsA<Null>(column.at("strides"))) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), j_shape.size())
|
||||
<< ArrayInterfaceErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
|
||||
<< ArrayInterfaceErrors::Contigious();
|
||||
}
|
||||
}
|
||||
|
||||
if (j_shape.size() == 1) {
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||
} else {
|
||||
CHECK_EQ(j_shape.size(), 2)
|
||||
<< "Only 1D or 2-D arrays currently supported.";
|
||||
CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported.";
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
static void ExtractStride(std::map<std::string, Json> const &column,
|
||||
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
|
||||
auto strides_it = column.find("strides");
|
||||
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
|
||||
// default strides
|
||||
strides[0] = cols;
|
||||
strides[1] = 1;
|
||||
} else {
|
||||
// strides specified by the array interface
|
||||
auto const &j_strides = get<Array const>(strides_it->second);
|
||||
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
|
||||
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
|
||||
size_t n = 1;
|
||||
if (j_strides.size() == 2) {
|
||||
n = get<Integer const>(j_strides[1]) / itemsize;
|
||||
}
|
||||
strides[1] = n;
|
||||
}
|
||||
auto valid = (rows - 1) * strides[0] + (cols - 1) * strides[1] == (rows * cols) - 1;
|
||||
CHECK(valid) << "Invalid strides in array.";
|
||||
}
|
||||
|
||||
static void* ExtractData(std::map<std::string, Json> const &column,
|
||||
StringView typestr,
|
||||
std::pair<size_t, size_t> shape) {
|
||||
Validate(column);
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
CHECK_EQ(typestr.at(1), TypeChar<T>())
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
|
||||
auto shape = ExtractShape(column);
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
if (!p_data) {
|
||||
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
|
||||
}
|
||||
return common::Span<T>{p_data, shape.first * shape.second};
|
||||
return p_data;
|
||||
}
|
||||
};
|
||||
|
||||
@ -236,11 +240,15 @@ class ArrayInterface {
|
||||
void Initialize(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
ArrayInterfaceHandler::Validate(column);
|
||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
this->AssignType(StringView{typestr});
|
||||
|
||||
auto shape = ArrayInterfaceHandler::ExtractShape(column);
|
||||
num_rows = shape.first;
|
||||
num_cols = shape.second;
|
||||
|
||||
data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);
|
||||
|
||||
if (allow_mask) {
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||
@ -257,8 +265,8 @@ class ArrayInterface {
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
this->AssignType(StringView{typestr});
|
||||
ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
|
||||
typestr[2] - '0');
|
||||
}
|
||||
|
||||
public:
|
||||
@ -288,6 +296,15 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
void AsColumnVector() {
|
||||
CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix.";
|
||||
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
|
||||
num_cols = 1;
|
||||
|
||||
strides[0] = std::max(strides[0], strides[1]);
|
||||
strides[1] = 1;
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
@ -320,95 +337,45 @@ class ArrayInterface {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return func(reinterpret_cast<float *>(data));
|
||||
break;
|
||||
case kF8:
|
||||
return func(reinterpret_cast<double *>(data));
|
||||
break;
|
||||
case kI1:
|
||||
return func(reinterpret_cast<int8_t *>(data));
|
||||
break;
|
||||
case kI2:
|
||||
return func(reinterpret_cast<int16_t *>(data));
|
||||
break;
|
||||
case kI4:
|
||||
return func(reinterpret_cast<int32_t *>(data));
|
||||
break;
|
||||
case kI8:
|
||||
return func(reinterpret_cast<int64_t *>(data));
|
||||
break;
|
||||
case kU1:
|
||||
return func(reinterpret_cast<uint8_t *>(data));
|
||||
break;
|
||||
case kU2:
|
||||
return func(reinterpret_cast<uint16_t *>(data));
|
||||
break;
|
||||
case kU4:
|
||||
return func(reinterpret_cast<uint32_t *>(data));
|
||||
break;
|
||||
case kU8:
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
break;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values{nullptr};
|
||||
this->DispatchCall([&p_values, offset](auto *ptr) {
|
||||
p_values = ptr + offset;
|
||||
});
|
||||
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE size_t ElementSize() {
|
||||
return this->DispatchCall([](auto* p_values) {
|
||||
return sizeof(std::remove_pointer_t<decltype(p_values)>);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
|
||||
return this->DispatchCall(
|
||||
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
size_t strides[2]{0, 0};
|
||||
void* data;
|
||||
|
||||
Type type;
|
||||
|
||||
@ -30,7 +30,7 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
|
||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
p_dst[idx] = column.GetElement(idx);
|
||||
p_dst[idx] = column.GetElement(idx, 0);
|
||||
});
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
auto d_tmp = temp.data();
|
||||
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
d_tmp[idx] = column.GetElement<size_t>(idx);
|
||||
d_tmp[idx] = column.GetElement<size_t>(idx, 0);
|
||||
});
|
||||
auto length = column.num_rows;
|
||||
out->resize(length + 1);
|
||||
@ -62,6 +62,50 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
std::partial_sum(out->begin(), out->end(), out->begin());
|
||||
}
|
||||
|
||||
void CopyQidImpl(ArrayInterface array_interface,
|
||||
std::vector<bst_group_t> *p_group_ptr) {
|
||||
auto &group_ptr_ = *p_group_ptr;
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return array_interface.GetElement<uint32_t>(i, 0);
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (array_interface.GetElement<uint32_t>(i, 0) >
|
||||
array_interface.GetElement<uint32_t>(i + 1, 0)) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
bool non_dec = true;
|
||||
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
|
||||
cudaMemcpyDeviceToHost));
|
||||
CHECK(non_dec) << "`qid` must be sorted in increasing order along with data.";
|
||||
size_t bytes = 0;
|
||||
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
|
||||
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
|
||||
HostDeviceVector<int> d_num_runs_out(1, 0, d);
|
||||
cub::DeviceRunLengthEncode::Encode(
|
||||
nullptr, bytes, it, out.begin(), cnt.begin(),
|
||||
d_num_runs_out.DevicePointer(), array_interface.num_rows);
|
||||
dh::caching_device_vector<char> tmp(bytes);
|
||||
cub::DeviceRunLengthEncode::Encode(
|
||||
tmp.data().get(), bytes, it, out.begin(), cnt.begin(),
|
||||
d_num_runs_out.DevicePointer(), array_interface.num_rows);
|
||||
|
||||
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
|
||||
group_ptr_.clear();
|
||||
group_ptr_.resize(h_num_runs_out + 1, 0);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
|
||||
cnt.begin() + h_num_runs_out, cnt.begin());
|
||||
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
|
||||
group_ptr_.begin() + 1);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// thrust::all_of tries to copy lambda function.
|
||||
struct AllOfOp {
|
||||
@ -78,10 +122,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
||||
ArrayInterface array_interface(interface_str);
|
||||
std::string key{c_key};
|
||||
array_interface.AsColumnVector();
|
||||
CHECK(!array_interface.valid.Data())
|
||||
<< "Meta info " << key << " should be dense, found validity mask";
|
||||
CHECK_EQ(array_interface.num_cols, 1)
|
||||
<< "Meta info should be a single column.";
|
||||
if (array_interface.num_rows == 0) {
|
||||
return;
|
||||
}
|
||||
@ -100,45 +143,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
return;
|
||||
} else if (key == "qid") {
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return array_interface.GetElement<uint32_t>(i);
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (array_interface.GetElement<uint32_t>(i) >
|
||||
array_interface.GetElement<uint32_t>(i + 1)) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
bool non_dec = true;
|
||||
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
|
||||
cudaMemcpyDeviceToHost));
|
||||
CHECK(non_dec)
|
||||
<< "`qid` must be sorted in increasing order along with data.";
|
||||
size_t bytes = 0;
|
||||
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
|
||||
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
|
||||
HostDeviceVector<int> d_num_runs_out(1, 0, d);
|
||||
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(),
|
||||
cnt.begin(), d_num_runs_out.DevicePointer(),
|
||||
array_interface.num_rows);
|
||||
dh::caching_device_vector<char> tmp(bytes);
|
||||
cub::DeviceRunLengthEncode::Encode(tmp.data().get(), bytes, it, out.begin(),
|
||||
cnt.begin(), d_num_runs_out.DevicePointer(),
|
||||
array_interface.num_rows);
|
||||
|
||||
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
|
||||
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
|
||||
cnt.begin() + h_num_runs_out, cnt.begin());
|
||||
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
|
||||
group_ptr_.begin() + 1);
|
||||
CopyQidImpl(array_interface, &group_ptr_);
|
||||
return;
|
||||
} else if (key == "label_lower_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_lower_bound_);
|
||||
|
||||
@ -47,7 +47,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
size_t row_idx = idx / columns_.size();
|
||||
auto const& column = columns_[column_idx];
|
||||
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
|
||||
? column.GetElement(row_idx)
|
||||
? column.GetElement(row_idx, 0)
|
||||
: std::numeric_limits<float>::quiet_NaN();
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
@ -170,7 +170,7 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
__device__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx = idx % array_interface_.num_cols;
|
||||
size_t row_idx = idx / array_interface_.num_cols;
|
||||
float value = array_interface_.GetElement(idx);
|
||||
float value = array_interface_.GetElement(row_idx, column_idx);
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
|
||||
|
||||
@ -138,5 +138,4 @@ TEST(Adapter, IteratorAdapter) {
|
||||
}
|
||||
ASSERT_EQ(num_batch, 1);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -38,17 +38,28 @@ TEST(ArrayInterface, Error) {
|
||||
Json(Boolean(false))};
|
||||
|
||||
auto const& column_obj = get<Object>(column);
|
||||
std::pair<size_t, size_t> shape{kRows, kCols};
|
||||
std::string typestr{"<f4"};
|
||||
|
||||
// missing version
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
|
||||
StringView{typestr}, shape),
|
||||
dmlc::Error);
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
// missing data
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
|
||||
StringView{typestr}, shape),
|
||||
dmlc::Error);
|
||||
column["data"] = j_data;
|
||||
// missing typestr
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
|
||||
StringView{typestr}, shape),
|
||||
dmlc::Error);
|
||||
column["typestr"] = String("<f4");
|
||||
// nullptr is not valid
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj,
|
||||
StringView{typestr}, shape),
|
||||
dmlc::Error);
|
||||
|
||||
HostDeviceVector<float> storage;
|
||||
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||
@ -56,7 +67,23 @@ TEST(ArrayInterface, Error) {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))),
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj));
|
||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(
|
||||
column_obj, StringView{typestr}, shape));
|
||||
}
|
||||
|
||||
TEST(ArrayInterface, GetElement) {
|
||||
size_t kRows = 4, kCols = 2;
|
||||
HostDeviceVector<float> storage;
|
||||
auto intefrace_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||
ArrayInterface array_interface{intefrace_str};
|
||||
|
||||
auto const& h_storage = storage.ConstHostVector();
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
for (size_t j = 0; j < kCols; ++j) {
|
||||
float v0 = array_interface.GetElement(i, j);
|
||||
float v1 = h_storage.at(i * kCols + j);
|
||||
ASSERT_EQ(v0, v1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -210,9 +210,13 @@ class TestGPUPredict:
|
||||
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
|
||||
|
||||
def predict_df(x):
|
||||
inplace_predt = booster.inplace_predict(x)
|
||||
# column major array
|
||||
inplace_predt = booster.inplace_predict(x.values)
|
||||
d = xgb.DMatrix(x)
|
||||
copied_predt = cp.array(booster.predict(d))
|
||||
assert cp.all(copied_predt == inplace_predt)
|
||||
|
||||
inplace_predt = booster.inplace_predict(x)
|
||||
return cp.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
|
||||
@ -2,7 +2,10 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
from scipy import sparse
|
||||
import pytest
|
||||
import pandas as pd
|
||||
|
||||
import testing as tm
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
@ -147,6 +150,19 @@ class TestInplacePredict:
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, self.rows, predict_csr)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_predict_pd(self):
|
||||
X = self.X
|
||||
# construct it in column major style
|
||||
df = pd.DataFrame({str(i): X[:, i] for i in range(X.shape[1])})
|
||||
booster = self.booster
|
||||
df_predt = booster.inplace_predict(df)
|
||||
arr_predt = booster.inplace_predict(X)
|
||||
dmat_predt = booster.predict(xgb.DMatrix(X))
|
||||
|
||||
np.testing.assert_allclose(dmat_predt, arr_predt)
|
||||
np.testing.assert_allclose(df_predt, arr_predt)
|
||||
|
||||
def test_base_margin(self):
|
||||
booster = self.booster
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user