From c4d721200ae134539f1c77c2d8d9dea30744309c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 20 Jun 2020 03:32:03 +0800 Subject: [PATCH] Implement extend method for meta info. (#5800) * Implement extend for host device vector. --- include/xgboost/data.h | 10 ++++++++ include/xgboost/host_device_vector.h | 5 ++++ src/common/host_device_vector.cc | 8 ++++++ src/common/host_device_vector.cu | 24 ++++++++++++++++++ src/data/data.cc | 38 ++++++++++++++++++++++++++++ tests/cpp/data/test_metainfo.cc | 32 +++++++++++++++++++++++ tests/cpp/data/test_metainfo.cu | 22 +++++++++++++++- 7 files changed, 138 insertions(+), 1 deletion(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index d7ab895a8..a0674f96b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -158,6 +158,16 @@ class MetaInfo { */ void SetInfo(const char* key, std::string const& interface_str); + /* + * \brief Extend with other MetaInfo. + * + * \param that The other MetaInfo object. + * + * \param accumulate_rows Whether rows need to be accumulated in this function. If + * client code knows number of rows in advance, set this parameter to false. + */ + void Extend(MetaInfo const& that, bool accumulate_rows); + private: /*! \brief argsort of labels */ mutable std::vector label_order_cache_; diff --git a/include/xgboost/host_device_vector.h b/include/xgboost/host_device_vector.h index 5bcceddcd..b9fb15104 100644 --- a/include/xgboost/host_device_vector.h +++ b/include/xgboost/host_device_vector.h @@ -51,6 +51,7 @@ #include #include +#include #include "span.h" @@ -83,6 +84,8 @@ enum GPUAccess { template class HostDeviceVector { + static_assert(std::is_standard_layout::value, "HostDeviceVector admits only POD types"); + public: explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1); HostDeviceVector(std::initializer_list init, int device = -1); @@ -117,6 +120,8 @@ class HostDeviceVector { void Copy(const std::vector& other); void Copy(std::initializer_list other); + void Extend(const HostDeviceVector& other); + std::vector& HostVector(); const std::vector& ConstHostVector() const; const std::vector& HostVector() const {return ConstHostVector(); } diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index 1dd9997fa..a6ee30e1f 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -136,6 +136,14 @@ void HostDeviceVector::Copy(std::initializer_list other) { std::copy(other.begin(), other.end(), HostVector().begin()); } +template +void HostDeviceVector::Extend(HostDeviceVector const& other) { + auto ori_size = this->Size(); + this->HostVector().resize(ori_size + other.Size()); + std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(), + this->HostVector().begin() + ori_size); +} + template bool HostDeviceVector::HostCanRead() const { return true; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 6afdadf39..7950096ca 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -125,6 +125,25 @@ class HostDeviceVectorImpl { } } + void Extend(HostDeviceVectorImpl* other) { + auto ori_size = this->Size(); + this->Resize(ori_size + other->Size(), T()); + if (HostCanWrite() && other->HostCanRead()) { + auto& h_vec = this->HostVector(); + auto& other_vec = other->HostVector(); + CHECK_EQ(h_vec.size(), ori_size + other->Size()); + std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size); + } else { + auto ptr = other->ConstDevicePointer(); + SetDevice(); + CHECK_EQ(this->DeviceIdx(), other->DeviceIdx()); + dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size, + ptr, + other->Size() * sizeof(T), + cudaMemcpyDeviceToDevice)); + } + } + std::vector& HostVector() { LazySyncHost(GPUAccess::kNone); return data_h_; @@ -326,6 +345,11 @@ void HostDeviceVector::Copy(std::initializer_list other) { impl_->Copy(other); } +template +void HostDeviceVector::Extend(HostDeviceVector const& other) { + impl_->Extend(other.impl_); +} + template std::vector& HostDeviceVector::HostVector() { return impl_->HostVector(); } diff --git a/src/data/data.cc b/src/data/data.cc index c70f6e9a1..f24753e31 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t } } +void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) { + if (accumulate_rows) { + this->num_row_ += that.num_row_; + } + if (this->num_col_ != 0) { + CHECK_EQ(this->num_col_, that.num_col_) + << "Number of columns must be consistent across batches."; + } + this->num_col_ = that.num_col_; + + this->labels_.SetDevice(that.labels_.DeviceIdx()); + this->labels_.Extend(that.labels_); + + this->weights_.SetDevice(that.weights_.DeviceIdx()); + this->weights_.Extend(that.weights_); + + this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx()); + this->labels_lower_bound_.Extend(that.labels_lower_bound_); + + this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx()); + this->labels_upper_bound_.Extend(that.labels_upper_bound_); + + this->base_margin_.SetDevice(that.base_margin_.DeviceIdx()); + this->base_margin_.Extend(that.base_margin_); + + if (this->group_ptr_.size() == 0) { + this->group_ptr_ = that.group_ptr_; + } else { + CHECK_NE(that.group_ptr_.size(), 0); + auto group_ptr = that.group_ptr_; + for (size_t i = 1; i < group_ptr.size(); ++i) { + group_ptr[i] += this->group_ptr_.back(); + } + this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1, + group_ptr.end()); + } +} + void MetaInfo::Validate(int32_t device) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 31ec7fe44..74002b75a 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) { EXPECT_THROW(info.Validate(1), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) } + +TEST(MetaInfo, HostExtend) { + xgboost::MetaInfo lhs, rhs; + size_t const kRows = 100; + lhs.labels_.Resize(kRows); + lhs.num_row_ = kRows; + rhs.labels_.Resize(kRows); + rhs.num_row_ = kRows; + ASSERT_TRUE(lhs.labels_.HostCanRead()); + ASSERT_TRUE(rhs.labels_.HostCanRead()); + + size_t per_group = 10; + std::vector groups; + for (size_t g = 0; g < kRows / per_group; ++g) { + groups.emplace_back(per_group); + } + lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + + lhs.Extend(rhs, true); + ASSERT_EQ(lhs.num_row_, kRows * 2); + ASSERT_TRUE(lhs.labels_.HostCanRead()); + ASSERT_TRUE(rhs.labels_.HostCanRead()); + ASSERT_FALSE(lhs.labels_.DeviceCanRead()); + ASSERT_FALSE(rhs.labels_.DeviceCanRead()); + + ASSERT_EQ(lhs.group_ptr_.front(), 0); + ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2); + for (size_t i = 0; i < kRows * 2 / per_group; ++i) { + ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i); + } +} diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 23cb0f243..ca688dcab 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -77,7 +77,6 @@ TEST(MetaInfo, FromInterface) { TEST(MetaInfo, Group) { cudaSetDevice(0); - MetaInfo info; thrust::device_vector d_uint; @@ -105,4 +104,25 @@ TEST(MetaInfo, Group) { info = MetaInfo(); EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str())); } + +TEST(MetaInfo, DeviceExtend) { + dh::safe_cuda(cudaSetDevice(0)); + size_t const kRows = 100; + MetaInfo lhs, rhs; + + thrust::device_vector d_data; + std::string str = PrepareData("