Implement extend method for meta info. (#5800)
* Implement extend for host device vector.
This commit is contained in:
parent
a6d9a06b7b
commit
c4d721200a
@ -158,6 +158,16 @@ class MetaInfo {
|
|||||||
*/
|
*/
|
||||||
void SetInfo(const char* key, std::string const& interface_str);
|
void SetInfo(const char* key, std::string const& interface_str);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* \brief Extend with other MetaInfo.
|
||||||
|
*
|
||||||
|
* \param that The other MetaInfo object.
|
||||||
|
*
|
||||||
|
* \param accumulate_rows Whether rows need to be accumulated in this function. If
|
||||||
|
* client code knows number of rows in advance, set this parameter to false.
|
||||||
|
*/
|
||||||
|
void Extend(MetaInfo const& that, bool accumulate_rows);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief argsort of labels */
|
/*! \brief argsort of labels */
|
||||||
mutable std::vector<size_t> label_order_cache_;
|
mutable std::vector<size_t> label_order_cache_;
|
||||||
|
|||||||
@ -51,6 +51,7 @@
|
|||||||
|
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "span.h"
|
#include "span.h"
|
||||||
|
|
||||||
@ -83,6 +84,8 @@ enum GPUAccess {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class HostDeviceVector {
|
class HostDeviceVector {
|
||||||
|
static_assert(std::is_standard_layout<T>::value, "HostDeviceVector admits only POD types");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
|
explicit HostDeviceVector(size_t size = 0, T v = T(), int device = -1);
|
||||||
HostDeviceVector(std::initializer_list<T> init, int device = -1);
|
HostDeviceVector(std::initializer_list<T> init, int device = -1);
|
||||||
@ -117,6 +120,8 @@ class HostDeviceVector {
|
|||||||
void Copy(const std::vector<T>& other);
|
void Copy(const std::vector<T>& other);
|
||||||
void Copy(std::initializer_list<T> other);
|
void Copy(std::initializer_list<T> other);
|
||||||
|
|
||||||
|
void Extend(const HostDeviceVector<T>& other);
|
||||||
|
|
||||||
std::vector<T>& HostVector();
|
std::vector<T>& HostVector();
|
||||||
const std::vector<T>& ConstHostVector() const;
|
const std::vector<T>& ConstHostVector() const;
|
||||||
const std::vector<T>& HostVector() const {return ConstHostVector(); }
|
const std::vector<T>& HostVector() const {return ConstHostVector(); }
|
||||||
|
|||||||
@ -136,6 +136,14 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
|
|||||||
std::copy(other.begin(), other.end(), HostVector().begin());
|
std::copy(other.begin(), other.end(), HostVector().begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
|
||||||
|
auto ori_size = this->Size();
|
||||||
|
this->HostVector().resize(ori_size + other.Size());
|
||||||
|
std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(),
|
||||||
|
this->HostVector().begin() + ori_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool HostDeviceVector<T>::HostCanRead() const {
|
bool HostDeviceVector<T>::HostCanRead() const {
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@ -125,6 +125,25 @@ class HostDeviceVectorImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Extend(HostDeviceVectorImpl* other) {
|
||||||
|
auto ori_size = this->Size();
|
||||||
|
this->Resize(ori_size + other->Size(), T());
|
||||||
|
if (HostCanWrite() && other->HostCanRead()) {
|
||||||
|
auto& h_vec = this->HostVector();
|
||||||
|
auto& other_vec = other->HostVector();
|
||||||
|
CHECK_EQ(h_vec.size(), ori_size + other->Size());
|
||||||
|
std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size);
|
||||||
|
} else {
|
||||||
|
auto ptr = other->ConstDevicePointer();
|
||||||
|
SetDevice();
|
||||||
|
CHECK_EQ(this->DeviceIdx(), other->DeviceIdx());
|
||||||
|
dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size,
|
||||||
|
ptr,
|
||||||
|
other->Size() * sizeof(T),
|
||||||
|
cudaMemcpyDeviceToDevice));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<T>& HostVector() {
|
std::vector<T>& HostVector() {
|
||||||
LazySyncHost(GPUAccess::kNone);
|
LazySyncHost(GPUAccess::kNone);
|
||||||
return data_h_;
|
return data_h_;
|
||||||
@ -326,6 +345,11 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
|
|||||||
impl_->Copy(other);
|
impl_->Copy(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
|
||||||
|
impl_->Extend(other.impl_);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }
|
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }
|
||||||
|
|
||||||
|
|||||||
@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
|
||||||
|
if (accumulate_rows) {
|
||||||
|
this->num_row_ += that.num_row_;
|
||||||
|
}
|
||||||
|
if (this->num_col_ != 0) {
|
||||||
|
CHECK_EQ(this->num_col_, that.num_col_)
|
||||||
|
<< "Number of columns must be consistent across batches.";
|
||||||
|
}
|
||||||
|
this->num_col_ = that.num_col_;
|
||||||
|
|
||||||
|
this->labels_.SetDevice(that.labels_.DeviceIdx());
|
||||||
|
this->labels_.Extend(that.labels_);
|
||||||
|
|
||||||
|
this->weights_.SetDevice(that.weights_.DeviceIdx());
|
||||||
|
this->weights_.Extend(that.weights_);
|
||||||
|
|
||||||
|
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
|
||||||
|
this->labels_lower_bound_.Extend(that.labels_lower_bound_);
|
||||||
|
|
||||||
|
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
|
||||||
|
this->labels_upper_bound_.Extend(that.labels_upper_bound_);
|
||||||
|
|
||||||
|
this->base_margin_.SetDevice(that.base_margin_.DeviceIdx());
|
||||||
|
this->base_margin_.Extend(that.base_margin_);
|
||||||
|
|
||||||
|
if (this->group_ptr_.size() == 0) {
|
||||||
|
this->group_ptr_ = that.group_ptr_;
|
||||||
|
} else {
|
||||||
|
CHECK_NE(that.group_ptr_.size(), 0);
|
||||||
|
auto group_ptr = that.group_ptr_;
|
||||||
|
for (size_t i = 1; i < group_ptr.size(); ++i) {
|
||||||
|
group_ptr[i] += this->group_ptr_.back();
|
||||||
|
}
|
||||||
|
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
|
||||||
|
group_ptr.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MetaInfo::Validate(int32_t device) const {
|
void MetaInfo::Validate(int32_t device) const {
|
||||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||||
|
|||||||
@ -163,3 +163,35 @@ TEST(MetaInfo, Validate) {
|
|||||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MetaInfo, HostExtend) {
|
||||||
|
xgboost::MetaInfo lhs, rhs;
|
||||||
|
size_t const kRows = 100;
|
||||||
|
lhs.labels_.Resize(kRows);
|
||||||
|
lhs.num_row_ = kRows;
|
||||||
|
rhs.labels_.Resize(kRows);
|
||||||
|
rhs.num_row_ = kRows;
|
||||||
|
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||||
|
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||||
|
|
||||||
|
size_t per_group = 10;
|
||||||
|
std::vector<xgboost::bst_group_t> groups;
|
||||||
|
for (size_t g = 0; g < kRows / per_group; ++g) {
|
||||||
|
groups.emplace_back(per_group);
|
||||||
|
}
|
||||||
|
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||||
|
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||||
|
|
||||||
|
lhs.Extend(rhs, true);
|
||||||
|
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||||
|
ASSERT_TRUE(lhs.labels_.HostCanRead());
|
||||||
|
ASSERT_TRUE(rhs.labels_.HostCanRead());
|
||||||
|
ASSERT_FALSE(lhs.labels_.DeviceCanRead());
|
||||||
|
ASSERT_FALSE(rhs.labels_.DeviceCanRead());
|
||||||
|
|
||||||
|
ASSERT_EQ(lhs.group_ptr_.front(), 0);
|
||||||
|
ASSERT_EQ(lhs.group_ptr_.back(), kRows * 2);
|
||||||
|
for (size_t i = 0; i < kRows * 2 / per_group; ++i) {
|
||||||
|
ASSERT_EQ(lhs.group_ptr_.at(i), per_group * i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -77,7 +77,6 @@ TEST(MetaInfo, FromInterface) {
|
|||||||
|
|
||||||
TEST(MetaInfo, Group) {
|
TEST(MetaInfo, Group) {
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(0);
|
||||||
|
|
||||||
MetaInfo info;
|
MetaInfo info;
|
||||||
|
|
||||||
thrust::device_vector<uint32_t> d_uint;
|
thrust::device_vector<uint32_t> d_uint;
|
||||||
@ -105,4 +104,25 @@ TEST(MetaInfo, Group) {
|
|||||||
info = MetaInfo();
|
info = MetaInfo();
|
||||||
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MetaInfo, DeviceExtend) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(0));
|
||||||
|
size_t const kRows = 100;
|
||||||
|
MetaInfo lhs, rhs;
|
||||||
|
|
||||||
|
thrust::device_vector<float> d_data;
|
||||||
|
std::string str = PrepareData<float>("<f4", &d_data, kRows);
|
||||||
|
lhs.SetInfo("label", str.c_str());
|
||||||
|
rhs.SetInfo("label", str.c_str());
|
||||||
|
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||||
|
lhs.num_row_ = kRows;
|
||||||
|
rhs.num_row_ = kRows;
|
||||||
|
|
||||||
|
lhs.Extend(rhs, true);
|
||||||
|
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||||
|
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||||
|
|
||||||
|
ASSERT_FALSE(lhs.labels_.HostCanRead());
|
||||||
|
ASSERT_FALSE(rhs.labels_.HostCanRead());
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user