Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -16,9 +16,9 @@ TEST(MetaInfo, GetSet) {
|
||||
|
||||
double double2[2] = {1.0, 2.0};
|
||||
|
||||
EXPECT_EQ(info.labels_.Size(), 0);
|
||||
EXPECT_EQ(info.labels.Size(), 0);
|
||||
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.labels_.Size(), 2);
|
||||
EXPECT_EQ(info.labels.Size(), 2);
|
||||
|
||||
float float2[2] = {1.0f, 2.0f};
|
||||
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
||||
@@ -120,8 +120,8 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
EXPECT_EQ(inforead.num_col_, info.num_col_);
|
||||
EXPECT_EQ(inforead.num_nonzero_, info.num_nonzero_);
|
||||
|
||||
ASSERT_EQ(inforead.labels_.HostVector(), values);
|
||||
EXPECT_EQ(inforead.labels_.HostVector(), info.labels_.HostVector());
|
||||
ASSERT_EQ(inforead.labels.Data()->HostVector(), values);
|
||||
EXPECT_EQ(inforead.labels.Data()->HostVector(), info.labels.Data()->HostVector());
|
||||
EXPECT_EQ(inforead.group_ptr_, info.group_ptr_);
|
||||
EXPECT_EQ(inforead.weights_.HostVector(), info.weights_.HostVector());
|
||||
|
||||
@@ -236,8 +236,9 @@ TEST(MetaInfo, Validate) {
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
std::vector<float> labels(info.num_row_ + 1);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
EXPECT_THROW(
|
||||
{ info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); },
|
||||
dmlc::Error);
|
||||
|
||||
// Make overflow data, which can happen when users pass group structure as int
|
||||
// or float.
|
||||
@@ -254,7 +255,7 @@ TEST(MetaInfo, Validate) {
|
||||
info.group_ptr_.clear();
|
||||
labels.resize(info.num_row_);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.labels_.SetDevice(0);
|
||||
info.labels.SetDevice(0);
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
|
||||
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||
@@ -269,12 +270,12 @@ TEST(MetaInfo, Validate) {
|
||||
TEST(MetaInfo, HostExtend) {
|
||||
xgboost::MetaInfo lhs, rhs;
|
||||
size_t const kRows = 100;
|
||||
lhs.labels_.Resize(kRows);
|
||||
lhs.labels.Reshape(kRows);
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.labels_.Resize(kRows);
|
||||
rhs.labels.Reshape(kRows);
|
||||
rhs.num_row_ = kRows;
|
||||
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||
ASSERT_TRUE(lhs.labels.Data()->HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels.Data()->HostCanRead());
|
||||
|
||||
size_t per_group = 10;
|
||||
std::vector<xgboost::bst_group_t> groups;
|
||||
@@ -286,10 +287,10 @@ TEST(MetaInfo, HostExtend) {
|
||||
|
||||
lhs.Extend(rhs, true, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
|
||||
ASSERT_FALSE(rhs.labels_.DeviceCanRead());
|
||||
ASSERT_TRUE(lhs.labels.Data()->HostCanRead());
|
||||
ASSERT_TRUE(rhs.labels.Data()->HostCanRead());
|
||||
ASSERT_FALSE(lhs.labels.Data()->DeviceCanRead());
|
||||
ASSERT_FALSE(rhs.labels.Data()->DeviceCanRead());
|
||||
|
||||
ASSERT_EQ(lhs.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
|
||||
|
||||
Reference in New Issue
Block a user