Support dmatrix construction from cupy array (#5206)
This commit is contained in:
@@ -9,8 +9,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out) {
|
||||
constexpr size_t kRows = 16;
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, const size_t kRows=16) {
|
||||
out->resize(kRows);
|
||||
auto& d_data = *out;
|
||||
|
||||
@@ -66,7 +65,15 @@ TEST(MetaInfo, FromInterface) {
|
||||
ASSERT_EQ(h_base_margin[i], d_data[i]);
|
||||
}
|
||||
|
||||
EXPECT_ANY_THROW({info.SetInfo("group", str.c_str());});
|
||||
thrust::device_vector<int> d_group_data;
|
||||
std::string group_str = PrepareData<int>("<i4", &d_group_data, 4);
|
||||
d_group_data[0] = 4;
|
||||
d_group_data[1] = 3;
|
||||
d_group_data[2] = 2;
|
||||
d_group_data[3] = 1;
|
||||
info.SetInfo("group", group_str.c_str());
|
||||
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
|
||||
EXPECT_EQ(info.group_ptr_, expected_group_ptr);
|
||||
}
|
||||
|
||||
TEST(MetaInfo, Group) {
|
||||
@@ -83,4 +90,4 @@ TEST(MetaInfo, Group) {
|
||||
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user