Convert labels into tensor. (#7456)

* Add a new ctor to tensor for `initilizer_list`.
* Change labels from host device vector to tensor.
* Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
Jiaming Yuan
2021-12-17 00:58:35 +08:00
committed by GitHub
parent 6f8a4633b7
commit 5b1161bb64
35 changed files with 319 additions and 258 deletions

View File

@@ -52,10 +52,10 @@ TEST(MetaInfo, FromInterface) {
MetaInfo info;
info.SetInfo("label", str.c_str());
auto const& h_label = info.labels_.HostVector();
ASSERT_EQ(h_label.size(), d_data.size());
auto const& h_label = info.labels.HostView();
ASSERT_EQ(h_label.Size(), d_data.size());
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_label[i], d_data[i]);
ASSERT_EQ(h_label(i), d_data[i]);
}
info.SetInfo("weight", str.c_str());
@@ -147,15 +147,15 @@ TEST(MetaInfo, DeviceExtend) {
std::string str = PrepareData<float>("<f4", &d_data, kRows);
lhs.SetInfo("label", str.c_str());
rhs.SetInfo("label", str.c_str());
ASSERT_FALSE(rhs.labels_.HostCanRead());
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
lhs.num_row_ = kRows;
rhs.num_row_ = kRows;
lhs.Extend(rhs, true, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_FALSE(lhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels.Data()->HostCanRead());
ASSERT_FALSE(lhs.labels_.HostCanRead());
ASSERT_FALSE(rhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels.Data()->HostCanRead());
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
}
} // namespace xgboost