further cleanup of single process multi-GPU code (#4810)
* use subspan in gpu predictor instead of copying * Revise `HostDeviceVector`
This commit is contained in:
parent
0184eb5d02
commit
733ed24dd9
@ -238,8 +238,7 @@ class MemoryLogger {
|
||||
device_allocations.erase(itr);
|
||||
}
|
||||
};
|
||||
std::map<int, DeviceStats>
|
||||
stats_; // Map device ordinal to memory information
|
||||
DeviceStats stats_;
|
||||
std::mutex mutex_;
|
||||
|
||||
public:
|
||||
@ -249,8 +248,8 @@ public:
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
int current_device;
|
||||
safe_cuda(cudaGetDevice(¤t_device));
|
||||
stats_[current_device].RegisterAllocation(ptr, n);
|
||||
CHECK_LE(stats_[current_device].peak_allocated_bytes, dh::TotalMemory(current_device));
|
||||
stats_.RegisterAllocation(ptr, n);
|
||||
CHECK_LE(stats_.peak_allocated_bytes, dh::TotalMemory(current_device));
|
||||
}
|
||||
void RegisterDeallocation(void *ptr, size_t n) {
|
||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
|
||||
@ -258,19 +257,19 @@ public:
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
int current_device;
|
||||
safe_cuda(cudaGetDevice(¤t_device));
|
||||
stats_[current_device].RegisterDeallocation(ptr, n, current_device);
|
||||
stats_.RegisterDeallocation(ptr, n, current_device);
|
||||
}
|
||||
void Log() {
|
||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
|
||||
return;
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
for (const auto &kv : stats_) {
|
||||
LOG(CONSOLE) << "======== Device " << kv.first << " Memory Allocations: "
|
||||
<< " ========";
|
||||
LOG(CONSOLE) << "Peak memory usage: "
|
||||
<< kv.second.peak_allocated_bytes / 1000000 << "mb";
|
||||
LOG(CONSOLE) << "Number of allocations: " << kv.second.num_allocations;
|
||||
}
|
||||
int current_device;
|
||||
safe_cuda(cudaGetDevice(¤t_device));
|
||||
LOG(CONSOLE) << "======== Device " << current_device << " Memory Allocations: "
|
||||
<< " ========";
|
||||
LOG(CONSOLE) << "Peak memory usage: "
|
||||
<< stats_.peak_allocated_bytes / 1000000 << "mb";
|
||||
LOG(CONSOLE) << "Number of allocations: " << stats_.num_allocations;
|
||||
}
|
||||
};
|
||||
};
|
||||
@ -940,10 +939,9 @@ class AllReducer {
|
||||
size_t allreduce_calls_; // Keep statistics of the number of reduce calls
|
||||
std::vector<size_t> host_data; // Used for all reduce on host
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
std::vector<ncclComm_t> comms;
|
||||
std::vector<cudaStream_t> streams;
|
||||
std::vector<int> device_ordinals; // device id from CUDA
|
||||
std::vector<int> device_counts; // device count from CUDA
|
||||
ncclComm_t comm;
|
||||
cudaStream_t stream;
|
||||
int device_ordinal;
|
||||
ncclUniqueId id;
|
||||
#endif
|
||||
|
||||
@ -952,79 +950,28 @@ class AllReducer {
|
||||
allreduce_calls_(0) {}
|
||||
|
||||
/**
|
||||
* \brief If we are using a single GPU only
|
||||
*/
|
||||
bool IsSingleGPU() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(device_counts.size() > 0) << "AllReducer not initialised.";
|
||||
return device_counts.size() <= 1 && device_counts.at(0) == 1;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Initialise with the desired device ordinals for this communication
|
||||
* \brief Initialise with the desired device ordinal for this communication
|
||||
* group.
|
||||
*
|
||||
* \param device_ordinals The device ordinals.
|
||||
* \param device_ordinal The device ordinal.
|
||||
*/
|
||||
|
||||
void Init(const std::vector<int> &device_ordinals) {
|
||||
void Init(int _device_ordinal) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
/** \brief this >monitor . init. */
|
||||
this->device_ordinals = device_ordinals;
|
||||
this->device_counts.resize(rabit::GetWorldSize());
|
||||
this->comms.resize(device_ordinals.size());
|
||||
this->streams.resize(device_ordinals.size());
|
||||
this->id = GetUniqueId();
|
||||
|
||||
device_counts.at(rabit::GetRank()) = device_ordinals.size();
|
||||
for (size_t i = 0; i < device_counts.size(); i++) {
|
||||
int dev_count = device_counts.at(i);
|
||||
rabit::Allreduce<rabit::op::Sum, int>(&dev_count, 1);
|
||||
device_counts.at(i) = dev_count;
|
||||
}
|
||||
|
||||
int nccl_rank = 0;
|
||||
int nccl_rank_offset = std::accumulate(device_counts.begin(),
|
||||
device_counts.begin() + rabit::GetRank(), 0);
|
||||
int nccl_nranks = std::accumulate(device_counts.begin(),
|
||||
device_counts.end(), 0);
|
||||
nccl_rank += nccl_rank_offset;
|
||||
|
||||
GroupStart();
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
int dev = device_ordinals.at(i);
|
||||
dh::safe_cuda(cudaSetDevice(dev));
|
||||
dh::safe_nccl(ncclCommInitRank(
|
||||
&comms.at(i),
|
||||
nccl_nranks, id,
|
||||
nccl_rank));
|
||||
|
||||
nccl_rank++;
|
||||
}
|
||||
GroupEnd();
|
||||
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
safe_cuda(cudaSetDevice(device_ordinals.at(i)));
|
||||
safe_cuda(cudaStreamCreate(&streams.at(i)));
|
||||
}
|
||||
device_ordinal = _device_ordinal;
|
||||
id = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||
dh::safe_nccl(ncclCommInitRank(&comm, rabit::GetWorldSize(), id, rabit::GetRank()));
|
||||
safe_cuda(cudaStreamCreate(&stream));
|
||||
initialised_ = true;
|
||||
#else
|
||||
CHECK_EQ(device_ordinals.size(), 1)
|
||||
<< "XGBoost must be compiled with NCCL to use more than one GPU.";
|
||||
#endif
|
||||
}
|
||||
~AllReducer() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (initialised_) {
|
||||
for (auto &stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
}
|
||||
for (auto &comm : comms) {
|
||||
ncclCommDestroy(comm);
|
||||
}
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
ncclCommDestroy(comm);
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
@ -1035,20 +982,21 @@ class AllReducer {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Use in exactly the same way as ncclGroupStart
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
void GroupStart() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Use in exactly the same way as ncclGroupEnd
|
||||
*/
|
||||
void GroupEnd() {
|
||||
void AllReduceSum(const double *sendbuff, double *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comm, stream));
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1056,51 +1004,18 @@ class AllReducer {
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the communication group.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
|
||||
void AllReduceSum(int communication_group_idx, const double *sendbuff,
|
||||
double *recvbuff, int count) {
|
||||
void AllReduceSum(const float *sendbuff, float *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum,
|
||||
comms.at(communication_group_idx),
|
||||
streams.at(communication_group_idx)));
|
||||
if(communication_group_idx == 0)
|
||||
{
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
|
||||
* streams or comms.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the communication group.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of elements.
|
||||
*/
|
||||
|
||||
void AllReduceSum(int communication_group_idx, const float *sendbuff,
|
||||
float *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(initialised_);
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx)));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum,
|
||||
comms.at(communication_group_idx),
|
||||
streams.at(communication_group_idx)));
|
||||
if(communication_group_idx == 0)
|
||||
{
|
||||
allreduce_bytes_ += count * sizeof(float);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclFloat, ncclSum, comm, stream));
|
||||
allreduce_bytes_ += count * sizeof(float);
|
||||
allreduce_calls_ += 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1109,21 +1024,17 @@ class AllReducer {
|
||||
*
|
||||
* \param count Number of.
|
||||
*
|
||||
* \param communication_group_idx Zero-based index of the communication group. \param sendbuff.
|
||||
* \param sendbuff The sendbuff.
|
||||
* \param recvbuff The recvbuff.
|
||||
* \param count Number of.
|
||||
*/
|
||||
|
||||
void AllReduceSum(int communication_group_idx, const int64_t *sendbuff,
|
||||
int64_t *recvbuff, int count) {
|
||||
void AllReduceSum(const int64_t *sendbuff, int64_t *recvbuff, int count) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
CHECK(initialised_);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx]));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum,
|
||||
comms[communication_group_idx],
|
||||
streams[communication_group_idx]));
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, comm, stream));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1134,26 +1045,8 @@ class AllReducer {
|
||||
*/
|
||||
void Synchronize() {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
for (size_t i = 0; i < device_ordinals.size(); i++) {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinals[i]));
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[i]));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Synchronizes the device
|
||||
*
|
||||
* \param device_id Identifier for the device.
|
||||
*/
|
||||
void Synchronize(int device_id) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
SaveCudaContext([&]() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
int idx = std::find(device_ordinals.begin(), device_ordinals.end(), device_id) - device_ordinals.begin();
|
||||
CHECK(idx < device_ordinals.size());
|
||||
dh::safe_cuda(cudaStreamSynchronize(streams[idx]));
|
||||
});
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal));
|
||||
dh::safe_cuda(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -1219,58 +1112,6 @@ class AllReducer {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a
|
||||
* single controlling thread for each element. In addition, passes the shard index
|
||||
* into the function.
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \param shards The shards.
|
||||
* \param f The func_t to process.
|
||||
*/
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext{[&]() {
|
||||
// Temporarily turn off dynamic so we have a guaranteed number of threads
|
||||
bool dynamic = omp_get_dynamic();
|
||||
omp_set_dynamic(false);
|
||||
const long shards_size = static_cast<long>(shards->size());
|
||||
#pragma omp parallel for schedule(static, 1) if (shards_size > 1) num_threads(shards_size)
|
||||
for (long shard = 0; shard < shards_size; ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
omp_set_dynamic(dynamic);
|
||||
}};
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a single controlling
|
||||
* thread for each element, returns the sum of the results.
|
||||
*
|
||||
* \tparam ReduceT Type of the reduce t.
|
||||
* \tparam T Generic type parameter.
|
||||
* \tparam FunctionT Type of the function t.
|
||||
* \param shards The shards.
|
||||
* \param f The func_t to process.
|
||||
*
|
||||
* \return A reduce_t.
|
||||
*/
|
||||
|
||||
template <typename ReduceT, typename ShardT, typename FunctionT>
|
||||
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
|
||||
std::vector<ReduceT> sums(shards->size());
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
sums[shard] = f(shards->at(shard));
|
||||
}}
|
||||
};
|
||||
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||
xgboost::common::Span<T> ToSpan(
|
||||
|
||||
@ -108,9 +108,6 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
|
||||
impl_->Vec().resize(new_size, v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::DeviceSize() const { return 0; }
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::Fill(T v) {
|
||||
std::fill(HostVector().begin(), HostVector().end(), v);
|
||||
@ -135,12 +132,22 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
|
||||
bool HostDeviceVector<T>::HostCanRead() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
|
||||
bool HostDeviceVector<T>::HostCanWrite() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanRead() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanWrite() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@ -19,33 +19,12 @@ void SetCudaSetDeviceHandler(void (*handler)(int)) {
|
||||
cudaSetDeviceHandler = handler;
|
||||
}
|
||||
|
||||
// wrapper over access with useful methods
|
||||
class Permissions {
|
||||
GPUAccess access_;
|
||||
explicit Permissions(GPUAccess access) : access_{access} {}
|
||||
|
||||
public:
|
||||
Permissions() : access_{GPUAccess::kNone} {}
|
||||
explicit Permissions(bool perm)
|
||||
: access_(perm ? GPUAccess::kWrite : GPUAccess::kNone) {}
|
||||
|
||||
bool CanRead() const { return access_ >= kRead; }
|
||||
bool CanWrite() const { return access_ == kWrite; }
|
||||
bool CanAccess(GPUAccess access) const { return access_ >= access; }
|
||||
void Grant(GPUAccess access) { access_ = std::max(access_, access); }
|
||||
void DenyComplementary(GPUAccess compl_access) {
|
||||
access_ = std::min(access_, GPUAccess::kWrite - compl_access);
|
||||
}
|
||||
Permissions Complementary() const {
|
||||
return Permissions(GPUAccess::kWrite - access_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class HostDeviceVectorImpl {
|
||||
public:
|
||||
HostDeviceVectorImpl(size_t size, T v, int device) : device_(device), perm_h_(device < 0) {
|
||||
HostDeviceVectorImpl(size_t size, T v, int device) : device_(device) {
|
||||
if (device >= 0) {
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
data_d_.resize(size, v);
|
||||
} else {
|
||||
@ -53,19 +32,11 @@ class HostDeviceVectorImpl {
|
||||
}
|
||||
}
|
||||
|
||||
// required, as a new std::mutex has to be created
|
||||
HostDeviceVectorImpl(const HostDeviceVectorImpl<T>& other)
|
||||
: device_(other.device_), data_h_(other.data_h_), perm_h_(other.perm_h_), mutex_() {
|
||||
if (device_ >= 0) {
|
||||
SetDevice();
|
||||
data_d_ = other.data_d_;
|
||||
}
|
||||
}
|
||||
|
||||
// Initializer can be std::vector<T> or std::initializer_list<T>
|
||||
template <class Initializer>
|
||||
HostDeviceVectorImpl(const Initializer& init, int device) : device_(device), perm_h_(device < 0) {
|
||||
HostDeviceVectorImpl(const Initializer& init, int device) : device_(device) {
|
||||
if (device >= 0) {
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
LazyResizeDevice(init.size());
|
||||
Copy(init);
|
||||
} else {
|
||||
@ -79,7 +50,7 @@ class HostDeviceVectorImpl {
|
||||
}
|
||||
}
|
||||
|
||||
size_t Size() const { return perm_h_.CanRead() ? data_h_.size() : data_d_.size(); }
|
||||
size_t Size() const { return HostCanRead() ? data_h_.size() : data_d_.size(); }
|
||||
|
||||
int DeviceIdx() const { return device_; }
|
||||
|
||||
@ -95,18 +66,13 @@ class HostDeviceVectorImpl {
|
||||
|
||||
common::Span<T> DeviceSpan() {
|
||||
LazySyncDevice(GPUAccess::kWrite);
|
||||
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(DeviceSize())};
|
||||
return {data_d_.data().get(), static_cast<typename common::Span<T>::index_type>(Size())};
|
||||
}
|
||||
|
||||
common::Span<const T> ConstDeviceSpan() {
|
||||
LazySyncDevice(GPUAccess::kRead);
|
||||
using SpanInd = typename common::Span<const T>::index_type;
|
||||
return {data_d_.data().get(), static_cast<SpanInd>(DeviceSize())};
|
||||
}
|
||||
|
||||
size_t DeviceSize() {
|
||||
LazySyncDevice(GPUAccess::kRead);
|
||||
return data_d_.size();
|
||||
return {data_d_.data().get(), static_cast<SpanInd>(Size())};
|
||||
}
|
||||
|
||||
thrust::device_ptr<T> tbegin() { // NOLINT
|
||||
@ -118,55 +84,53 @@ class HostDeviceVectorImpl {
|
||||
}
|
||||
|
||||
thrust::device_ptr<T> tend() { // NOLINT
|
||||
return tbegin() + DeviceSize();
|
||||
return tbegin() + Size();
|
||||
}
|
||||
|
||||
thrust::device_ptr<const T> tcend() { // NOLINT
|
||||
return tcbegin() + DeviceSize();
|
||||
return tcbegin() + Size();
|
||||
}
|
||||
|
||||
void Fill(T v) { // NOLINT
|
||||
if (perm_h_.CanWrite()) {
|
||||
if (HostCanWrite()) {
|
||||
std::fill(data_h_.begin(), data_h_.end(), v);
|
||||
} else {
|
||||
DeviceFill(v);
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
thrust::fill(data_d_.begin(), data_d_.end(), v);
|
||||
}
|
||||
}
|
||||
|
||||
void Copy(HostDeviceVectorImpl<T>* other) {
|
||||
CHECK_EQ(Size(), other->Size());
|
||||
// Data is on host.
|
||||
if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) {
|
||||
if (HostCanWrite() && other->HostCanWrite()) {
|
||||
std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin());
|
||||
return;
|
||||
}
|
||||
// Data is on device;
|
||||
if (device_ != other->device_) {
|
||||
SetDevice(other->device_);
|
||||
}
|
||||
DeviceCopy(other);
|
||||
CopyToDevice(other);
|
||||
}
|
||||
|
||||
void Copy(const std::vector<T>& other) {
|
||||
CHECK_EQ(Size(), other.size());
|
||||
if (perm_h_.CanWrite()) {
|
||||
if (HostCanWrite()) {
|
||||
std::copy(other.begin(), other.end(), data_h_.begin());
|
||||
} else {
|
||||
DeviceCopy(other.data());
|
||||
CopyToDevice(other.data());
|
||||
}
|
||||
}
|
||||
|
||||
void Copy(std::initializer_list<T> other) {
|
||||
CHECK_EQ(Size(), other.size());
|
||||
if (perm_h_.CanWrite()) {
|
||||
if (HostCanWrite()) {
|
||||
std::copy(other.begin(), other.end(), data_h_.begin());
|
||||
} else {
|
||||
DeviceCopy(other.begin());
|
||||
CopyToDevice(other.begin());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<T>& HostVector() {
|
||||
LazySyncHost(GPUAccess::kWrite);
|
||||
LazySyncHost(GPUAccess::kNone);
|
||||
return data_h_;
|
||||
}
|
||||
|
||||
@ -178,7 +142,7 @@ class HostDeviceVectorImpl {
|
||||
void SetDevice(int device) {
|
||||
if (device_ == device) { return; }
|
||||
if (device_ >= 0) {
|
||||
LazySyncHost(GPUAccess::kWrite);
|
||||
LazySyncHost(GPUAccess::kNone);
|
||||
}
|
||||
device_ = device;
|
||||
if (device_ >= 0) {
|
||||
@ -190,38 +154,37 @@ class HostDeviceVectorImpl {
|
||||
if (new_size == Size()) { return; }
|
||||
if (Size() == 0 && device_ >= 0) {
|
||||
// fast on-device resize
|
||||
perm_h_ = Permissions(false);
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
data_d_.resize(new_size, v);
|
||||
} else {
|
||||
// resize on host
|
||||
LazySyncHost(GPUAccess::kWrite);
|
||||
LazySyncHost(GPUAccess::kNone);
|
||||
data_h_.resize(new_size, v);
|
||||
}
|
||||
}
|
||||
|
||||
void LazySyncHost(GPUAccess access) {
|
||||
if (perm_h_.CanAccess(access)) { return; }
|
||||
if (perm_h_.CanRead()) {
|
||||
if (HostCanAccess(access)) { return; }
|
||||
if (HostCanRead()) {
|
||||
// data is present, just need to deny access to the device
|
||||
perm_h_.Grant(access);
|
||||
gpu_access_ = access;
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
gpu_access_ = access;
|
||||
if (data_h_.size() != data_d_.size()) { data_h_.resize(data_d_.size()); }
|
||||
SetDevice();
|
||||
dh::safe_cuda(cudaMemcpy(data_h_.data(),
|
||||
data_d_.data().get(),
|
||||
data_d_.size() * sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
perm_h_.Grant(access);
|
||||
}
|
||||
|
||||
void LazySyncDevice(GPUAccess access) {
|
||||
if (DevicePerm().CanAccess(access)) { return; }
|
||||
if (DevicePerm().CanRead()) {
|
||||
if (DeviceCanAccess(access)) { return; }
|
||||
if (DeviceCanRead()) {
|
||||
// deny read to the host
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
perm_h_.DenyComplementary(access);
|
||||
gpu_access_ = access;
|
||||
return;
|
||||
}
|
||||
// data is on the host
|
||||
@ -231,41 +194,37 @@ class HostDeviceVectorImpl {
|
||||
data_h_.data(),
|
||||
data_d_.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
perm_h_.DenyComplementary(access);
|
||||
gpu_access_ = access;
|
||||
}
|
||||
|
||||
bool HostCanAccess(GPUAccess access) { return perm_h_.CanAccess(access); }
|
||||
bool DeviceCanAccess(GPUAccess access) { return DevicePerm().CanAccess(access); }
|
||||
bool HostCanAccess(GPUAccess access) const { return gpu_access_ <= access; }
|
||||
bool HostCanRead() const { return HostCanAccess(GPUAccess::kRead); }
|
||||
bool HostCanWrite() const { return HostCanAccess(GPUAccess::kNone); }
|
||||
bool DeviceCanAccess(GPUAccess access) const { return gpu_access_ >= access; }
|
||||
bool DeviceCanRead() const { return DeviceCanAccess(GPUAccess::kRead); }
|
||||
bool DeviceCanWrite() const { return DeviceCanAccess(GPUAccess::kWrite); }
|
||||
|
||||
private:
|
||||
int device_{-1};
|
||||
std::vector<T> data_h_{};
|
||||
dh::device_vector<T> data_d_{};
|
||||
Permissions perm_h_{false};
|
||||
// protects size_d_ and perm_h_ when updated from multiple threads
|
||||
std::mutex mutex_{};
|
||||
GPUAccess gpu_access_{GPUAccess::kNone};
|
||||
|
||||
void DeviceFill(T v) {
|
||||
// TODO(canonizer): avoid full copy of host data
|
||||
LazySyncDevice(GPUAccess::kWrite);
|
||||
SetDevice();
|
||||
thrust::fill(data_d_.begin(), data_d_.end(), v);
|
||||
void CopyToDevice(HostDeviceVectorImpl* other) {
|
||||
if (other->HostCanWrite()) {
|
||||
CopyToDevice(other->data_h_.data());
|
||||
} else {
|
||||
LazyResizeDevice(Size());
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), other->data_d_.data().get(),
|
||||
data_d_.size() * sizeof(T), cudaMemcpyDefault));
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceCopy(HostDeviceVectorImpl* other) {
|
||||
// TODO(canonizer): avoid full copy of host data for this (but not for other)
|
||||
LazySyncDevice(GPUAccess::kWrite);
|
||||
other->LazySyncDevice(GPUAccess::kRead);
|
||||
SetDevice();
|
||||
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), other->data_d_.data().get(),
|
||||
data_d_.size() * sizeof(T), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void DeviceCopy(const T* begin) {
|
||||
// TODO(canonizer): avoid full copy of host data
|
||||
LazySyncDevice(GPUAccess::kWrite);
|
||||
void CopyToDevice(const T* begin) {
|
||||
LazyResizeDevice(Size());
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
dh::safe_cuda(cudaMemcpyAsync(data_d_.data().get(), begin,
|
||||
data_d_.size() * sizeof(T), cudaMemcpyDefault));
|
||||
@ -285,8 +244,6 @@ class HostDeviceVectorImpl {
|
||||
(*cudaSetDeviceHandler)(device_);
|
||||
}
|
||||
}
|
||||
|
||||
Permissions DevicePerm() const { return perm_h_.Complementary(); }
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
@ -347,11 +304,6 @@ common::Span<const T> HostDeviceVector<T>::ConstDeviceSpan() const {
|
||||
return impl_->ConstDeviceSpan();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t HostDeviceVector<T>::DeviceSize() const {
|
||||
return impl_->DeviceSize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> HostDeviceVector<T>::tbegin() { // NOLINT
|
||||
return impl_->tbegin();
|
||||
@ -401,13 +353,23 @@ const std::vector<T>& HostDeviceVector<T>::ConstHostVector() const {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::HostCanAccess(GPUAccess access) const {
|
||||
return impl_->HostCanAccess(access);
|
||||
bool HostDeviceVector<T>::HostCanRead() const {
|
||||
return impl_->HostCanRead();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanAccess(GPUAccess access) const {
|
||||
return impl_->DeviceCanAccess(access);
|
||||
bool HostDeviceVector<T>::HostCanWrite() const {
|
||||
return impl_->HostCanWrite();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanRead() const {
|
||||
return impl_->DeviceCanRead();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::DeviceCanWrite() const {
|
||||
return impl_->DeviceCanWrite();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -79,16 +79,23 @@ void SetCudaSetDeviceHandler(void (*handler)(int));
|
||||
|
||||
template <typename T> struct HostDeviceVectorImpl;
|
||||
|
||||
/*!
|
||||
* \brief Controls data access from the GPU.
|
||||
*
|
||||
* Since a `HostDeviceVector` can have data on both the host and device, access control needs to be
|
||||
* maintained to keep the data consistent.
|
||||
*
|
||||
* There are 3 scenarios supported:
|
||||
* - Data is being manipulated on device. GPU has write access, host doesn't have access.
|
||||
* - Data is read-only on both the host and device.
|
||||
* - Data is being manipulated on the host. Host has write access, device doesn't have access.
|
||||
*/
|
||||
enum GPUAccess {
|
||||
kNone, kRead,
|
||||
// write implies read
|
||||
kWrite
|
||||
};
|
||||
|
||||
inline GPUAccess operator-(GPUAccess a, GPUAccess b) {
|
||||
return static_cast<GPUAccess>(static_cast<int>(a) - static_cast<int>(b));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class HostDeviceVector {
|
||||
public:
|
||||
@ -111,8 +118,6 @@ class HostDeviceVector {
|
||||
const T* ConstHostPointer() const { return ConstHostVector().data(); }
|
||||
const T* HostPointer() const { return ConstHostPointer(); }
|
||||
|
||||
size_t DeviceSize() const;
|
||||
|
||||
// only define functions returning device_ptr
|
||||
// if HostDeviceVector.h is included from a .cu file
|
||||
#ifdef __CUDACC__
|
||||
@ -135,8 +140,10 @@ class HostDeviceVector {
|
||||
const std::vector<T>& ConstHostVector() const;
|
||||
const std::vector<T>& HostVector() const {return ConstHostVector(); }
|
||||
|
||||
bool HostCanAccess(GPUAccess access) const;
|
||||
bool DeviceCanAccess(GPUAccess access) const;
|
||||
bool HostCanRead() const;
|
||||
bool HostCanWrite() const;
|
||||
bool DeviceCanRead() const;
|
||||
bool DeviceCanWrite() const;
|
||||
|
||||
void SetDevice(int device) const;
|
||||
|
||||
|
||||
@ -68,7 +68,7 @@ class ElementWiseMetricsReduction {
|
||||
const HostDeviceVector<bst_float>& weights,
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds) {
|
||||
size_t n_data = preds.DeviceSize();
|
||||
size_t n_data = preds.Size();
|
||||
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||
|
||||
@ -85,7 +85,7 @@ class MultiClassMetricsReduction {
|
||||
const HostDeviceVector<bst_float>& labels,
|
||||
const HostDeviceVector<bst_float>& preds,
|
||||
const size_t n_class) {
|
||||
size_t n_data = labels.DeviceSize();
|
||||
size_t n_data = labels.Size();
|
||||
|
||||
thrust::counting_iterator<size_t> begin(0);
|
||||
thrust::counting_iterator<size_t> end = begin + n_data;
|
||||
|
||||
@ -231,12 +231,13 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
this->num_group_ = model.param.num_output_group;
|
||||
}
|
||||
|
||||
void PredictInternal
|
||||
(const SparsePage& batch, size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions) {
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
const int BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.offset.DeviceSize() - 1;
|
||||
size_t num_rows = batch.Size();
|
||||
const int GRID_SIZE = static_cast<int>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
int shared_memory_bytes = static_cast<int>
|
||||
@ -249,10 +250,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t entry_start = 0;
|
||||
|
||||
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan(), dh::ToSpan(tree_segments_),
|
||||
dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features,
|
||||
num_rows, entry_start, use_shared, this->num_group_);
|
||||
(dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(),
|
||||
batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, this->num_group_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -297,28 +298,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
InitModel(model, tree_begin, tree_end);
|
||||
|
||||
size_t batch_offset = 0;
|
||||
auto* preds = out_preds;
|
||||
std::unique_ptr<HostDeviceVector<bst_float>> batch_preds{nullptr};
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
bool is_external_memory = batch.Size() < dmat->Info().num_row_;
|
||||
if (is_external_memory) {
|
||||
batch_preds.reset(new HostDeviceVector<bst_float>);
|
||||
batch_preds->Resize(batch.Size() * model.param.num_output_group);
|
||||
std::copy(out_preds->ConstHostVector().begin() + batch_offset,
|
||||
out_preds->ConstHostVector().begin() + batch_offset + batch_preds->Size(),
|
||||
batch_preds->HostVector().begin());
|
||||
preds = batch_preds.get();
|
||||
}
|
||||
|
||||
batch.offset.SetDevice(device_);
|
||||
batch.data.SetDevice(device_);
|
||||
preds->SetDevice(device_);
|
||||
shard_.PredictInternal(batch, model.param.num_feature, preds);
|
||||
|
||||
if (is_external_memory) {
|
||||
auto h_preds = preds->ConstHostVector();
|
||||
std::copy(h_preds.begin(), h_preds.end(), out_preds->HostVector().begin() + batch_offset);
|
||||
}
|
||||
shard_.PredictInternal(batch, model.param.num_feature, out_preds, batch_offset);
|
||||
batch_offset += batch.Size() * model.param.num_output_group;
|
||||
}
|
||||
|
||||
@ -356,6 +339,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t n_classes = model.param.num_output_group;
|
||||
size_t n = n_classes * info.num_row_;
|
||||
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
|
||||
out_preds->SetDevice(device_);
|
||||
out_preds->Resize(n);
|
||||
if (base_margin.Size() != 0) {
|
||||
CHECK_EQ(base_margin.Size(), n);
|
||||
@ -454,7 +438,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Re configure shards when GPUSet is changed. */
|
||||
/*! \brief Reconfigure the shard when GPU is changed. */
|
||||
void ConfigureShard(int device) {
|
||||
if (device_ == device) return;
|
||||
|
||||
|
||||
@ -93,14 +93,14 @@ struct ExpandEntry {
|
||||
}
|
||||
};
|
||||
|
||||
inline static bool DepthWise(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||
if (lhs.depth == rhs.depth) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
return lhs.depth > rhs.depth; // favor small depth
|
||||
}
|
||||
}
|
||||
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
||||
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
|
||||
if (lhs.split.loss_chg == rhs.split.loss_chg) {
|
||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||
} else {
|
||||
@ -553,7 +553,7 @@ __global__ void SharedMemHistKernel(ELLPackMatrix matrix,
|
||||
// of rows to process from a batch and the position from which to process on each device.
|
||||
struct RowStateOnDevice {
|
||||
// Number of rows assigned to this device
|
||||
const size_t total_rows_assigned_to_device;
|
||||
size_t total_rows_assigned_to_device;
|
||||
// Number of rows processed thus far
|
||||
size_t total_rows_processed;
|
||||
// Number of rows to process from the current sparse page batch
|
||||
@ -584,14 +584,13 @@ template <typename GradientSumT>
|
||||
struct DeviceShard {
|
||||
int n_bins;
|
||||
int device_id;
|
||||
int shard_idx; // Position in the local array of shards
|
||||
|
||||
dh::BulkAllocator ba;
|
||||
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
DeviceHistogram<GradientSumT> hist{};
|
||||
|
||||
/*! \brief row_ptr form HistogramCuts. */
|
||||
common::Span<uint32_t> feature_segments;
|
||||
@ -611,9 +610,6 @@ struct DeviceShard {
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
/*! The row offset for this shard. */
|
||||
bst_uint row_begin_idx;
|
||||
bst_uint row_end_idx;
|
||||
bst_uint n_rows;
|
||||
|
||||
TrainParam param;
|
||||
@ -623,7 +619,7 @@ struct DeviceShard {
|
||||
dh::CubMemory temp_memory;
|
||||
dh::PinnedMemory pinned_memory;
|
||||
|
||||
std::vector<cudaStream_t> streams;
|
||||
std::vector<cudaStream_t> streams{};
|
||||
|
||||
common::Monitor monitor;
|
||||
std::vector<ValueConstraint> node_value_constraints;
|
||||
@ -635,14 +631,10 @@ struct DeviceShard {
|
||||
std::function<bool(ExpandEntry, ExpandEntry)>>;
|
||||
std::unique_ptr<ExpandQueue> qexpand;
|
||||
|
||||
DeviceShard(int _device_id, int shard_idx, bst_uint row_begin,
|
||||
bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed,
|
||||
DeviceShard(int _device_id, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed,
|
||||
uint32_t n_features)
|
||||
: device_id(_device_id),
|
||||
shard_idx(shard_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
n_rows(row_end - row_begin),
|
||||
n_rows(_n_rows),
|
||||
n_bins(0),
|
||||
param(std::move(_param)),
|
||||
prediction_cache_initialised(false),
|
||||
@ -658,7 +650,7 @@ struct DeviceShard {
|
||||
const SparsePage &row_batch, const common::HistogramCuts &hmat,
|
||||
const RowStateOnDevice &device_row_state, int rows_per_batch);
|
||||
|
||||
~DeviceShard() {
|
||||
~DeviceShard() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
for (auto& stream : streams) {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream));
|
||||
@ -704,7 +696,7 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair.data(), dh_gpair->ConstDevicePointer(),
|
||||
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
|
||||
SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx);
|
||||
SubsampleGradientPair(device_id, gpair, param.subsample);
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
@ -755,7 +747,7 @@ struct DeviceShard {
|
||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
if (d_feature_set.size() == 0) {
|
||||
if (d_feature_set.empty()) {
|
||||
// Acting as a device side constructor for DeviceSplitCandidate.
|
||||
// DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this
|
||||
// candidate.
|
||||
@ -927,12 +919,11 @@ struct DeviceShard {
|
||||
monitor.StartCuda("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
reducer->AllReduceSum(
|
||||
shard_idx,
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
ellpack_matrix.BinCount() *
|
||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
reducer->Synchronize(device_id);
|
||||
reducer->Synchronize();
|
||||
|
||||
monitor.StopCuda("AllReduce");
|
||||
}
|
||||
@ -979,11 +970,11 @@ struct DeviceShard {
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
|
||||
GradStats left_stats;
|
||||
GradStats left_stats{};
|
||||
left_stats.Add(candidate.split.left_sum);
|
||||
GradStats right_stats;
|
||||
GradStats right_stats{};
|
||||
right_stats.Add(candidate.split.right_sum);
|
||||
GradStats parent_sum;
|
||||
GradStats parent_sum{};
|
||||
parent_sum.Add(left_stats);
|
||||
parent_sum.Add(right_stats);
|
||||
node_value_constraints.resize(tree.GetNodes().size());
|
||||
@ -1021,9 +1012,9 @@ struct DeviceShard {
|
||||
dh::SumReduction(temp_memory, gpair, node_sum_gradients_d,
|
||||
gpair.size());
|
||||
reducer->AllReduceSum(
|
||||
shard_idx, reinterpret_cast<float*>(node_sum_gradients_d.data()),
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()),
|
||||
reinterpret_cast<float*>(node_sum_gradients_d.data()), 2);
|
||||
reducer->Synchronize(device_id);
|
||||
reducer->Synchronize();
|
||||
dh::safe_cuda(cudaMemcpy(node_sum_gradients.data(),
|
||||
node_sum_gradients_d.data(), sizeof(GradientPair),
|
||||
cudaMemcpyDeviceToHost));
|
||||
@ -1238,52 +1229,44 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
class DeviceHistogramBuilderState {
|
||||
public:
|
||||
template <typename GradientSumT>
|
||||
explicit DeviceHistogramBuilderState(
|
||||
const std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> &shards) {
|
||||
device_row_states_.reserve(shards.size());
|
||||
for (const auto &shard : shards) {
|
||||
device_row_states_.push_back(RowStateOnDevice(shard->n_rows));
|
||||
}
|
||||
}
|
||||
explicit DeviceHistogramBuilderState(const std::unique_ptr<DeviceShard<GradientSumT>>& shard)
|
||||
: device_row_state_(shard->n_rows) {}
|
||||
|
||||
const RowStateOnDevice &GetRowStateOnDevice(int idx) const {
|
||||
return device_row_states_[idx];
|
||||
const RowStateOnDevice& GetRowStateOnDevice() const {
|
||||
return device_row_state_;
|
||||
}
|
||||
|
||||
// This method is invoked at the beginning of each sparse page batch. This distributes
|
||||
// the rows in the sparse page to the different devices.
|
||||
// the rows in the sparse page to the device.
|
||||
// TODO(sriramch): Think of a way to utilize *all* the GPUs to build the compressed bins.
|
||||
void BeginBatch(const SparsePage &batch) {
|
||||
size_t rem_rows = batch.Size();
|
||||
size_t row_offset_in_current_batch = 0;
|
||||
for (auto &device_row_state : device_row_states_) {
|
||||
// Do we have anymore left to process from this batch on this device?
|
||||
if (device_row_state.total_rows_assigned_to_device > device_row_state.total_rows_processed) {
|
||||
// There are still some rows that needs to be assigned to this device
|
||||
device_row_state.rows_to_process_from_batch =
|
||||
std::min(
|
||||
device_row_state.total_rows_assigned_to_device - device_row_state.total_rows_processed,
|
||||
rem_rows);
|
||||
} else {
|
||||
// All rows have been assigned to this device
|
||||
device_row_state.rows_to_process_from_batch = 0;
|
||||
}
|
||||
|
||||
device_row_state.row_offset_in_current_batch = row_offset_in_current_batch;
|
||||
row_offset_in_current_batch += device_row_state.rows_to_process_from_batch;
|
||||
rem_rows -= device_row_state.rows_to_process_from_batch;
|
||||
// Do we have anymore left to process from this batch on this device?
|
||||
if (device_row_state_.total_rows_assigned_to_device > device_row_state_.total_rows_processed) {
|
||||
// There are still some rows that needs to be assigned to this device
|
||||
device_row_state_.rows_to_process_from_batch =
|
||||
std::min(
|
||||
device_row_state_.total_rows_assigned_to_device - device_row_state_.total_rows_processed,
|
||||
rem_rows);
|
||||
} else {
|
||||
// All rows have been assigned to this device
|
||||
device_row_state_.rows_to_process_from_batch = 0;
|
||||
}
|
||||
|
||||
device_row_state_.row_offset_in_current_batch = row_offset_in_current_batch;
|
||||
row_offset_in_current_batch += device_row_state_.rows_to_process_from_batch;
|
||||
rem_rows -= device_row_state_.rows_to_process_from_batch;
|
||||
}
|
||||
|
||||
// This method is invoked after completion of each sparse page batch
|
||||
void EndBatch() {
|
||||
for (auto &rs : device_row_states_) {
|
||||
rs.Advance();
|
||||
}
|
||||
device_row_state_.Advance();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<RowStateOnDevice> device_row_states_;
|
||||
RowStateOnDevice device_row_state_{0};
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
@ -1302,7 +1285,9 @@ class GPUHistMakerSpecialised {
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
}
|
||||
|
||||
~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); }
|
||||
~GPUHistMakerSpecialised() { // NOLINT
|
||||
dh::GlobalMemoryLogger().Log();
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
@ -1333,20 +1318,13 @@ class GPUHistMakerSpecialised {
|
||||
uint32_t column_sampling_seed = common::GlobalRandom()();
|
||||
rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||
|
||||
// Create device shards
|
||||
shards_.resize(1);
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
size_t start = 0;
|
||||
size_t size = info_->num_row_;
|
||||
shard = std::unique_ptr<DeviceShard<GradientSumT>>(
|
||||
new DeviceShard<GradientSumT>(device_, idx,
|
||||
start, start + size, param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
});
|
||||
// Create device shard
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
shard_.reset(new DeviceShard<GradientSumT>(device_,
|
||||
info_->num_row_,
|
||||
param_,
|
||||
column_sampling_seed,
|
||||
info_->num_col_));
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts
|
||||
@ -1355,32 +1333,22 @@ class GPUHistMakerSpecialised {
|
||||
dmat, &hmat_);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
n_bins_ = hmat_.Ptrs().back();
|
||||
|
||||
auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_;
|
||||
|
||||
// Init global data for each shard
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->InitCompressedData(hmat_, row_stride, is_dense);
|
||||
});
|
||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
||||
shard_->InitCompressedData(hmat_, row_stride, is_dense);
|
||||
monitor_.StopCuda("InitCompressedData");
|
||||
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
DeviceHistogramBuilderState hist_builder_row_state(shards_);
|
||||
DeviceHistogramBuilderState hist_builder_row_state(shard_);
|
||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
hist_builder_row_state.BeginBatch(batch);
|
||||
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(idx),
|
||||
hist_maker_param_.gpu_batch_nrows);
|
||||
});
|
||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
||||
shard_->CreateHistIndices(batch, hmat_, hist_builder_row_state.GetRowStateOnDevice(),
|
||||
hist_maker_param_.gpu_batch_nrows);
|
||||
|
||||
hist_builder_row_state.EndBatch();
|
||||
}
|
||||
@ -1408,7 +1376,7 @@ class GPUHistMakerSpecialised {
|
||||
}
|
||||
fs.Seek(0);
|
||||
rabit::Broadcast(&s_model, 0);
|
||||
RegTree reference_tree;
|
||||
RegTree reference_tree{};
|
||||
reference_tree.Load(&fs);
|
||||
for (const auto& tree : local_trees) {
|
||||
CHECK(tree == reference_tree);
|
||||
@ -1421,66 +1389,39 @@ class GPUHistMakerSpecialised {
|
||||
this->InitData(gpair, p_fmat);
|
||||
monitor_.StopCuda("InitData");
|
||||
|
||||
std::vector<RegTree> trees(shards_.size());
|
||||
for (auto& tree : trees) {
|
||||
tree = *p_tree;
|
||||
}
|
||||
gpair->SetDevice(device_);
|
||||
|
||||
// Launch one thread for each device "shard" containing a subset of rows.
|
||||
// Threads will cooperatively build the tree, synchronising over histograms.
|
||||
// Each thread will redundantly build its own copy of the tree
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->UpdateTree(gpair, p_fmat, &trees.at(idx), &reducer_);
|
||||
});
|
||||
|
||||
// All trees are expected to be identical
|
||||
if (hist_maker_param_.debug_synchronize) {
|
||||
this->CheckTreesSynchronized(trees);
|
||||
}
|
||||
|
||||
// Write the output tree
|
||||
*p_tree = trees.front();
|
||||
shard_->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
if (shard_ == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) {
|
||||
return false;
|
||||
}
|
||||
monitor_.StartCuda("UpdatePredictionCache");
|
||||
p_out_preds->SetDevice(device_);
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer());
|
||||
});
|
||||
dh::safe_cuda(cudaSetDevice(shard_->device_id));
|
||||
shard_->UpdatePredictionCache(p_out_preds->DevicePointer());
|
||||
monitor_.StopCuda("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
TrainParam param_; // NOLINT
|
||||
common::HistogramCuts hmat_; // NOLINT
|
||||
MetaInfo* info_; // NOLINT
|
||||
MetaInfo* info_{}; // NOLINT
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard<GradientSumT>>> shards_; // NOLINT
|
||||
std::unique_ptr<DeviceShard<GradientSumT>> shard_; // NOLINT
|
||||
|
||||
private:
|
||||
bool initialised_;
|
||||
|
||||
int n_bins_;
|
||||
|
||||
GPUHistMakerTrainParam hist_maker_param_;
|
||||
GenericParameter const* generic_param_;
|
||||
|
||||
dh::AllReducer reducer_;
|
||||
|
||||
DMatrix* p_last_fmat_;
|
||||
int device_;
|
||||
int device_{-1};
|
||||
|
||||
common::Monitor monitor_;
|
||||
};
|
||||
|
||||
@ -10,17 +10,6 @@
|
||||
|
||||
using xgboost::common::Span;
|
||||
|
||||
struct Shard { int id; };
|
||||
|
||||
TEST(DeviceHelpers, Basic) {
|
||||
std::vector<Shard> shards (4);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
shards[i].id = i;
|
||||
}
|
||||
int sum = dh::ReduceShards<int>(&shards, [](Shard& s) { return s.id ; });
|
||||
ASSERT_EQ(sum, 6);
|
||||
}
|
||||
|
||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
thrust::host_vector<int> *row_ptr,
|
||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
||||
|
||||
@ -38,19 +38,19 @@ void InitHostDeviceVector(size_t n, int device, HostDeviceVector<int> *v) {
|
||||
ASSERT_EQ(v->Size(), n);
|
||||
ASSERT_EQ(v->DeviceIdx(), device);
|
||||
// ensure that the device have read-write access
|
||||
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
|
||||
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
|
||||
ASSERT_TRUE(v->DeviceCanRead());
|
||||
ASSERT_TRUE(v->DeviceCanWrite());
|
||||
// ensure that the host has no access
|
||||
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
|
||||
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
|
||||
ASSERT_FALSE(v->HostCanRead());
|
||||
ASSERT_FALSE(v->HostCanWrite());
|
||||
|
||||
// fill in the data on the host
|
||||
std::vector<int>& data_h = v->HostVector();
|
||||
// ensure that the host has full access, while the device have none
|
||||
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
|
||||
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kWrite));
|
||||
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kRead));
|
||||
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
|
||||
ASSERT_TRUE(v->HostCanRead());
|
||||
ASSERT_TRUE(v->HostCanWrite());
|
||||
ASSERT_FALSE(v->DeviceCanRead());
|
||||
ASSERT_FALSE(v->DeviceCanWrite());
|
||||
ASSERT_EQ(data_h.size(), n);
|
||||
std::copy_n(thrust::make_counting_iterator(0), n, data_h.begin());
|
||||
}
|
||||
@ -62,76 +62,62 @@ void PlusOne(HostDeviceVector<int> *v) {
|
||||
[=]__device__(unsigned int a){ return a + 1; });
|
||||
}
|
||||
|
||||
void CheckDevice(HostDeviceVector<int> *v,
|
||||
const std::vector<size_t>& starts,
|
||||
const std::vector<size_t>& sizes,
|
||||
unsigned int first, GPUAccess access) {
|
||||
int n_devices = sizes.size();
|
||||
ASSERT_EQ(n_devices, 1);
|
||||
for (int i = 0; i < n_devices; ++i) {
|
||||
ASSERT_EQ(v->DeviceSize(), sizes.at(i));
|
||||
SetDevice(i);
|
||||
ASSERT_TRUE(thrust::equal(v->tcbegin(), v->tcend(),
|
||||
thrust::make_counting_iterator(first + starts[i])));
|
||||
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
|
||||
// ensure that the device has at most the access specified by access
|
||||
ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
|
||||
}
|
||||
ASSERT_EQ(v->HostCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
|
||||
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
|
||||
for (int i = 0; i < n_devices; ++i) {
|
||||
SetDevice(i);
|
||||
ASSERT_TRUE(thrust::equal(v->tbegin(), v->tend(),
|
||||
thrust::make_counting_iterator(first + starts[i])));
|
||||
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kRead));
|
||||
ASSERT_TRUE(v->DeviceCanAccess(GPUAccess::kWrite));
|
||||
}
|
||||
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kRead));
|
||||
ASSERT_FALSE(v->HostCanAccess(GPUAccess::kWrite));
|
||||
void CheckDevice(HostDeviceVector<int>* v,
|
||||
size_t size,
|
||||
unsigned int first,
|
||||
GPUAccess access) {
|
||||
ASSERT_EQ(v->Size(), size);
|
||||
SetDevice(v->DeviceIdx());
|
||||
|
||||
ASSERT_TRUE(thrust::equal(v->tcbegin(), v->tcend(),
|
||||
thrust::make_counting_iterator(first)));
|
||||
ASSERT_TRUE(v->DeviceCanRead());
|
||||
// ensure that the device has at most the access specified by access
|
||||
ASSERT_EQ(v->DeviceCanWrite(), access == GPUAccess::kWrite);
|
||||
ASSERT_EQ(v->HostCanRead(), access == GPUAccess::kRead);
|
||||
ASSERT_FALSE(v->HostCanWrite());
|
||||
|
||||
ASSERT_TRUE(thrust::equal(v->tbegin(), v->tend(),
|
||||
thrust::make_counting_iterator(first)));
|
||||
ASSERT_TRUE(v->DeviceCanRead());
|
||||
ASSERT_TRUE(v->DeviceCanWrite());
|
||||
ASSERT_FALSE(v->HostCanRead());
|
||||
ASSERT_FALSE(v->HostCanWrite());
|
||||
}
|
||||
|
||||
void CheckHost(HostDeviceVector<int> *v, GPUAccess access) {
|
||||
const std::vector<int>& data_h = access == GPUAccess::kWrite ?
|
||||
const std::vector<int>& data_h = access == GPUAccess::kNone ?
|
||||
v->HostVector() : v->ConstHostVector();
|
||||
for (size_t i = 0; i < v->Size(); ++i) {
|
||||
ASSERT_EQ(data_h.at(i), i + 1);
|
||||
}
|
||||
ASSERT_TRUE(v->HostCanAccess(GPUAccess::kRead));
|
||||
ASSERT_EQ(v->HostCanAccess(GPUAccess::kWrite), access == GPUAccess::kWrite);
|
||||
size_t n_devices = 1;
|
||||
for (int i = 0; i < n_devices; ++i) {
|
||||
ASSERT_EQ(v->DeviceCanAccess(GPUAccess::kRead), access == GPUAccess::kRead);
|
||||
// the devices should have no write access
|
||||
ASSERT_FALSE(v->DeviceCanAccess(GPUAccess::kWrite));
|
||||
}
|
||||
ASSERT_TRUE(v->HostCanRead());
|
||||
ASSERT_EQ(v->HostCanWrite(), access == GPUAccess::kNone);
|
||||
ASSERT_EQ(v->DeviceCanRead(), access == GPUAccess::kRead);
|
||||
// the devices should have no write access
|
||||
ASSERT_FALSE(v->DeviceCanWrite());
|
||||
}
|
||||
|
||||
void TestHostDeviceVector
|
||||
(size_t n, int device,
|
||||
const std::vector<size_t>& starts, const std::vector<size_t>& sizes) {
|
||||
void TestHostDeviceVector(size_t n, int device) {
|
||||
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
|
||||
HostDeviceVector<int> v;
|
||||
InitHostDeviceVector(n, device, &v);
|
||||
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
|
||||
CheckDevice(&v, n, 0, GPUAccess::kRead);
|
||||
PlusOne(&v);
|
||||
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
|
||||
CheckDevice(&v, n, 1, GPUAccess::kWrite);
|
||||
CheckHost(&v, GPUAccess::kRead);
|
||||
CheckHost(&v, GPUAccess::kWrite);
|
||||
CheckHost(&v, GPUAccess::kNone);
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, TestBlock) {
|
||||
TEST(HostDeviceVector, Basic) {
|
||||
size_t n = 1001;
|
||||
int device = 0;
|
||||
std::vector<size_t> starts{0};
|
||||
std::vector<size_t> sizes{1001};
|
||||
TestHostDeviceVector(n, device, starts, sizes);
|
||||
TestHostDeviceVector(n, device);
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, TestCopy) {
|
||||
TEST(HostDeviceVector, Copy) {
|
||||
size_t n = 1001;
|
||||
int device = 0;
|
||||
std::vector<size_t> starts{0};
|
||||
std::vector<size_t> sizes{1001};
|
||||
HostDeviceVectorSetDeviceHandler hdvec_dev_hndlr(SetDevice);
|
||||
|
||||
HostDeviceVector<int> v;
|
||||
@ -141,14 +127,14 @@ TEST(HostDeviceVector, TestCopy) {
|
||||
InitHostDeviceVector(n, device, &v1);
|
||||
v = v1;
|
||||
}
|
||||
CheckDevice(&v, starts, sizes, 0, GPUAccess::kRead);
|
||||
CheckDevice(&v, n, 0, GPUAccess::kRead);
|
||||
PlusOne(&v);
|
||||
CheckDevice(&v, starts, sizes, 1, GPUAccess::kWrite);
|
||||
CheckDevice(&v, n, 1, GPUAccess::kWrite);
|
||||
CheckHost(&v, GPUAccess::kRead);
|
||||
CheckHost(&v, GPUAccess::kWrite);
|
||||
CheckHost(&v, GPUAccess::kNone);
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Shard) {
|
||||
TEST(HostDeviceVector, SetDevice) {
|
||||
std::vector<int> h_vec (2345);
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
h_vec[i] = i;
|
||||
@ -157,7 +143,6 @@ TEST(HostDeviceVector, Shard) {
|
||||
auto device = 0;
|
||||
|
||||
vec.SetDevice(device);
|
||||
ASSERT_EQ(vec.DeviceSize(), h_vec.size());
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
auto span = vec.DeviceSpan(); // sync to device
|
||||
|
||||
@ -169,39 +154,26 @@ TEST(HostDeviceVector, Shard) {
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Reshard) {
|
||||
std::vector<int> h_vec (2345);
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
h_vec[i] = i;
|
||||
}
|
||||
HostDeviceVector<int> vec (h_vec);
|
||||
auto device = 0;
|
||||
|
||||
vec.SetDevice(device);
|
||||
ASSERT_EQ(vec.DeviceSize(), h_vec.size());
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
PlusOne(&vec);
|
||||
|
||||
vec.SetDevice(-1);
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
ASSERT_EQ(vec.DeviceIdx(), -1);
|
||||
|
||||
auto h_vec_1 = vec.HostVector();
|
||||
for (size_t i = 0; i < h_vec_1.size(); ++i) {
|
||||
ASSERT_EQ(h_vec_1.at(i), i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Span) {
|
||||
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
vec.SetDevice(0);
|
||||
auto span = vec.DeviceSpan();
|
||||
ASSERT_EQ(vec.DeviceSize(), span.size());
|
||||
ASSERT_EQ(vec.Size(), span.size());
|
||||
ASSERT_EQ(vec.DevicePointer(), span.data());
|
||||
auto const_span = vec.ConstDeviceSpan();
|
||||
ASSERT_EQ(vec.DeviceSize(), span.size());
|
||||
ASSERT_EQ(vec.ConstDevicePointer(), span.data());
|
||||
ASSERT_EQ(vec.Size(), const_span.size());
|
||||
ASSERT_EQ(vec.ConstDevicePointer(), const_span.data());
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, MGPU_Basic) {
|
||||
if (AllVisibleGPUs() < 2) {
|
||||
LOG(WARNING) << "Not testing in multi-gpu environment.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t n = 1001;
|
||||
int device = 1;
|
||||
TestHostDeviceVector(n, device);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -83,8 +83,8 @@ TEST(gpu_predictor, ExternalMemoryTest) {
|
||||
std::string file1 = tmpdir.path + "/big_1.libsvm";
|
||||
std::string file2 = tmpdir.path + "/big_2.libsvm";
|
||||
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
|
||||
// dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
|
||||
// dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
|
||||
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
|
||||
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
|
||||
|
||||
for (const auto& dmat: dmats) {
|
||||
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
|
||||
|
||||
@ -113,7 +113,7 @@ TEST(GpuHist, BuildGidxDense) {
|
||||
{"max_leaves", "0"},
|
||||
};
|
||||
param.Init(args);
|
||||
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols, kNCols);
|
||||
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
|
||||
BuildGidx(&shard, kNRows, kNCols);
|
||||
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
||||
@ -154,8 +154,7 @@ TEST(GpuHist, BuildGidxSparse) {
|
||||
};
|
||||
param.Init(args);
|
||||
|
||||
DeviceShard<GradientPairPrecise> shard(0, 0, 0, kNRows, param, kNCols,
|
||||
kNCols);
|
||||
DeviceShard<GradientPairPrecise> shard(0, kNRows, param, kNCols, kNCols);
|
||||
BuildGidx(&shard, kNRows, kNCols, 0.9f);
|
||||
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(shard.gidx_buffer.size());
|
||||
@ -200,8 +199,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
{"max_leaves", "0"},
|
||||
};
|
||||
param.Init(args);
|
||||
DeviceShard<GradientSumT> shard(0, 0, 0, kNRows, param, kNCols,
|
||||
kNCols);
|
||||
DeviceShard<GradientSumT> shard(0, kNRows, param, kNCols, kNCols);
|
||||
BuildGidx(&shard, kNRows, kNCols);
|
||||
|
||||
xgboost::SimpleLCG gen;
|
||||
@ -303,8 +301,7 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
|
||||
// Initialize DeviceShard
|
||||
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
|
||||
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols,
|
||||
kNCols)};
|
||||
new DeviceShard<GradientPairPrecise>(0, kNRows, param, kNCols, kNCols)};
|
||||
// Initialize DeviceShard::node_sum_gradients
|
||||
shard->node_sum_gradients = {{6.4f, 12.8f}};
|
||||
|
||||
@ -391,24 +388,20 @@ void TestHistogramIndexImpl() {
|
||||
hist_maker_ext.Configure(training_params, &generic_param);
|
||||
hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get());
|
||||
|
||||
ASSERT_EQ(hist_maker.shards_.size(), hist_maker_ext.shards_.size());
|
||||
|
||||
// Extract the device shards from the histogram makers and from that its compressed
|
||||
// Extract the device shard from the histogram makers and from that its compressed
|
||||
// histogram index
|
||||
for (size_t i = 0; i < hist_maker.shards_.size(); ++i) {
|
||||
const auto &dev_shard = hist_maker.shards_[i];
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
|
||||
const auto &dev_shard = hist_maker.shard_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(dev_shard->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer, dev_shard->gidx_buffer);
|
||||
|
||||
const auto &dev_shard_ext = hist_maker_ext.shards_[i];
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
|
||||
const auto &dev_shard_ext = hist_maker_ext.shard_;
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(dev_shard_ext->gidx_buffer.size());
|
||||
dh::CopyDeviceSpanToVector(&h_gidx_buffer_ext, dev_shard_ext->gidx_buffer);
|
||||
|
||||
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
|
||||
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
|
||||
ASSERT_EQ(dev_shard->n_bins, dev_shard_ext->n_bins);
|
||||
ASSERT_EQ(dev_shard->gidx_buffer.size(), dev_shard_ext->gidx_buffer.size());
|
||||
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
}
|
||||
ASSERT_EQ(h_gidx_buffer, h_gidx_buffer_ext);
|
||||
}
|
||||
|
||||
TEST(GpuHist, TestHistogramIndex) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user