Implement extend method for meta info. (#5800)
* Implement extend for host device vector.
This commit is contained in:
@@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) {
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
TEST(MetaInfo, HostExtend) {
|
||||
xgboost::MetaInfo lhs, rhs;
|
||||
size_t const kRows = 100;
|
||||
lhs.labels_.Resize(kRows);
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.labels_.Resize(kRows);
|
||||
rhs.num_row_ = kRows;
|
||||
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||
|
||||
size_t per_group = 10;
|
||||
std::vector<xgboost::bst_group_t> groups;
|
||||
for (size_t g = 0; g < kRows / per_group; ++g) {
|
||||
groups.emplace_back(per_group);
|
||||
}
|
||||
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
|
||||
lhs.Extend(rhs, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
|
||||
ASSERT_FALSE(rhs.labels_.DeviceCanRead());
|
||||
|
||||
ASSERT_EQ(lhs.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
|
||||
for (size_t i = 0; i < kRows * 2 / per_group; ++i) {
|
||||
ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,6 @@ TEST(MetaInfo, FromInterface) {
|
||||
|
||||
TEST(MetaInfo, Group) {
|
||||
cudaSetDevice(0);
|
||||
|
||||
MetaInfo info;
|
||||
|
||||
thrust::device_vector<uint32_t> d_uint;
|
||||
@@ -105,4 +104,25 @@ TEST(MetaInfo, Group) {
|
||||
info = MetaInfo();
|
||||
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
||||
}
|
||||
|
||||
TEST(MetaInfo, DeviceExtend) {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
size_t const kRows = 100;
|
||||
MetaInfo lhs, rhs;
|
||||
|
||||
thrust::device_vector<float> d_data;
|
||||
std::string str = PrepareData<float>("<f4", &d_data, kRows);
|
||||
lhs.SetInfo("label", str.c_str());
|
||||
rhs.SetInfo("label", str.c_str());
|
||||
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.num_row_ = kRows;
|
||||
|
||||
lhs.Extend(rhs, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||
|
||||
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user