Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -52,10 +52,10 @@ TEST(MetaInfo, FromInterface) {
|
||||
MetaInfo info;
|
||||
info.SetInfo("label", str.c_str());
|
||||
|
||||
auto const& h_label = info.labels_.HostVector();
|
||||
ASSERT_EQ(h_label.size(), d_data.size());
|
||||
auto const& h_label = info.labels.HostView();
|
||||
ASSERT_EQ(h_label.Size(), d_data.size());
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
ASSERT_EQ(h_label[i], d_data[i]);
|
||||
ASSERT_EQ(h_label(i), d_data[i]);
|
||||
}
|
||||
|
||||
info.SetInfo("weight", str.c_str());
|
||||
@@ -147,15 +147,15 @@ TEST(MetaInfo, DeviceExtend) {
|
||||
std::string str = PrepareData<float>("<f4", &d_data, kRows);
|
||||
lhs.SetInfo("label", str.c_str());
|
||||
rhs.SetInfo("label", str.c_str());
|
||||
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.num_row_ = kRows;
|
||||
|
||||
lhs.Extend(rhs, true, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(lhs.labels.Data()->HostCanRead());
|
||||
|
||||
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(lhs.labels.Data()->HostCanRead());
|
||||
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user