Complete cudf support. (#4850)
* Handles missing value. * Accept all floating point and integer types. * Move to cudf 9.0 API. * Remove requirement on `null_count`. * Arbitrary column types support.
This commit is contained in:
@@ -7,11 +7,13 @@
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
constexpr size_t kRows = 16;
|
||||
|
||||
thrust::device_vector<float> d_data(kRows);
|
||||
template <typename T>
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out) {
|
||||
constexpr size_t kRows = 16;
|
||||
out->resize(kRows);
|
||||
auto& d_data = *out;
|
||||
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
d_data[i] = i * 2.0;
|
||||
}
|
||||
@@ -22,7 +24,7 @@ TEST(MetaInfo, FromInterface) {
|
||||
column["shape"] = Array(j_shape);
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
column["typestr"] = String("<f4");
|
||||
column["typestr"] = String(typestr);
|
||||
|
||||
auto p_d_data = dh::Raw(d_data);
|
||||
std::vector<Json> j_data {
|
||||
@@ -34,6 +36,15 @@ TEST(MetaInfo, FromInterface) {
|
||||
Json::Dump(column, &ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
thrust::device_vector<float> d_data;
|
||||
|
||||
std::string str = PrepareData<float>("<f4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
info.SetInfo("label", str.c_str());
|
||||
|
||||
@@ -53,5 +64,22 @@ TEST(MetaInfo, FromInterface) {
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
ASSERT_EQ(h_base_margin[i], d_data[i]);
|
||||
}
|
||||
|
||||
EXPECT_ANY_THROW({info.SetInfo("group", str.c_str());});
|
||||
}
|
||||
|
||||
TEST(MetaInfo, Group) {
|
||||
cudaSetDevice(0);
|
||||
thrust::device_vector<uint32_t> d_data;
|
||||
std::string str = PrepareData<uint32_t>("<u4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
|
||||
info.SetInfo("group", str.c_str());
|
||||
auto const& h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_data.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user