further cleanup of single process multi-GPU code (#4810)

* use subspan in gpu predictor instead of copying
* Revise `HostDeviceVector`
This commit is contained in:
Rong Ou 2019-08-30 02:27:23 -07:00 committed by Jiaming Yuan
parent 0184eb5d02
commit 733ed24dd9
12 changed files with 289 additions and 593 deletions

View File

@ -238,8 +238,7 @@ class MemoryLogger {
device_allocations.erase(itr);
}
};
std::map<int, DeviceStats>
stats_; // Map device ordinal to memory information
DeviceStats stats_;
std::mutex mutex_;
public:
@ -249,8 +248,8 @@ public:
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterAllocation(ptr, n);
CHECK_LE(stats_[current_device].peak_allocated_bytes, dh::TotalMemory(current_device));
stats_.RegisterAllocation(ptr, n);
CHECK_LE(stats_.peak_allocated_bytes, dh::TotalMemory(current_device));
}
void RegisterDeallocation(void *ptr, size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
@ -258,19 +257,19 @@ public:
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterDeallocation(ptr, n, current_device);
stats_.RegisterDeallocation(ptr, n, current_device);
}
void Log() {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
return;
std::lock_guard<std::mutex> guard(mutex_);
for (const auto &kv : stats_) {
LOG(CONSOLE) << "======== Device " << kv.first << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< kv.second.peak_allocated_bytes / 1000000 << "mb";
LOG(CONSOLE) << "Number of allocations: " << kv.second.num_allocations;
}
int current_device;
safe_cuda(cudaGetDevice(&current_device));
LOG(CONSOLE) << "======== Device " << current_device << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< stats_.peak_allocated_bytes / 1000000 << "mb";
LOG(CONSOLE) << "Number of allocations: " << stats_.num_allocations;
}
};
};
@ -940,10 +939,9 @@ class AllReducer {
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
std::vector<size_t> host_data; // Used for all reduce on host
#ifdef XGBOOST_USE_NCCL
std::vector<ncclComm_t> comms;
std::vector<cudaStream_t> streams;
std::vector<int> device_ordinals; // device id from CUDA
std::vector<int> device_counts; // device count from CUDA
ncclComm_t comm;
cudaStream_t stream;
int device_ordinal;
ncclUniqueId id;
#endif
@ -952,79 +950,28 @@ class AllReducer {
allreduce_calls_(0) {}
/**
* \brief If we are using a single GPU only
*/
bool IsSingleGPU() {
#ifdef XGBOOST_USE_NCCL
CHECK(device_counts.size() > 0) << "AllReducer not initialised.";
return device_counts.size() <= 1 && device_counts.at(0) == 1;
#else
return true;
#endif
}
/**
* \brief Initialise with the desired device ordinals for this communication
* \brief Initialise with the desired device ordinal for this communication
* group.
*
* \param device_ordinals The device ordinals.
* \param device_ordinal The device ordinal.
*/
void Init(const std::vector<int> &device_ordinals) {
void Init(int _device_ordinal) {
#ifdef XGBOOST_USE_NCCL
/** \brief this >monitor . init. */
this->device_ordinals = device_ordinals;
this->device_counts.resize(rabit::GetWorldSize());
this->comms.resize(device_ordinals.size());
this->streams.resize(device_ordinals.size());
this->id = GetUniqueId();
device_counts.at(rabit::GetRank()) = device_ordinals.size();
for (size_t i = 0; i < device_counts.size(); i++) {
int dev_count = device_counts.at(i);
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
device_counts.at(i) = dev_count;
}
int nccl_rank = 0;
int nccl_rank_offset = std::accumulate(device_counts.begin(),
device_counts.begin() + rabit::GetRank(), 0);
int nccl_nranks = std::accumulate(device_counts.begin(),
device_counts.end(), 0);
nccl_rank += nccl_rank_offset;
GroupStart();
for (size_t i = 0; i < device_ordinals.size(); i++) {
int dev = device_ordinals.at(i);
dh::safe_cuda(cudaSetDevice(dev));
dh::safe_nccl(ncclCommInitRank(
&comms.at(i),
nccl_nranks, id,
nccl_rank));
nccl_rank++;
}
GroupEnd();
for (size_t i = 0; i < device_ordinals.size(); i++) {
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
safe_cuda(cudaStreamCreate(&streams.at(i)));
}
device_ordinal = _device_ordinal;
id = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank()));
safe_cuda(cudaStreamCreate(&stream));
initialised_ = true;
#else
CHECK_EQ(device_ordinals.size(), 1)
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
#endif
}
~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised_) {
for (auto &stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
for (auto &comm : comms) {
ncclCommDestroy(comm);
}
dh::safe_cuda(cudaStreamDestroy(stream));
ncclCommDestroy(comm);
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
@ -1035,20 +982,21 @@ class AllReducer {
}
/**
* \brief Use in exactly the same way as ncclGroupStart
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void GroupStart() {
#ifdef XGBOOST_USE_NCCL
dh::safe_nccl(ncclGroupStart());
#endif
}
/**
* \brief Use in exactly the same way as ncclGroupEnd
*/
void GroupEnd() {
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
dh::safe_nccl(ncclGroupEnd());
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm, stream));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
#endif
}
@ -1056,51 +1004,18 @@ class AllReducer {
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(int communication_group_idx, const double *sendbuff,
double *recvbuff, int count) {
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
#endif
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
*
* \param communication_group_idx Zero-based index of the communication group.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of elements.
*/
void AllReduceSum(int communication_group_idx, const float *sendbuff,
float *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum,
comms.at(communication_group_idx),
streams.at(communication_group_idx)));
if(communication_group_idx == 0)
{
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
}
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm, stream));
allreduce_bytes_ += count * sizeof(float);
allreduce_calls_ += 1;
#endif
}
@ -1109,21 +1024,17 @@ class AllReducer {
*
* \param count Number of.
*
* \param communication_group_idx Zero-based index of the communication group. \param sendbuff.
* \param sendbuff The sendbuff.
* \param recvbuff The recvbuff.
* \param count Number of.
*/
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
int64_t *recvbuff, int count) {
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
comms[communication_group_idx],
streams[communication_group_idx]));
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm, stream));
#endif
}
@ -1134,26 +1045,8 @@ class AllReducer {
*/
void Synchronize() {
#ifdef XGBOOST_USE_NCCL
for (size_t i = 0; i < device_ordinals.size(); i++) {
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
}
#endif
};
/**
* \brief Synchronizes the device
*
* \param device_id Identifier for the device.
*/
void Synchronize(int device_id) {
#ifdef XGBOOST_USE_NCCL
SaveCudaContext([&]() {
dh::safe_cuda(cudaSetDevice(device_id));
int idx = std::find(device_ordinals.begin(), device_ordinals.end(), device_id) - device_ordinals.begin();
CHECK(idx < device_ordinals.size());
dh::safe_cuda(cudaStreamSynchronize(streams[idx]));
});
dh::safe_cuda(cudaSetDevice(device_ordinal));
dh::safe_cuda(cudaStreamSynchronize(stream));
#endif
};
@ -1219,58 +1112,6 @@ class AllReducer {
}
};
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element. In addition, passes the shard index
* into the function.
*
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*/
template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext{[&]() {
// Temporarily turn off dynamic so we have a guaranteed number of threads
bool dynamic = omp_get_dynamic();
omp_set_dynamic(false);
const long shards_size = static_cast<long>(shards->size());
#pragma omp parallel for schedule(static, 1) if (shards_size > 1) num_threads(shards_size)
for (long shard = 0; shard < shards_size; ++shard) {
f(shard, shards->at(shard));
}
omp_set_dynamic(dynamic);
}};
}
/**
* \brief Executes some operation on each element of the input vector, using a single controlling
* thread for each element, returns the sum of the results.
*
* \tparam ReduceT Type of the reduce t.
* \tparam T Generic type parameter.
* \tparam FunctionT Type of the function t.
* \param shards The shards.
* \param f The func_t to process.
*
* \return A reduce_t.
*/
template <typename ReduceT, typename ShardT, typename FunctionT>
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
std::vector<ReduceT> sums(shards->size());
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
sums[shard] = f(shards->at(shard));
}}
};
return std::accumulate(sums.begin(), sums.end(), ReduceT());
}
template <typename T,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(

View File

@ -108,9 +108,6 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
impl_->Vec().resize(new_size, v);
}
template <typename T>
size_t HostDeviceVector<T>::DeviceSize() const { return 0; }
template <typename T>
void HostDeviceVector<T>::Fill(T v) {
std::fill(HostVector().begin(), HostVector().end(), v);
@ -135,12 +132,22 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
}
template <typename T>
bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
bool HostDeviceVector<T>::HostCanRead() const {
return true;
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
bool HostDeviceVector<T>::HostCanWrite() const {
return true;
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanRead() const {
return false;
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanWrite() const {
return false;
}

View File

@ -19,33 +19,12 @@ void SetCudaSetDeviceHandler(void (*handler)(int)) {
cudaSetDeviceHandler = handler;
}
// wrapper over access with useful methods
class Permissions {
GPUAccess access_;
explicit Permissions(GPUAccess access) : access_{access} {}
public:
Permissions() : access_{GPUAccess::kNone} {}
explicit Permissions(bool perm)
: access_(perm ? GPUAccess::kWrite : GPUAccess::kNone) {}
bool CanRead() const { return access_ >= kRead; }
bool CanWrite() const { return access_ == kWrite; }
bool CanAccess(GPUAccess access) const { return access_ >= access; }
void Grant(GPUAccess access) { access_ = std::max(access_, access); }
void DenyComplementary(GPUAccess compl_access) {
access_ = std::min(access_, GPUAccess::kWrite - compl_access);
}
Permissions Complementary() const {
return Permissions(GPUAccess::kWrite - access_);
}
};
template <typename T>
class HostDeviceVectorImpl {
public:
HostDeviceVectorImpl(size_t size, T v, int device) : device_(device), perm_h_(device < 0) {
HostDeviceVectorImpl(size_t size, T v, int device) : device_(device) {
if (device >= 0) {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
data_d_.resize(size, v);
} else {
@ -53,19 +32,11 @@ class HostDeviceVectorImpl {
}
}
// required, as a new std::mutex has to be created
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>& other)
: device_(other.device_), data_h_(other.data_h_), perm_h_(other.perm_h_), mutex_() {
if (device_ >= 0) {
SetDevice();
data_d_ = other.data_d_;
}
}
// Initializer can be std::vector<T> or std::initializer_list<T>
template <class Initializer>
HostDeviceVectorImpl(const Initializer& init, int device) : device_(device), perm_h_(device < 0) {
HostDeviceVectorImpl(const Initializer& init, int device) : device_(device) {
if (device >= 0) {
gpu_access_ = GPUAccess::kWrite;
LazyResizeDevice(init.size());
Copy(init);
} else {
@ -79,7 +50,7 @@ class HostDeviceVectorImpl {
}
}
size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : data_d_.size(); }
size_t Size() const { return HostCanRead() ? data_h_.size() : data_d_.size(); }
int DeviceIdx() const { return device_; }
@ -95,18 +66,13 @@ class HostDeviceVectorImpl {
common::Span<T> DeviceSpan() {
LazySyncDevice(GPUAccess::kWrite);
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(DeviceSize())};
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(Size())};
}
common::Span<const T> ConstDeviceSpan() {
LazySyncDevice(GPUAccess::kRead);
using SpanInd = typename common::Span<const T>::index_type;
return {data_d_.data().get(), static_cast<SpanInd>(DeviceSize())};
}
size_t DeviceSize() {
LazySyncDevice(GPUAccess::kRead);
return data_d_.size();
return {data_d_.data().get(), static_cast<SpanInd>(Size())};
}
thrust::device_ptr<T> tbegin() { // NOLINT
@ -118,55 +84,53 @@ class HostDeviceVectorImpl {
}
thrust::device_ptr<T> tend() { // NOLINT
return tbegin() + DeviceSize();
return tbegin() + Size();
}
thrust::device_ptr<const T> tcend() { // NOLINT
return tcbegin() + DeviceSize();
return tcbegin() + Size();
}
void Fill(T v) { // NOLINT
if (perm_h_.CanWrite()) {
if (HostCanWrite()) {
std::fill(data_h_.begin(), data_h_.end(), v);
} else {
DeviceFill(v);
gpu_access_ = GPUAccess::kWrite;
SetDevice();
thrust::fill(data_d_.begin(), data_d_.end(), v);
}
}
void Copy(HostDeviceVectorImpl<T>* other) {
CHECK_EQ(Size(), other->Size());
// Data is on host.
if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) {
if (HostCanWrite() && other->HostCanWrite()) {
std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin());
return;
}
// Data is on device;
if (device_ != other->device_) {
SetDevice(other->device_);
}
DeviceCopy(other);
CopyToDevice(other);
}
void Copy(const std::vector<T>& other) {
CHECK_EQ(Size(), other.size());
if (perm_h_.CanWrite()) {
if (HostCanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin());
} else {
DeviceCopy(other.data());
CopyToDevice(other.data());
}
}
void Copy(std::initializer_list<T> other) {
CHECK_EQ(Size(), other.size());
if (perm_h_.CanWrite()) {
if (HostCanWrite()) {
std::copy(other.begin(), other.end(), data_h_.begin());
} else {
DeviceCopy(other.begin());
CopyToDevice(other.begin());
}
}
std::vector<T>& HostVector() {
LazySyncHost(GPUAccess::kWrite);
LazySyncHost(GPUAccess::kNone);
return data_h_;
}
@ -178,7 +142,7 @@ class HostDeviceVectorImpl {
void SetDevice(int device) {
if (device_ == device) { return; }
if (device_ >= 0) {
LazySyncHost(GPUAccess::kWrite);
LazySyncHost(GPUAccess::kNone);
}
device_ = device;
if (device_ >= 0) {
@ -190,38 +154,37 @@ class HostDeviceVectorImpl {
if (new_size == Size()) { return; }
if (Size() == 0 && device_ >= 0) {
// fast on-device resize
perm_h_ = Permissions(false);
gpu_access_ = GPUAccess::kWrite;
SetDevice();
data_d_.resize(new_size, v);
} else {
// resize on host
LazySyncHost(GPUAccess::kWrite);
LazySyncHost(GPUAccess::kNone);
data_h_.resize(new_size, v);
}
}
void LazySyncHost(GPUAccess access) {
if (perm_h_.CanAccess(access)) { return; }
if (perm_h_.CanRead()) {
if (HostCanAccess(access)) { return; }
if (HostCanRead()) {
// data is present, just need to deny access to the device
perm_h_.Grant(access);
gpu_access_ = access;
return;
}
std::lock_guard<std::mutex> lock(mutex_);
gpu_access_ = access;
if (data_h_.size() != data_d_.size()) { data_h_.resize(data_d_.size()); }
SetDevice();
dh::safe_cuda(cudaMemcpy(data_h_.data(),
data_d_.data().get(),
data_d_.size() * sizeof(T),
cudaMemcpyDeviceToHost));
perm_h_.Grant(access);
}
void LazySyncDevice(GPUAccess access) {
if (DevicePerm().CanAccess(access)) { return; }
if (DevicePerm().CanRead()) {
if (DeviceCanAccess(access)) { return; }
if (DeviceCanRead()) {
// deny read to the host
std::lock_guard<std::mutex> lock(mutex_);
perm_h_.DenyComplementary(access);
gpu_access_ = access;
return;
}
// data is on the host
@ -231,41 +194,37 @@ class HostDeviceVectorImpl {
data_h_.data(),
data_d_.size() * sizeof(T),
cudaMemcpyHostToDevice));
std::lock_guard<std::mutex> lock(mutex_);
perm_h_.DenyComplementary(access);
gpu_access_ = access;
}
bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); }
bool DeviceCanAccess(GPUAccess access) { return DevicePerm().CanAccess(access); }
bool HostCanAccess(GPUAccess access) const { return gpu_access_ <= access; }
bool HostCanRead() const { return HostCanAccess(GPUAccess::kRead); }
bool HostCanWrite() const { return HostCanAccess(GPUAccess::kNone); }
bool DeviceCanAccess(GPUAccess access) const { return gpu_access_ >= access; }
bool DeviceCanRead() const { return DeviceCanAccess(GPUAccess::kRead); }
bool DeviceCanWrite() const { return DeviceCanAccess(GPUAccess::kWrite); }
private:
int device_{-1};
std::vector<T> data_h_{};
dh::device_vector<T> data_d_{};
Permissions perm_h_{false};
// protects size_d_ and perm_h_ when updated from multiple threads
std::mutex mutex_{};
GPUAccess gpu_access_{GPUAccess::kNone};
void DeviceFill(T v) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
thrust::fill(data_d_.begin(), data_d_.end(), v);
void CopyToDevice(HostDeviceVectorImpl* other) {
if (other->HostCanWrite()) {
CopyToDevice(other->data_h_.data());
} else {
LazyResizeDevice(Size());
gpu_access_ = GPUAccess::kWrite;
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), other->data_d_.data().get(),
data_d_.size() * sizeof(T), cudaMemcpyDefault));
}
}
void DeviceCopy(HostDeviceVectorImpl* other) {
// TODO(canonizer): avoid full copy of host data for this (but not for other)
LazySyncDevice(GPUAccess::kWrite);
other->LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), other->data_d_.data().get(),
data_d_.size() * sizeof(T), cudaMemcpyDefault));
}
void DeviceCopy(const T* begin) {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
void CopyToDevice(const T* begin) {
LazyResizeDevice(Size());
gpu_access_ = GPUAccess::kWrite;
SetDevice();
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), begin,
data_d_.size() * sizeof(T), cudaMemcpyDefault));
@ -285,8 +244,6 @@ class HostDeviceVectorImpl {
(*cudaSetDeviceHandler)(device_);
}
}
Permissions DevicePerm() const { return perm_h_.Complementary(); }
};
template<typename T>
@ -347,11 +304,6 @@ common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan() const {
return impl_->ConstDeviceSpan();
}
template <typename T>
size_t HostDeviceVector<T>::DeviceSize() const {
return impl_->DeviceSize();
}
template <typename T>
thrust::device_ptr<T> HostDeviceVector<T>::tbegin() { // NOLINT
return impl_->tbegin();
@ -401,13 +353,23 @@ const std::vector<T>& HostDeviceVector<T>::ConstHostVector() const {
}
template <typename T>
bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
return impl_->HostCanAccess(access);
bool HostDeviceVector<T>::HostCanRead() const {
return impl_->HostCanRead();
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
return impl_->DeviceCanAccess(access);
bool HostDeviceVector<T>::HostCanWrite() const {
return impl_->HostCanWrite();
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanRead() const {
return impl_->DeviceCanRead();
}
template <typename T>
bool HostDeviceVector<T>::DeviceCanWrite() const {
return impl_->DeviceCanWrite();
}
template <typename T>

View File

@ -79,16 +79,23 @@ void SetCudaSetDeviceHandler(void (*handler)(int));
template <typename T> struct HostDeviceVectorImpl;
/*!
* \brief Controls data access from the GPU.
*
* Since a `HostDeviceVector` can have data on both the host and device, access control needs to be
* maintained to keep the data consistent.
*
* There are 3 scenarios supported:
* - Data is being manipulated on device. GPU has write access, host doesn't have access.
* - Data is read-only on both the host and device.
* - Data is being manipulated on the host. Host has write access, device doesn't have access.
*/
enum GPUAccess {
kNone, kRead,
// write implies read
kWrite
};
inline GPUAccess operator-(GPUAccess a, GPUAccess b) {
return static_cast<GPUAccess>(static_cast<int>(a) - static_cast<int>(b));
}
template <typename T>
class HostDeviceVector {
public:
@ -111,8 +118,6 @@ class HostDeviceVector {
const T* ConstHostPointer() const { return ConstHostVector().data(); }
const T* HostPointer() const { return ConstHostPointer(); }
size_t DeviceSize() const;
// only define functions returning device_ptr
// if HostDeviceVector.h is included from a .cu file
#ifdef __CUDACC__
@ -135,8 +140,10 @@ class HostDeviceVector {
const std::vector<T>& ConstHostVector() const;
const std::vector<T>& HostVector() const {return ConstHostVector(); }
bool HostCanAccess(GPUAccess access) const;
bool DeviceCanAccess(GPUAccess access) const;
bool HostCanRead() const;
bool HostCanWrite() const;
bool DeviceCanRead() const;
bool DeviceCanWrite() const;
void SetDevice(int device) const;

View File

@ -68,7 +68,7 @@ class ElementWiseMetricsReduction {
const HostDeviceVector<bst_float>& weights,
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds) {
size_t n_data = preds.DeviceSize();
size_t n_data = preds.Size();
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data;

View File

@ -85,7 +85,7 @@ class MultiClassMetricsReduction {
const HostDeviceVector<bst_float>& labels,
const HostDeviceVector<bst_float>& preds,
const size_t n_class) {
size_t n_data = labels.DeviceSize();
size_t n_data = labels.Size();
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + n_data;

View File

@ -231,12 +231,13 @@ class GPUPredictor : public xgboost::Predictor {
this->num_group_ = model.param.num_output_group;
}
void PredictInternal
(const SparsePage& batch, size_t num_features,
HostDeviceVector<bst_float>* predictions) {
void PredictInternal(const SparsePage& batch,
size_t num_features,
HostDeviceVector<bst_float>* predictions,
size_t batch_offset) {
dh::safe_cuda(cudaSetDevice(device_));
const int BLOCK_THREADS = 128;
size_t num_rows = batch.offset.DeviceSize() - 1;
size_t num_rows = batch.Size();
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
int shared_memory_bytes = static_cast<int>
@ -249,10 +250,10 @@ class GPUPredictor : public xgboost::Predictor {
size_t entry_start = 0;
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
(dh::ToSpan(nodes_), predictions->DeviceSpan(), dh::ToSpan(tree_segments_),
dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features,
num_rows, entry_start, use_shared, this->num_group_);
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
entry_start, use_shared, this->num_group_);
}
private:
@ -297,28 +298,10 @@ class GPUPredictor : public xgboost::Predictor {
InitModel(model, tree_begin, tree_end);
size_t batch_offset = 0;
auto* preds = out_preds;
std::unique_ptr<HostDeviceVector<bst_float>> batch_preds{nullptr};
for (auto &batch : dmat->GetBatches<SparsePage>()) {
bool is_external_memory = batch.Size() < dmat->Info().num_row_;
if (is_external_memory) {
batch_preds.reset(new HostDeviceVector<bst_float>);
batch_preds->Resize(batch.Size() * model.param.num_output_group);
std::copy(out_preds->ConstHostVector().begin() + batch_offset,
out_preds->ConstHostVector().begin() + batch_offset + batch_preds->Size(),
batch_preds->HostVector().begin());
preds = batch_preds.get();
}
batch.offset.SetDevice(device_);
batch.data.SetDevice(device_);
preds->SetDevice(device_);
shard_.PredictInternal(batch, model.param.num_feature, preds);
if (is_external_memory) {
auto h_preds = preds->ConstHostVector();
std::copy(h_preds.begin(), h_preds.end(), out_preds->HostVector().begin() + batch_offset);
}
shard_.PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
batch_offset += batch.Size() * model.param.num_output_group;
}
@ -356,6 +339,7 @@ class GPUPredictor : public xgboost::Predictor {
size_t n_classes = model.param.num_output_group;
size_t n = n_classes * info.num_row_;
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
out_preds->SetDevice(device_);
out_preds->Resize(n);
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Size(), n);
@ -454,7 +438,7 @@ class GPUPredictor : public xgboost::Predictor {
}
private:
/*! \brief Re configure shards when GPUSet is changed. */
/*! \brief Reconfigure the shard when GPU is changed. */
void ConfigureShard(int device) {
if (device_ == device) return;

View File

@ -93,14 +93,14 @@ struct ExpandEntry {
}
};
inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) {
inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
@ -553,7 +553,7 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
// of rows to process from a batch and the position from which to process on each device.
struct RowStateOnDevice {
// Number of rows assigned to this device
const size_t total_rows_assigned_to_device;
size_t total_rows_assigned_to_device;
// Number of rows processed thus far
size_t total_rows_processed;
// Number of rows to process from the current sparse page batch
@ -584,14 +584,13 @@ template <typename GradientSumT>
struct DeviceShard {
int n_bins;
int device_id;
int shard_idx; // Position in the local array of shards
dh::BulkAllocator ba;
ELLPackMatrix ellpack_matrix;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist;
DeviceHistogram<GradientSumT> hist{};
/*! \brief row_ptr form HistogramCuts. */
common::Span<uint32_t> feature_segments;
@ -611,9 +610,6 @@ struct DeviceShard {
/*! \brief Sum gradient for each node. */
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
/*! The row offset for this shard. */
bst_uint row_begin_idx;
bst_uint row_end_idx;
bst_uint n_rows;
TrainParam param;
@ -623,7 +619,7 @@ struct DeviceShard {
dh::CubMemory temp_memory;
dh::PinnedMemory pinned_memory;
std::vector<cudaStream_t> streams;
std::vector<cudaStream_t> streams{};
common::Monitor monitor;
std::vector<ValueConstraint> node_value_constraints;
@ -635,14 +631,10 @@ struct DeviceShard {
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
DeviceShard(int _device_id, int shard_idx, bst_uint row_begin,
bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed,
DeviceShard(int _device_id, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed,
uint32_t n_features)
: device_id(_device_id),
shard_idx(shard_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_rows(_n_rows),
n_bins(0),
param(std::move(_param)),
prediction_cache_initialised(false),
@ -658,7 +650,7 @@ struct DeviceShard {
const SparsePage &row_batch, const common::HistogramCuts &hmat,
const RowStateOnDevice &device_row_state, int rows_per_batch);
~DeviceShard() {
~DeviceShard() { // NOLINT
dh::safe_cuda(cudaSetDevice(device_id));
for (auto& stream : streams) {
dh::safe_cuda(cudaStreamDestroy(stream));
@ -704,7 +696,7 @@ struct DeviceShard {
dh::safe_cuda(cudaMemcpyAsync(
gpair.data(), dh_gpair->ConstDevicePointer(),
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx);
SubsampleGradientPair(device_id, gpair, param.subsample);
hist.Reset();
}
@ -755,7 +747,7 @@ struct DeviceShard {
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
auto d_result = d_result_all.subspan(i, 1);
if (d_feature_set.size() == 0) {
if (d_feature_set.empty()) {
// Acting as a device side constructor for DeviceSplitCandidate.
// DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this
// candidate.
@ -927,12 +919,11 @@ struct DeviceShard {
monitor.StartCuda("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
reducer->AllReduceSum(
shard_idx,
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
ellpack_matrix.BinCount() *
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->Synchronize(device_id);
reducer->Synchronize();
monitor.StopCuda("AllReduce");
}
@ -979,11 +970,11 @@ struct DeviceShard {
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
GradStats left_stats;
GradStats left_stats{};
left_stats.Add(candidate.split.left_sum);
GradStats right_stats;
GradStats right_stats{};
right_stats.Add(candidate.split.right_sum);
GradStats parent_sum;
GradStats parent_sum{};
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints.resize(tree.GetNodes().size());
@ -1021,9 +1012,9 @@ struct DeviceShard {
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
gpair.size());
reducer->AllReduceSum(
shard_idx, reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()),
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
reducer->Synchronize(device_id);
reducer->Synchronize();
dh::safe_cuda(cudaMemcpy(node_sum_gradients.data(),
node_sum_gradients_d.data(), sizeof(GradientPair),
cudaMemcpyDeviceToHost));
@ -1238,52 +1229,44 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
class DeviceHistogramBuilderState {
public:
template <typename GradientSumT>
explicit DeviceHistogramBuilderState(
const std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> &shards) {
device_row_states_.reserve(shards.size());
for (const auto &shard : shards) {
device_row_states_.push_back(RowStateOnDevice(shard->n_rows));
}
}
explicit DeviceHistogramBuilderState(const std::unique_ptr<DeviceShard<GradientSumT>>& shard)
: device_row_state_(shard->n_rows) {}
const RowStateOnDevice &GetRowStateOnDevice(int idx) const {
return device_row_states_[idx];
const RowStateOnDevice& GetRowStateOnDevice() const {
return device_row_state_;
}
// This method is invoked at the beginning of each sparse page batch. This distributes
// the rows in the sparse page to the different devices.
// the rows in the sparse page to the device.
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
void BeginBatch(const SparsePage &batch) {
size_t rem_rows = batch.Size();
size_t row_offset_in_current_batch = 0;
for (auto &device_row_state : device_row_states_) {
// Do we have anymore left to process from this batch on this device?
if (device_row_state.total_rows_assigned_to_device > device_row_state.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state.rows_to_process_from_batch =
std::min(
device_row_state.total_rows_assigned_to_device - device_row_state.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state.rows_to_process_from_batch = 0;
}
device_row_state.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state.rows_to_process_from_batch;
rem_rows -= device_row_state.rows_to_process_from_batch;
// Do we have anymore left to process from this batch on this device?
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
// There are still some rows that needs to be assigned to this device
device_row_state_.rows_to_process_from_batch =
std::min(
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
rem_rows);
} else {
// All rows have been assigned to this device
device_row_state_.rows_to_process_from_batch = 0;
}
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
rem_rows -= device_row_state_.rows_to_process_from_batch;
}
// This method is invoked after completion of each sparse page batch
void EndBatch() {
for (auto &rs : device_row_states_) {
rs.Advance();
}
device_row_state_.Advance();
}
private:
std::vector<RowStateOnDevice> device_row_states_;
RowStateOnDevice device_row_state_{0};
};
template <typename GradientSumT>
@ -1302,7 +1285,9 @@ class GPUHistMakerSpecialised {
monitor_.Init("updater_gpu_hist");
}
~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); }
~GPUHistMakerSpecialised() { // NOLINT
dh::GlobalMemoryLogger().Log();
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
@ -1333,20 +1318,13 @@ class GPUHistMakerSpecialised {
uint32_t column_sampling_seed = common::GlobalRandom()();
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
// Create device shards
shards_.resize(1);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(device_));
size_t start = 0;
size_t size = info_->num_row_;
shard = std::unique_ptr<DeviceShard<GradientSumT>>(
new DeviceShard<GradientSumT>(device_, idx,
start, start + size, param_,
column_sampling_seed,
info_->num_col_));
});
// Create device shard
dh::safe_cuda(cudaSetDevice(device_));
shard_.reset(new DeviceShard<GradientSumT>(device_,
info_->num_row_,
param_,
column_sampling_seed,
info_->num_col_));
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
@ -1355,32 +1333,22 @@ class GPUHistMakerSpecialised {
dmat, &hmat_);
monitor_.StopCuda("Quantiles");
n_bins_ = hmat_.Ptrs().back();
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
// Init global data for each shard
monitor_.StartCuda("InitCompressedData");
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->InitCompressedData(hmat_, row_stride, is_dense);
});
dh::safe_cuda(cudaSetDevice(shard_->device_id));
shard_->InitCompressedData(hmat_, row_stride, is_dense);
monitor_.StopCuda("InitCompressedData");
monitor_.StartCuda("BinningCompression");
DeviceHistogramBuilderState hist_builder_row_state(shards_);
DeviceHistogramBuilderState hist_builder_row_state(shard_);
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
hist_builder_row_state.BeginBatch(batch);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(idx),
hist_maker_param_.gpu_batch_nrows);
});
dh::safe_cuda(cudaSetDevice(shard_->device_id));
shard_->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(),
hist_maker_param_.gpu_batch_nrows);
hist_builder_row_state.EndBatch();
}
@ -1408,7 +1376,7 @@ class GPUHistMakerSpecialised {
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
RegTree reference_tree;
RegTree reference_tree{};
reference_tree.Load(&fs);
for (const auto& tree : local_trees) {
CHECK(tree == reference_tree);
@ -1421,66 +1389,39 @@ class GPUHistMakerSpecialised {
this->InitData(gpair, p_fmat);
monitor_.StopCuda("InitData");
std::vector<RegTree> trees(shards_.size());
for (auto& tree : trees) {
tree = *p_tree;
}
gpair->SetDevice(device_);
// Launch one thread for each device "shard" containing a subset of rows.
// Threads will cooperatively build the tree, synchronising over histograms.
// Each thread will redundantly build its own copy of the tree
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->UpdateTree(gpair, p_fmat, &trees.at(idx), &reducer_);
});
// All trees are expected to be identical
if (hist_maker_param_.debug_synchronize) {
this->CheckTreesSynchronized(trees);
}
// Write the output tree
*p_tree = trees.front();
shard_->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
if (shard_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
return false;
}
monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->SetDevice(device_);
dh::ExecuteIndexShards(
&shards_,
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
dh::safe_cuda(cudaSetDevice(shard->device_id));
shard->UpdatePredictionCache(
p_out_preds->DevicePointer());
});
dh::safe_cuda(cudaSetDevice(shard_->device_id));
shard_->UpdatePredictionCache(p_out_preds->DevicePointer());
monitor_.StopCuda("UpdatePredictionCache");
return true;
}
TrainParam param_; // NOLINT
common::HistogramCuts hmat_; // NOLINT
MetaInfo* info_; // NOLINT
MetaInfo* info_{}; // NOLINT
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT
private:
bool initialised_;
int n_bins_;
GPUHistMakerTrainParam hist_maker_param_;
GenericParameter const* generic_param_;
dh::AllReducer reducer_;
DMatrix* p_last_fmat_;
int device_;
int device_{-1};
common::Monitor monitor_;
};

View File

@ -10,17 +10,6 @@
using xgboost::common::Span;
struct Shard { int id; };
TEST(DeviceHelpers, Basic) {
std::vector<Shard> shards (4);
for (int i = 0; i < 4; ++i) {
shards[i].id = i;
}
int sum = dh::ReduceShards<int>(&shards, [](Shard& s) { return s.id ; });
ASSERT_EQ(sum, 6);
}
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
thrust::host_vector<int> *row_ptr,
thrust::host_vector<xgboost::bst_uint> *rows) {

View File

@ -38,19 +38,19 @@ void InitHostDeviceVector(size_t n, int device, HostDeviceVector<int> *v) {
ASSERT_EQ(v->Size(), n);
ASSERT_EQ(v->DeviceIdx(), device);
// ensure that the device have read-write access
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
ASSERT_TRUE(v->DeviceCanRead());
ASSERT_TRUE(v->DeviceCanWrite());
// ensure that the host has no access
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_FALSE(v->HostCanRead());
ASSERT_FALSE(v->HostCanWrite());
// fill in the data on the host
std::vector<int>& data_h = v->HostVector();
// ensure that the host has full access, while the device have none
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kWrite));
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kRead));
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
ASSERT_TRUE(v->HostCanRead());
ASSERT_TRUE(v->HostCanWrite());
ASSERT_FALSE(v->DeviceCanRead());
ASSERT_FALSE(v->DeviceCanWrite());
ASSERT_EQ(data_h.size(), n);
std::copy_n(thrust::make_counting_iterator(0), n, data_h.begin());
}
@ -62,76 +62,62 @@ void PlusOne(HostDeviceVector<int> *v) {
[=]__device__(unsigned int a){ return a + 1; });
}
void CheckDevice(HostDeviceVector<int> *v,
const std::vector<size_t>& starts,
const std::vector<size_t>& sizes,
unsigned int first, GPUAccess access) {
int n_devices = sizes.size();
ASSERT_EQ(n_devices, 1);
for (int i = 0; i < n_devices; ++i) {
ASSERT_EQ(v->DeviceSize(), sizes.at(i));
SetDevice(i);
ASSERT_TRUE(thrust::equal(v->tcbegin(), v->tcend(),
thrust::make_counting_iterator(first + starts[i])));
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
// ensure that the device has at most the access specified by access
ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
}
ASSERT_EQ(v->HostCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
for (int i = 0; i < n_devices; ++i) {
SetDevice(i);
ASSERT_TRUE(thrust::equal(v->tbegin(), v->tend(),
thrust::make_counting_iterator(first + starts[i])));
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
}
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
void CheckDevice(HostDeviceVector<int>* v,
size_t size,
unsigned int first,
GPUAccess access) {
ASSERT_EQ(v->Size(), size);
SetDevice(v->DeviceIdx());
ASSERT_TRUE(thrust::equal(v->tcbegin(), v->tcend(),
thrust::make_counting_iterator(first)));
ASSERT_TRUE(v->DeviceCanRead());
// ensure that the device has at most the access specified by access
ASSERT_EQ(v->DeviceCanWrite(), access == GPUAccess::kWrite);
ASSERT_EQ(v->HostCanRead(), access == GPUAccess::kRead);
ASSERT_FALSE(v->HostCanWrite());
ASSERT_TRUE(thrust::equal(v->tbegin(), v->tend(),
thrust::make_counting_iterator(first)));
ASSERT_TRUE(v->DeviceCanRead());
ASSERT_TRUE(v->DeviceCanWrite());
ASSERT_FALSE(v->HostCanRead());
ASSERT_FALSE(v->HostCanWrite());
}
void CheckHost(HostDeviceVector<int> *v, GPUAccess access) {
const std::vector<int>& data_h = access == GPUAccess::kWrite ?
const std::vector<int>& data_h = access == GPUAccess::kNone ?
v->HostVector() : v->ConstHostVector();
for (size_t i = 0; i < v->Size(); ++i) {
ASSERT_EQ(data_h.at(i), i + 1);
}
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
ASSERT_EQ(v->HostCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
size_t n_devices = 1;
for (int i = 0; i < n_devices; ++i) {
ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
// the devices should have no write access
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
}
ASSERT_TRUE(v->HostCanRead());
ASSERT_EQ(v->HostCanWrite(), access == GPUAccess::kNone);
ASSERT_EQ(v->DeviceCanRead(), access == GPUAccess::kRead);
// the devices should have no write access
ASSERT_FALSE(v->DeviceCanWrite());
}
void TestHostDeviceVector
(size_t n, int device,
const std::vector<size_t>& starts, const std::vector<size_t>& sizes) {
void TestHostDeviceVector(size_t n, int device) {
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
HostDeviceVector<int> v;
InitHostDeviceVector(n, device, &v);
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
CheckDevice(&v, n, 0, GPUAccess::kRead);
PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
CheckDevice(&v, n, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kNone);
}
TEST(HostDeviceVector, TestBlock) {
TEST(HostDeviceVector, Basic) {
size_t n = 1001;
int device = 0;
std::vector<size_t> starts{0};
std::vector<size_t> sizes{1001};
TestHostDeviceVector(n, device, starts, sizes);
TestHostDeviceVector(n, device);
}
TEST(HostDeviceVector, TestCopy) {
TEST(HostDeviceVector, Copy) {
size_t n = 1001;
int device = 0;
std::vector<size_t> starts{0};
std::vector<size_t> sizes{1001};
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
HostDeviceVector<int> v;
@ -141,14 +127,14 @@ TEST(HostDeviceVector, TestCopy) {
InitHostDeviceVector(n, device, &v1);
v = v1;
}
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
CheckDevice(&v, n, 0, GPUAccess::kRead);
PlusOne(&v);
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
CheckDevice(&v, n, 1, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kRead);
CheckHost(&v, GPUAccess::kWrite);
CheckHost(&v, GPUAccess::kNone);
}
TEST(HostDeviceVector, Shard) {
TEST(HostDeviceVector, SetDevice) {
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
@ -157,7 +143,6 @@ TEST(HostDeviceVector, Shard) {
auto device = 0;
vec.SetDevice(device);
ASSERT_EQ(vec.DeviceSize(), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(); // sync to device
@ -169,39 +154,26 @@ TEST(HostDeviceVector, Shard) {
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
}
TEST(HostDeviceVector, Reshard) {
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
auto device = 0;
vec.SetDevice(device);
ASSERT_EQ(vec.DeviceSize(), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
PlusOne(&vec);
vec.SetDevice(-1);
ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_EQ(vec.DeviceIdx(), -1);
auto h_vec_1 = vec.HostVector();
for (size_t i = 0; i < h_vec_1.size(); ++i) {
ASSERT_EQ(h_vec_1.at(i), i + 1);
}
}
TEST(HostDeviceVector, Span) {
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
vec.SetDevice(0);
auto span = vec.DeviceSpan();
ASSERT_EQ(vec.DeviceSize(), span.size());
ASSERT_EQ(vec.Size(), span.size());
ASSERT_EQ(vec.DevicePointer(), span.data());
auto const_span = vec.ConstDeviceSpan();
ASSERT_EQ(vec.DeviceSize(), span.size());
ASSERT_EQ(vec.ConstDevicePointer(), span.data());
ASSERT_EQ(vec.Size(), const_span.size());
ASSERT_EQ(vec.ConstDevicePointer(), const_span.data());
}
TEST(HostDeviceVector, MGPU_Basic) {
if (AllVisibleGPUs() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
size_t n = 1001;
int device = 1;
TestHostDeviceVector(n, device);
}
} // namespace common
} // namespace xgboost

View File

@ -83,8 +83,8 @@ TEST(gpu_predictor, ExternalMemoryTest) {
std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
// dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
// dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
for (const auto& dmat: dmats) {
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);

View File

@ -113,7 +113,7 @@ TEST(GpuHist, BuildGidxDense) {
{"max_leaves", "0"},
};
param.Init(args);
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols, kNCols);
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
@ -154,8 +154,7 @@ TEST(GpuHist, BuildGidxSparse) {
};
param.Init(args);
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols,
kNCols);
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols, 0.9f);
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
@ -200,8 +199,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
{"max_leaves", "0"},
};
param.Init(args);
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols,
kNCols);
DeviceShard<GradientSumT> shard(0, kNRows, param, kNCols, kNCols);
BuildGidx(&shard, kNRows, kNCols);
xgboost::SimpleLCG gen;
@ -303,8 +301,7 @@ TEST(GpuHist, EvaluateSplits) {
// Initialize DeviceShard
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols,
kNCols)};
new DeviceShard<GradientPairPrecise>(0, kNRows, param, kNCols, kNCols)};
// Initialize DeviceShard::node_sum_gradients
shard->node_sum_gradients = {{6.4f, 12.8f}};
@ -391,24 +388,20 @@ void TestHistogramIndexImpl() {
hist_maker_ext.Configure(training_params, &generic_param);
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
ASSERT_EQ(hist_maker.shards_.size(), hist_maker_ext.shards_.size());
// Extract the device shards from the histogram makers and from that its compressed
// Extract the device shard from the histogram makers and from that its compressed
// histogram index
for (size_t i = 0; i < hist_maker.shards_.size(); ++i) {
const auto &dev_shard = hist_maker.shards_[i];
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
const auto &dev_shard = hist_maker.shard_;
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
const auto &dev_shard_ext = hist_maker_ext.shards_[i];
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
const auto &dev_shard_ext = hist_maker_ext.shard_;
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
}
TEST(GpuHist, TestHistogramIndex) {