/*! * Copyright 2017 XGBoost contributors */ #include "./host_device_vector.h" #include #include #include #include #include #include "./device_helpers.cuh" namespace xgboost { // the handler to call instead of cudaSetDevice; only used for testing static void (*cudaSetDeviceHandler)(int) = nullptr; // NOLINT void SetCudaSetDeviceHandler(void (*handler)(int)) { cudaSetDeviceHandler = handler; } // wrapper over access with useful methods class Permissions { GPUAccess access_; explicit Permissions(GPUAccess access) : access_{access} {} public: Permissions() : access_{GPUAccess::kNone} {} explicit Permissions(bool perm) : access_(perm ? GPUAccess::kWrite : GPUAccess::kNone) {} bool CanRead() const { return access_ >= kRead; } bool CanWrite() const { return access_ == kWrite; } bool CanAccess(GPUAccess access) const { return access_ >= access; } void Grant(GPUAccess access) { access_ = std::max(access_, access); } void DenyComplementary(GPUAccess compl_access) { access_ = std::min(access_, GPUAccess::kWrite - compl_access); } Permissions Complementary() const { return Permissions(GPUAccess::kWrite - access_); } }; template struct HostDeviceVectorImpl { struct DeviceShard { DeviceShard() : proper_size_{0}, device_{-1}, start_{0}, perm_d_{false}, cached_size_{static_cast(~0)}, vec_{nullptr} {} ~DeviceShard() { SetDevice(); } void Init(HostDeviceVectorImpl* vec, int device) { if (vec_ == nullptr) { vec_ = vec; } CHECK_EQ(vec, vec_); device_ = device; LazyResize(vec_->Size()); perm_d_ = vec_->perm_h_.Complementary(); } void Init(HostDeviceVectorImpl* vec, const DeviceShard& other) { if (vec_ == nullptr) { vec_ = vec; } CHECK_EQ(vec, vec_); device_ = other.device_; cached_size_ = other.cached_size_; start_ = other.start_; proper_size_ = other.proper_size_; SetDevice(); data_.resize(other.data_.size()); perm_d_ = other.perm_d_; } void ScatterFrom(const T* begin) { // TODO(canonizer): avoid full copy of host data LazySyncDevice(GPUAccess::kWrite); SetDevice(); dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_, data_.size() * sizeof(T), cudaMemcpyDefault)); } void GatherTo(thrust::device_ptr begin) { LazySyncDevice(GPUAccess::kRead); SetDevice(); dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(), proper_size_ * sizeof(T), cudaMemcpyDefault)); } void Fill(T v) { // TODO(canonizer): avoid full copy of host data LazySyncDevice(GPUAccess::kWrite); SetDevice(); thrust::fill(data_.begin(), data_.end(), v); } void Copy(DeviceShard* other) { // TODO(canonizer): avoid full copy of host data for this (but not for other) LazySyncDevice(GPUAccess::kWrite); other->LazySyncDevice(GPUAccess::kRead); SetDevice(); dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(), data_.size() * sizeof(T), cudaMemcpyDefault)); } void LazySyncHost(GPUAccess access) { SetDevice(); dh::safe_cuda(cudaMemcpy(vec_->data_h_.data() + start_, data_.data().get(), proper_size_ * sizeof(T), cudaMemcpyDeviceToHost)); perm_d_.DenyComplementary(access); } void LazyResize(size_t new_size) { if (new_size == cached_size_) { return; } // resize is required int ndevices = vec_->distribution_.devices_.Size(); int device_index = vec_->distribution_.devices_.Index(device_); start_ = vec_->distribution_.ShardStart(new_size, device_index); proper_size_ = vec_->distribution_.ShardProperSize(new_size, device_index); // The size on this device. size_t size_d = vec_->distribution_.ShardSize(new_size, device_index); SetDevice(); data_.resize(size_d); cached_size_ = new_size; } void LazySyncDevice(GPUAccess access) { if (perm_d_.CanAccess(access)) { return; } if (perm_d_.CanRead()) { // deny read to the host perm_d_.Grant(access); std::lock_guard lock(vec_->mutex_); vec_->perm_h_.DenyComplementary(access); return; } // data is on the host size_t size_h = vec_->data_h_.size(); LazyResize(size_h); SetDevice(); dh::safe_cuda( cudaMemcpy(data_.data().get(), vec_->data_h_.data() + start_, data_.size() * sizeof(T), cudaMemcpyHostToDevice)); perm_d_.Grant(access); std::lock_guard lock(vec_->mutex_); vec_->perm_h_.DenyComplementary(access); vec_->size_d_ = size_h; } void SetDevice() { if (cudaSetDeviceHandler == nullptr) { dh::safe_cuda(cudaSetDevice(device_)); } else { (*cudaSetDeviceHandler)(device_); } } T* Raw() { return data_.data().get(); } size_t Start() const { return start_; } size_t DataSize() const { return data_.size(); } Permissions& Perm() { return perm_d_; } Permissions const& Perm() const { return perm_d_; } private: int device_; dh::device_vector data_; // cached vector size size_t cached_size_; size_t start_; // size of the portion to copy back to the host size_t proper_size_; Permissions perm_d_; HostDeviceVectorImpl* vec_; }; HostDeviceVectorImpl(size_t size, T v, const GPUDistribution &distribution) : distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { if (!distribution_.IsEmpty()) { size_d_ = size; InitShards(); Fill(v); } else { data_h_.resize(size, v); } } // required, as a new std::mutex has to be created HostDeviceVectorImpl(const HostDeviceVectorImpl& other) : data_h_(other.data_h_), perm_h_(other.perm_h_), size_d_(other.size_d_), distribution_(other.distribution_), mutex_() { shards_.resize(other.shards_.size()); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { shard.Init(this, other.shards_.at(i)); }); } // Initializer can be std::vector or std::initializer_list template HostDeviceVectorImpl(const Initializer& init, const GPUDistribution &distribution) : distribution_(distribution), perm_h_(distribution.IsEmpty()), size_d_(0) { if (!distribution_.IsEmpty()) { size_d_ = init.size(); InitShards(); Copy(init); } else { data_h_ = init; } } void InitShards() { int ndevices = distribution_.devices_.Size(); shards_.resize(ndevices); dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { shard.Init(this, distribution_.devices_.DeviceId(i)); }); } size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : size_d_; } GPUSet Devices() const { return distribution_.devices_; } const GPUDistribution& Distribution() const { return distribution_; } T* DevicePointer(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kWrite); return shards_.at(distribution_.devices_.Index(device)).Raw(); } const T* ConstDevicePointer(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); return shards_.at(distribution_.devices_.Index(device)).Raw(); } common::Span DeviceSpan(int device) { GPUSet devices = distribution_.devices_; CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kWrite); return {shards_.at(devices.Index(device)).Raw(), static_cast::index_type>(DeviceSize(device))}; } common::Span ConstDeviceSpan(int device) { GPUSet devices = distribution_.devices_; CHECK(devices.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); using SpanInd = typename common::Span::index_type; return {shards_.at(devices.Index(device)).Raw(), static_cast(DeviceSize(device))}; } size_t DeviceSize(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); return shards_.at(distribution_.devices_.Index(device)).DataSize(); } size_t DeviceStart(int device) { CHECK(distribution_.devices_.Contains(device)); LazySyncDevice(device, GPUAccess::kRead); return shards_.at(distribution_.devices_.Index(device)).Start(); } thrust::device_ptr tbegin(int device) { // NOLINT return thrust::device_ptr(DevicePointer(device)); } thrust::device_ptr tcbegin(int device) { // NOLINT return thrust::device_ptr(ConstDevicePointer(device)); } thrust::device_ptr tend(int device) { // NOLINT return tbegin(device) + DeviceSize(device); } thrust::device_ptr tcend(int device) { // NOLINT return tcbegin(device) + DeviceSize(device); } void ScatterFrom(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); if (perm_h_.CanWrite()) { dh::safe_cuda(cudaMemcpy(data_h_.data(), begin.get(), (end - begin) * sizeof(T), cudaMemcpyDeviceToHost)); } else { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(begin.get()); }); } } void GatherTo(thrust::device_ptr begin, thrust::device_ptr end) { CHECK_EQ(end - begin, Size()); if (perm_h_.CanWrite()) { dh::safe_cuda(cudaMemcpy(begin.get(), data_h_.data(), data_h_.size() * sizeof(T), cudaMemcpyHostToDevice)); } else { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); }); } } void Fill(T v) { // NOLINT if (perm_h_.CanWrite()) { std::fill(data_h_.begin(), data_h_.end(), v); } else { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); }); } } void Copy(HostDeviceVectorImpl* other) { CHECK_EQ(Size(), other->Size()); // Data is on host. if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) { std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin()); return; } // Data is on device; if (distribution_ != other->distribution_) { distribution_ = GPUDistribution(); Shard(other->Distribution()); size_d_ = other->size_d_; } dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) { shard.Copy(&other->shards_.at(i)); }); } void Copy(const std::vector& other) { CHECK_EQ(Size(), other.size()); if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(other.data()); }); } } void Copy(std::initializer_list other) { CHECK_EQ(Size(), other.size()); if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(other.begin()); }); } } std::vector& HostVector() { LazySyncHost(GPUAccess::kWrite); return data_h_; } const std::vector& ConstHostVector() { LazySyncHost(GPUAccess::kRead); return data_h_; } void Shard(const GPUDistribution& distribution) { if (distribution_ == distribution) { return; } CHECK(distribution_.IsEmpty()) << "This: " << distribution_.Devices().Size() << ", " << "Others: " << distribution.Devices().Size(); distribution_ = distribution; InitShards(); } void Shard(GPUSet new_devices) { if (distribution_.Devices() == new_devices) { return; } Shard(GPUDistribution::Block(new_devices)); } void Reshard(const GPUDistribution &distribution) { if (distribution_ == distribution) { return; } LazySyncHost(GPUAccess::kWrite); distribution_ = distribution; shards_.clear(); InitShards(); } void Resize(size_t new_size, T v) { if (new_size == Size()) { return; } if (distribution_.IsFixedSize()) { CHECK_EQ(new_size, distribution_.offsets_.back()); } if (Size() == 0 && !distribution_.IsEmpty()) { // fast on-device resize perm_h_ = Permissions(false); size_d_ = new_size; InitShards(); Fill(v); } else { // resize on host LazySyncHost(GPUAccess::kWrite); data_h_.resize(new_size, v); } } void LazySyncHost(GPUAccess access) { if (perm_h_.CanAccess(access)) { return; } if (perm_h_.CanRead()) { // data is present, just need to deny access to the device dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Perm().DenyComplementary(access); }); perm_h_.Grant(access); return; } if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.LazySyncHost(access); }); perm_h_.Grant(access); } void LazySyncDevice(int device, GPUAccess access) { GPUSet devices = distribution_.Devices(); CHECK(devices.Contains(device)); shards_.at(devices.Index(device)).LazySyncDevice(access); } bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); } bool DeviceCanAccess(int device, GPUAccess access) { GPUSet devices = distribution_.Devices(); if (!devices.Contains(device)) { return false; } return shards_.at(devices.Index(device)).Perm().CanAccess(access); } private: std::vector data_h_; Permissions perm_h_; // the total size of the data stored on the devices size_t size_d_; GPUDistribution distribution_; // protects size_d_ and perm_h_ when updated from multiple threads std::mutex mutex_; std::vector shards_; }; template HostDeviceVector::HostDeviceVector (size_t size, T v, const GPUDistribution &distribution) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(size, v, distribution); } template HostDeviceVector::HostDeviceVector (std::initializer_list init, const GPUDistribution &distribution) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(init, distribution); } template HostDeviceVector::HostDeviceVector (const std::vector& init, const GPUDistribution &distribution) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(init, distribution); } template HostDeviceVector::HostDeviceVector(const HostDeviceVector& other) : impl_(nullptr) { impl_ = new HostDeviceVectorImpl(*other.impl_); } template HostDeviceVector& HostDeviceVector::operator= (const HostDeviceVector& other) { if (this == &other) { return *this; } std::unique_ptr> newImpl(new HostDeviceVectorImpl(*other.impl_)); delete impl_; impl_ = newImpl.release(); return *this; } template HostDeviceVector::~HostDeviceVector() { delete impl_; impl_ = nullptr; } template size_t HostDeviceVector::Size() const { return impl_->Size(); } template GPUSet HostDeviceVector::Devices() const { return impl_->Devices(); } template const GPUDistribution& HostDeviceVector::Distribution() const { return impl_->Distribution(); } template T* HostDeviceVector::DevicePointer(int device) { return impl_->DevicePointer(device); } template const T* HostDeviceVector::ConstDevicePointer(int device) const { return impl_->ConstDevicePointer(device); } template common::Span HostDeviceVector::DeviceSpan(int device) { return impl_->DeviceSpan(device); } template common::Span HostDeviceVector::ConstDeviceSpan(int device) const { return impl_->ConstDeviceSpan(device); } template size_t HostDeviceVector::DeviceStart(int device) const { return impl_->DeviceStart(device); } template size_t HostDeviceVector::DeviceSize(int device) const { return impl_->DeviceSize(device); } template thrust::device_ptr HostDeviceVector::tbegin(int device) { // NOLINT return impl_->tbegin(device); } template thrust::device_ptr HostDeviceVector::tcbegin(int device) const { // NOLINT return impl_->tcbegin(device); } template thrust::device_ptr HostDeviceVector::tend(int device) { // NOLINT return impl_->tend(device); } template thrust::device_ptr HostDeviceVector::tcend(int device) const { // NOLINT return impl_->tcend(device); } template void HostDeviceVector::ScatterFrom (thrust::device_ptr begin, thrust::device_ptr end) { impl_->ScatterFrom(begin, end); } template void HostDeviceVector::GatherTo (thrust::device_ptr begin, thrust::device_ptr end) const { impl_->GatherTo(begin, end); } template void HostDeviceVector::Fill(T v) { impl_->Fill(v); } template void HostDeviceVector::Copy(const HostDeviceVector& other) { impl_->Copy(other.impl_); } template void HostDeviceVector::Copy(const std::vector& other) { impl_->Copy(other); } template void HostDeviceVector::Copy(std::initializer_list other) { impl_->Copy(other); } template std::vector& HostDeviceVector::HostVector() { return impl_->HostVector(); } template const std::vector& HostDeviceVector::ConstHostVector() const { return impl_->ConstHostVector(); } template bool HostDeviceVector::HostCanAccess(GPUAccess access) const { return impl_->HostCanAccess(access); } template bool HostDeviceVector::DeviceCanAccess(int device, GPUAccess access) const { return impl_->DeviceCanAccess(device, access); } template void HostDeviceVector::Shard(GPUSet new_devices) const { impl_->Shard(new_devices); } template void HostDeviceVector::Shard(const GPUDistribution &distribution) const { impl_->Shard(distribution); } template void HostDeviceVector::Reshard(const GPUDistribution &distribution) { impl_->Reshard(distribution); } template void HostDeviceVector::Resize(size_t new_size, T v) { impl_->Resize(new_size, v); } // explicit instantiations are required, as HostDeviceVector isn't header-only template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; } // namespace xgboost