Implement extend method for meta info. (#5800)

* Implement extend for host device vector.
This commit is contained in:
Jiaming Yuan
2020-06-20 03:32:03 +08:00
committed by GitHub
parent a6d9a06b7b
commit c4d721200a
7 changed files with 138 additions and 1 deletions

View File

@@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) {
EXPECT_THROW(info.Validate(1), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA)
}
TEST(MetaInfo, HostExtend) {
xgboost::MetaInfo lhs, rhs;
size_t const kRows = 100;
lhs.labels_.Resize(kRows);
lhs.num_row_ = kRows;
rhs.labels_.Resize(kRows);
rhs.num_row_ = kRows;
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
size_t per_group = 10;
std::vector<xgboost::bst_group_t> groups;
for (size_t g = 0; g < kRows / per_group; ++g) {
groups.emplace_back(per_group);
}
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
lhs.Extend(rhs, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
ASSERT_FALSE(rhs.labels_.DeviceCanRead());
ASSERT_EQ(lhs.group_ptr_.front(), 0);
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
for (size_t i = 0; i < kRows * 2 / per_group; ++i) {
ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i);
}
}