Refactor DeviceUVector. (#10595)

Create a wrapper instead of using inheritance to avoid inconsistent interface of the class.
This commit is contained in:
Jiaming Yuan 2024-07-18 03:33:01 +08:00 committed by GitHub
parent 07732e02e5
commit e9fbce9791
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 45 deletions

View File

@ -510,7 +510,7 @@ xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
template <typename T> template <typename T>
xgboost::common::Span<T> ToSpan(DeviceUVector<T> &vec) { xgboost::common::Span<T> ToSpan(DeviceUVector<T> &vec) {
return {thrust::raw_pointer_cast(vec.data()), vec.size()}; return {vec.data(), vec.size()};
} }
// thrust begin, similiar to std::begin // thrust begin, similiar to std::begin

View File

@ -284,47 +284,64 @@ class LoggingResource : public rmm::mr::device_memory_resource {
LoggingResource *GlobalLoggingResource(); LoggingResource *GlobalLoggingResource();
#endif // defined(XGBOOST_USE_RMM)
/** /**
* @brief Container class that doesn't initialize the data. * @brief Container class that doesn't initialize the data when RMM is used.
*/ */
template <typename T> template <typename T>
class DeviceUVector : public rmm::device_uvector<T> { class DeviceUVector {
using Super = rmm::device_uvector<T>; private:
#if defined(XGBOOST_USE_RMM)
rmm::device_uvector<T> data_{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()};
#else
::dh::device_vector<T> data_;
#endif // defined(XGBOOST_USE_RMM)
public: public:
DeviceUVector() : Super{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()} {} using value_type = T; // NOLINT
using pointer = value_type *; // NOLINT
using const_pointer = value_type const *; // NOLINT
using reference = value_type &; // NOLINT
using const_reference = value_type const &; // NOLINT
void Resize(std::size_t n) { Super::resize(n, rmm::cuda_stream_per_thread); } public:
void Resize(std::size_t n, T const &v) { DeviceUVector() = default;
DeviceUVector(DeviceUVector const &that) = delete;
DeviceUVector &operator=(DeviceUVector const &that) = delete;
DeviceUVector(DeviceUVector &&that) = default;
DeviceUVector &operator=(DeviceUVector &&that) = default;
void resize(std::size_t n) { // NOLINT
#if defined(XGBOOST_USE_RMM)
data_.resize(n, rmm::cuda_stream_per_thread);
#else
data_.resize(n);
#endif
}
void resize(std::size_t n, T const &v) { // NOLINT
#if defined(XGBOOST_USE_RMM)
auto orig = this->size(); auto orig = this->size();
Super::resize(n, rmm::cuda_stream_per_thread); data_.resize(n, rmm::cuda_stream_per_thread);
if (orig < n) { if (orig < n) {
thrust::fill(rmm::exec_policy_nosync{}, this->begin() + orig, this->end(), v); thrust::fill(rmm::exec_policy_nosync{}, this->begin() + orig, this->end(), v);
} }
}
private:
// undefined private, cannot be accessed.
void resize(std::size_t n, rmm::cuda_stream_view stream); // NOLINT
};
#else #else
data_.resize(n, v);
#endif
}
[[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT
/** [[nodiscard]] auto begin() { return data_.begin(); } // NOLINT
* @brief Without RMM, the initialization will happen. [[nodiscard]] auto end() { return data_.end(); } // NOLINT
*/
template <typename T>
class DeviceUVector : public thrust::device_vector<T, XGBDeviceAllocator<T>> {
using Super = thrust::device_vector<T, XGBDeviceAllocator<T>>;
public: [[nodiscard]] auto begin() const { return this->cbegin(); } // NOLINT
void Resize(std::size_t n) { Super::resize(n); } [[nodiscard]] auto end() const { return this->cend(); } // NOLINT
void Resize(std::size_t n, T const &v) { Super::resize(n, v); }
private: [[nodiscard]] auto cbegin() const { return data_.cbegin(); } // NOLINT
// undefined private, cannot be accessed. [[nodiscard]] auto cend() const { return data_.cend(); } // NOLINT
void resize(std::size_t n, T const &v = T{}); // NOLINT
[[nodiscard]] auto data() { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
[[nodiscard]] auto data() const { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
}; };
#endif // defined(XGBOOST_USE_RMM)
} // namespace dh } // namespace dh

View File

@ -29,7 +29,7 @@ class HostDeviceVectorImpl {
if (device.IsCUDA()) { if (device.IsCUDA()) {
gpu_access_ = GPUAccess::kWrite; gpu_access_ = GPUAccess::kWrite;
SetDevice(); SetDevice();
data_d_->Resize(size, v); data_d_->resize(size, v);
} else { } else {
data_h_.resize(size, v); data_h_.resize(size, v);
} }
@ -67,12 +67,12 @@ class HostDeviceVectorImpl {
T* DevicePointer() { T* DevicePointer() {
LazySyncDevice(GPUAccess::kWrite); LazySyncDevice(GPUAccess::kWrite);
return thrust::raw_pointer_cast(data_d_->data()); return data_d_->data();
} }
const T* ConstDevicePointer() { const T* ConstDevicePointer() {
LazySyncDevice(GPUAccess::kRead); LazySyncDevice(GPUAccess::kRead);
return thrust::raw_pointer_cast(data_d_->data()); return data_d_->data();
} }
common::Span<T> DeviceSpan() { common::Span<T> DeviceSpan() {
@ -181,7 +181,7 @@ class HostDeviceVectorImpl {
gpu_access_ = GPUAccess::kWrite; gpu_access_ = GPUAccess::kWrite;
SetDevice(); SetDevice();
auto old_size = data_d_->size(); auto old_size = data_d_->size();
data_d_->Resize(new_size, std::forward<U>(args)...); data_d_->resize(new_size, std::forward<U>(args)...);
} else { } else {
// resize on host // resize on host
LazySyncHost(GPUAccess::kNone); LazySyncHost(GPUAccess::kNone);
@ -200,8 +200,8 @@ class HostDeviceVectorImpl {
gpu_access_ = access; gpu_access_ = access;
if (data_h_.size() != data_d_->size()) { data_h_.resize(data_d_->size()); } if (data_h_.size() != data_d_->size()) { data_h_.resize(data_d_->size()); }
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpy(data_h_.data(), thrust::raw_pointer_cast(data_d_->data()), dh::safe_cuda(cudaMemcpy(data_h_.data(), data_d_->data(), data_d_->size() * sizeof(T),
data_d_->size() * sizeof(T), cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
} }
void LazySyncDevice(GPUAccess access) { void LazySyncDevice(GPUAccess access) {
@ -214,9 +214,8 @@ class HostDeviceVectorImpl {
// data is on the host // data is on the host
LazyResizeDevice(data_h_.size()); LazyResizeDevice(data_h_.size());
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(data_d_->data()), data_h_.data(), dh::safe_cuda(cudaMemcpyAsync(data_d_->data(), data_h_.data(), data_d_->size() * sizeof(T),
data_d_->size() * sizeof(T), cudaMemcpyHostToDevice, cudaMemcpyHostToDevice, dh::DefaultStream()));
dh::DefaultStream()));
gpu_access_ = access; gpu_access_ = access;
} }
@ -241,8 +240,7 @@ class HostDeviceVectorImpl {
LazyResizeDevice(Size()); LazyResizeDevice(Size());
gpu_access_ = GPUAccess::kWrite; gpu_access_ = GPUAccess::kWrite;
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(data_d_->data()), dh::safe_cuda(cudaMemcpyAsync(data_d_->data(), other->data_d_->data(),
thrust::raw_pointer_cast(other->data_d_->data()),
data_d_->size() * sizeof(T), cudaMemcpyDefault, data_d_->size() * sizeof(T), cudaMemcpyDefault,
dh::DefaultStream())); dh::DefaultStream()));
} }
@ -252,15 +250,14 @@ class HostDeviceVectorImpl {
LazyResizeDevice(Size()); LazyResizeDevice(Size());
gpu_access_ = GPUAccess::kWrite; gpu_access_ = GPUAccess::kWrite;
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(data_d_->data()), begin, dh::safe_cuda(cudaMemcpyAsync(data_d_->data(), begin, data_d_->size() * sizeof(T),
data_d_->size() * sizeof(T), cudaMemcpyDefault, cudaMemcpyDefault, dh::DefaultStream()));
dh::DefaultStream()));
} }
void LazyResizeDevice(size_t new_size) { void LazyResizeDevice(size_t new_size) {
if (data_d_ && new_size == data_d_->size()) { return; } if (data_d_ && new_size == data_d_->size()) { return; }
SetDevice(); SetDevice();
data_d_->Resize(new_size); data_d_->resize(new_size);
} }
void SetDevice() { void SetDevice() {

View File

@ -12,7 +12,7 @@ TEST(DeviceUVector, Basic) {
std::int32_t verbosity{3}; std::int32_t verbosity{3};
std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity); std::swap(verbosity, xgboost::GlobalConfigThreadLocalStore::Get()->verbosity);
DeviceUVector<float> uvec; DeviceUVector<float> uvec;
uvec.Resize(12); uvec.resize(12);
auto peak = GlobalMemoryLogger().PeakMemory(); auto peak = GlobalMemoryLogger().PeakMemory();
auto n_bytes = sizeof(decltype(uvec)::value_type) * uvec.size(); auto n_bytes = sizeof(decltype(uvec)::value_type) * uvec.size();
ASSERT_EQ(peak, n_bytes); ASSERT_EQ(peak, n_bytes);