Implement extend method for meta info. (#5800)

* Implement extend for host device vector.
This commit is contained in:
Jiaming Yuan 2020-06-20 03:32:03 +08:00 committed by GitHub
parent a6d9a06b7b
commit c4d721200a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 138 additions and 1 deletions

View File

@ -158,6 +158,16 @@ class MetaInfo {
*/ */
void SetInfo(const char* key, std::string const& interface_str); void SetInfo(const char* key, std::string const& interface_str);
/*
* \brief Extend with other MetaInfo.
*
* \param that The other MetaInfo object.
*
* \param accumulate_rows Whether rows need to be accumulated in this function. If
* client code knows number of rows in advance, set this parameter to false.
*/
void Extend(MetaInfo const& that, bool accumulate_rows);
private: private:
/*! \brief argsort of labels */ /*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_; mutable std::vector<size_t> label_order_cache_;

View File

@ -51,6 +51,7 @@
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include <type_traits>
#include "span.h" #include "span.h"
@ -83,6 +84,8 @@ enum GPUAccess {
template <typename T> template <typename T>
class HostDeviceVector { class HostDeviceVector {
static_assert(std::is_standard_layout<T>::value, "HostDeviceVector admits only POD types");
public: public:
explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1); explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
HostDeviceVector(std::initializer_list<T> init, int device = -1); HostDeviceVector(std::initializer_list<T> init, int device = -1);
@ -117,6 +120,8 @@ class HostDeviceVector {
void Copy(const std::vector<T>& other); void Copy(const std::vector<T>& other);
void Copy(std::initializer_list<T> other); void Copy(std::initializer_list<T> other);
void Extend(const HostDeviceVector<T>& other);
std::vector<T>& HostVector(); std::vector<T>& HostVector();
const std::vector<T>& ConstHostVector() const; const std::vector<T>& ConstHostVector() const;
const std::vector<T>& HostVector() const {return ConstHostVector(); } const std::vector<T>& HostVector() const {return ConstHostVector(); }

View File

@ -136,6 +136,14 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
std::copy(other.begin(), other.end(), HostVector().begin()); std::copy(other.begin(), other.end(), HostVector().begin());
} }
template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
auto ori_size = this->Size();
this->HostVector().resize(ori_size + other.Size());
std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(),
this->HostVector().begin() + ori_size);
}
template <typename T> template <typename T>
bool HostDeviceVector<T>::HostCanRead() const { bool HostDeviceVector<T>::HostCanRead() const {
return true; return true;

View File

@ -125,6 +125,25 @@ class HostDeviceVectorImpl {
} }
} }
void Extend(HostDeviceVectorImpl* other) {
auto ori_size = this->Size();
this->Resize(ori_size + other->Size(), T());
if (HostCanWrite() && other->HostCanRead()) {
auto& h_vec = this->HostVector();
auto& other_vec = other->HostVector();
CHECK_EQ(h_vec.size(), ori_size + other->Size());
std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size);
} else {
auto ptr = other->ConstDevicePointer();
SetDevice();
CHECK_EQ(this->DeviceIdx(), other->DeviceIdx());
dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size,
ptr,
other->Size() * sizeof(T),
cudaMemcpyDeviceToDevice));
}
}
std::vector<T>& HostVector() { std::vector<T>& HostVector() {
LazySyncHost(GPUAccess::kNone); LazySyncHost(GPUAccess::kNone);
return data_h_; return data_h_;
@ -326,6 +345,11 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
impl_->Copy(other); impl_->Copy(other);
} }
template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
impl_->Extend(other.impl_);
}
template <typename T> template <typename T>
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); } std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }

View File

@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
} }
} }
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
if (accumulate_rows) {
this->num_row_ += that.num_row_;
}
if (this->num_col_ != 0) {
CHECK_EQ(this->num_col_, that.num_col_)
<< "Number of columns must be consistent across batches.";
}
this->num_col_ = that.num_col_;
this->labels_.SetDevice(that.labels_.DeviceIdx());
this->labels_.Extend(that.labels_);
this->weights_.SetDevice(that.weights_.DeviceIdx());
this->weights_.Extend(that.weights_);
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
this->labels_lower_bound_.Extend(that.labels_lower_bound_);
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
this->labels_upper_bound_.Extend(that.labels_upper_bound_);
this->base_margin_.SetDevice(that.base_margin_.DeviceIdx());
this->base_margin_.Extend(that.base_margin_);
if (this->group_ptr_.size() == 0) {
this->group_ptr_ = that.group_ptr_;
} else {
CHECK_NE(that.group_ptr_.size(), 0);
auto group_ptr = that.group_ptr_;
for (size_t i = 1; i < group_ptr.size(); ++i) {
group_ptr[i] += this->group_ptr_.back();
}
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
group_ptr.end());
}
}
void MetaInfo::Validate(int32_t device) const { void MetaInfo::Validate(int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) { if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)

View File

@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) {
EXPECT_THROW(info.Validate(1), dmlc::Error); EXPECT_THROW(info.Validate(1), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} }
TEST(MetaInfo, HostExtend) {
xgboost::MetaInfo lhs, rhs;
size_t const kRows = 100;
lhs.labels_.Resize(kRows);
lhs.num_row_ = kRows;
rhs.labels_.Resize(kRows);
rhs.num_row_ = kRows;
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
size_t per_group = 10;
std::vector<xgboost::bst_group_t> groups;
for (size_t g = 0; g < kRows / per_group; ++g) {
groups.emplace_back(per_group);
}
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
lhs.Extend(rhs, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_TRUE(lhs.labels_.HostCanRead());
ASSERT_TRUE(rhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
ASSERT_FALSE(rhs.labels_.DeviceCanRead());
ASSERT_EQ(lhs.group_ptr_.front(), 0);
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
for (size_t i = 0; i < kRows * 2 / per_group; ++i) {
ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i);
}
}

View File

@ -77,7 +77,6 @@ TEST(MetaInfo, FromInterface) {
TEST(MetaInfo, Group) { TEST(MetaInfo, Group) {
cudaSetDevice(0); cudaSetDevice(0);
MetaInfo info; MetaInfo info;
thrust::device_vector<uint32_t> d_uint; thrust::device_vector<uint32_t> d_uint;
@ -105,4 +104,25 @@ TEST(MetaInfo, Group) {
info = MetaInfo(); info = MetaInfo();
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str())); EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
} }
TEST(MetaInfo, DeviceExtend) {
dh::safe_cuda(cudaSetDevice(0));
size_t const kRows = 100;
MetaInfo lhs, rhs;
thrust::device_vector<float> d_data;
std::string str = PrepareData<float>("<f4", &d_data, kRows);
lhs.SetInfo("label", str.c_str());
rhs.SetInfo("label", str.c_str());
ASSERT_FALSE(rhs.labels_.HostCanRead());
lhs.num_row_ = kRows;
rhs.num_row_ = kRows;
lhs.Extend(rhs, true);
ASSERT_EQ(lhs.num_row_, kRows * 2);
ASSERT_FALSE(lhs.labels_.HostCanRead());
ASSERT_FALSE(lhs.labels_.HostCanRead());
ASSERT_FALSE(rhs.labels_.HostCanRead());
}
} // namespace xgboost } // namespace xgboost