Merge generic device helper functions into gpu set. (#3626)

* Remove the use of old NDevices* functions.
* Use GPUSet in timer.h.
This commit is contained in:
trivialfis 2018-08-26 14:14:23 +08:00 committed by Rory Mitchell
parent 3261002099
commit 60787ecebc
12 changed files with 299 additions and 199 deletions

View File

@ -6,6 +6,8 @@
#ifndef XGBOOST_COMMON_COMMON_H_
#define XGBOOST_COMMON_COMMON_H_
#include <xgboost/base.h>
#include <vector>
#include <string>
#include <sstream>
@ -35,6 +37,71 @@ inline std::string ToString(const T& data) {
return os.str();
}
/*
* Range iterator
*/
class Range {
public:
class Iterator {
friend class Range;
public:
using DifferenceType = int64_t;
XGBOOST_DEVICE int64_t operator*() const { return i_; }
XGBOOST_DEVICE const Iterator &operator++() {
i_ += step_;
return *this;
}
XGBOOST_DEVICE Iterator operator++(int) {
Iterator res {*this};
i_ += step_;
return res;
}
XGBOOST_DEVICE bool operator==(const Iterator &other) const {
return i_ >= other.i_;
}
XGBOOST_DEVICE bool operator!=(const Iterator &other) const {
return i_ < other.i_;
}
XGBOOST_DEVICE void Step(DifferenceType s) { step_ = s; }
protected:
XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {}
XGBOOST_DEVICE explicit Iterator(int64_t start, int step) :
i_{start}, step_{step} {}
public:
int64_t i_;
DifferenceType step_ = 1;
};
XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT
XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT
XGBOOST_DEVICE Range(int64_t begin, int64_t end)
: begin_(begin), end_(end) {}
XGBOOST_DEVICE Range(int64_t begin, int64_t end, Iterator::DifferenceType step)
: begin_(begin, step), end_(end) {}
XGBOOST_DEVICE bool operator==(const Range& other) const {
return *begin_ == *other.begin_ && *end_ == *other.end_;
}
XGBOOST_DEVICE bool operator!=(const Range& other) const {
return !(*this == other);
}
XGBOOST_DEVICE void Step(Iterator::DifferenceType s) { begin_.Step(s); }
XGBOOST_DEVICE Iterator::DifferenceType GetStep() const { return begin_.step_; }
private:
Iterator begin_;
Iterator end_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_COMMON_H_

View File

@ -7,6 +7,10 @@
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <xgboost/logging.h>
#include "common.h"
#include "gpu_set.h"
#include <algorithm>
#include <chrono>
#include <ctime>
@ -28,25 +32,6 @@ namespace dh {
#define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
#define DEV_INLINE __device__ __forceinline__
/*
* Error handling functions
*/
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
int line) {
if (code != cudaSuccess) {
std::stringstream ss;
ss << file << "(" << line << ")";
std::string file_and_line;
ss >> file_and_line;
throw thrust::system_error(code, thrust::cuda_category(), file_and_line);
}
return code;
}
#ifdef XGBOOST_USE_NCCL
#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__)
@ -73,45 +58,20 @@ const T *Raw(const thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
inline int NVisibleDevices() {
int n_visgpus = 0;
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
return n_visgpus;
}
inline int NDevicesAll(int n_gpus) {
int n_devices_visible = dh::NVisibleDevices();
int n_devices = n_gpus < 0 ? n_devices_visible : n_gpus;
return (n_devices);
}
inline int NDevices(int n_gpus, int num_rows) {
int n_devices = dh::NDevicesAll(n_gpus);
// fix-up device number to be limited by number of rows
n_devices = n_devices > num_rows ? num_rows : n_devices;
return (n_devices);
}
// if n_devices=-1, then use all visible devices
inline void SynchronizeNDevices(int n_devices, std::vector<int> dList) {
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx];
safe_cuda(cudaSetDevice(device_idx));
safe_cuda(cudaDeviceSynchronize());
}
}
inline void SynchronizeAll() {
for (int device_idx = 0; device_idx < NVisibleDevices(); device_idx++) {
safe_cuda(cudaSetDevice(device_idx));
inline void SynchronizeNDevices(xgboost::GPUSet devices) {
devices = devices.IsEmpty() ? xgboost::GPUSet::AllVisible() : devices;
for (auto const d : devices.Unnormalised()) {
safe_cuda(cudaSetDevice(d));
safe_cuda(cudaDeviceSynchronize());
}
}
inline std::string DeviceName(int device_idx) {
cudaDeviceProp prop;
dh::safe_cuda(cudaGetDeviceProperties(&prop, device_idx));
return std::string(prop.name);
inline void SynchronizeAll() {
for (int device_idx : xgboost::GPUSet::AllVisible()) {
safe_cuda(cudaSetDevice(device_idx));
safe_cuda(cudaDeviceSynchronize());
}
}
inline size_t AvailableMemory(int device_idx) {
@ -144,15 +104,8 @@ inline size_t MaxSharedMemory(int device_idx) {
return prop.sharedMemPerBlock;
}
// ensure gpu_id is correct, so not dependent upon user knowing details
inline int GetDeviceIdx(int gpu_id) {
// protect against overrun for gpu_id
return (std::abs(gpu_id) + 0) % dh::NVisibleDevices();
}
inline void CheckComputeCapability() {
int n_devices = NVisibleDevices();
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
for (int d_idx : xgboost::GPUSet::AllVisible()) {
cudaDeviceProp prop;
safe_cuda(cudaGetDeviceProperties(&prop, d_idx));
std::ostringstream oss;
@ -163,12 +116,11 @@ inline void CheckComputeCapability() {
}
}
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
}
/*!
/*!
* \brief Find the strict upper bound for an element in a sorted array
* using binary search.
* \param cuts pointer to the first element of the sorted array
@ -199,67 +151,18 @@ DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
return right;
}
/*
* Range iterator
*/
class Range {
public:
class Iterator {
friend class Range;
public:
XGBOOST_DEVICE int64_t operator*() const { return i_; }
XGBOOST_DEVICE const Iterator &operator++() {
i_ += step_;
return *this;
}
XGBOOST_DEVICE Iterator operator++(int) {
Iterator copy(*this);
i_ += step_;
return copy;
}
XGBOOST_DEVICE bool operator==(const Iterator &other) const {
return i_ >= other.i_;
}
XGBOOST_DEVICE bool operator!=(const Iterator &other) const {
return i_ < other.i_;
}
XGBOOST_DEVICE void Step(int s) { step_ = s; }
protected:
XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {}
public:
uint64_t i_;
int step_ = 1;
};
XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT
XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT
XGBOOST_DEVICE Range(int64_t begin, int64_t end)
: begin_(begin), end_(end) {}
XGBOOST_DEVICE void Step(int s) { begin_.Step(s); }
private:
Iterator begin_;
Iterator end_;
};
template <typename T>
__device__ Range GridStrideRange(T begin, T end) {
__device__ xgboost::common::Range GridStrideRange(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x;
Range r(begin, end);
xgboost::common::Range r(begin, end);
r.Step(gridDim.x * blockDim.x);
return r;
}
template <typename T>
__device__ Range BlockStrideRange(T begin, T end) {
__device__ xgboost::common::Range BlockStrideRange(T begin, T end) {
begin += threadIdx.x;
Range r(begin, end);
xgboost::common::Range r(begin, end);
r.Step(blockDim.x);
return r;
}
@ -557,7 +460,7 @@ class BulkAllocator {
BulkAllocator(BulkAllocator<MemoryT>&&) = delete;
void operator=(const BulkAllocator<MemoryT>&) = delete;
void operator=(BulkAllocator<MemoryT>&&) = delete;
~BulkAllocator() {
for (size_t i = 0; i < d_ptr_.size(); i++) {
if (!(d_ptr_[i] == nullptr)) {

122
src/common/gpu_set.h Normal file
View File

@ -0,0 +1,122 @@
/*!
* Copyright 2018 XGBoost contributors
*/
#ifndef XGBOOST_COMMON_GPU_SET_H_
#define XGBOOST_COMMON_GPU_SET_H_
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <limits>
#include <string>
#include "common.h"
#include "span.h"
#if defined(__CUDACC__)
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif
namespace dh {
#if defined(__CUDACC__)
/*
* Error handling functions
*/
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
int line) {
if (code != cudaSuccess) {
throw thrust::system_error(code, thrust::cuda_category(),
std::string{file} + "(" + // NOLINT
std::to_string(line) + ")");
}
return code;
}
#endif
} // namespace dh
namespace xgboost {
/* \brief set of devices across which HostDeviceVector can be distributed.
*
* Currently implemented as a range, but can be changed later to something else,
* e.g. a bitset
*/
class GPUSet {
public:
explicit GPUSet(int start = 0, int ndevices = 0)
: devices_(start, start + ndevices) {}
static GPUSet Empty() { return GPUSet(); }
static GPUSet Range(int start, int ndevices) {
return ndevices <= 0 ? Empty() : GPUSet{start, ndevices};
}
/* \brief ndevices and num_rows both are upper bounds. */
static GPUSet All(int ndevices, int num_rows = std::numeric_limits<int>::max()) {
int n_devices_visible = AllVisible().Size();
ndevices = ndevices < 0 ? n_devices_visible : ndevices;
// fix-up device number to be limited by number of rows
ndevices = ndevices > num_rows ? num_rows : ndevices;
return Range(0, ndevices);
}
static GPUSet AllVisible() {
int n_visgpus = 0;
#if defined(__CUDACC__)
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
#endif
return Range(0, n_visgpus);
}
/* \brief Ensure gpu_id is correct, so not dependent upon user knowing details */
static int GetDeviceIdx(int gpu_id) {
return (std::abs(gpu_id) + 0) % AllVisible().Size();
}
/* \brief Counting from gpu_id */
GPUSet Normalised(int gpu_id) const {
return Range(gpu_id, *devices_.end() + gpu_id);
}
/* \brief Counting from 0 */
GPUSet Unnormalised() const {
return Range(0, *devices_.end() - *devices_.begin());
}
int Size() const {
int res = *devices_.end() - *devices_.begin();
return res < 0 ? 0 : res;
}
int operator[](int index) const {
CHECK(index >= 0 && index < *(devices_.end()));
return *devices_.begin() + index;
}
bool IsEmpty() const { return Size() == 0; } // NOLINT
int Index(int device) const {
CHECK(Contains(device));
return device - *devices_.begin();
}
bool Contains(int device) const {
return *devices_.begin() <= device && device < *devices_.end();
}
common::Range::Iterator begin() const { return devices_.begin(); } // NOLINT
common::Range::Iterator end() const { return devices_.end(); } // NOLINT
friend bool operator==(const GPUSet& lhs, const GPUSet& rhs) {
return lhs.devices_ == rhs.devices_;
}
friend bool operator!=(const GPUSet& lhs, const GPUSet& rhs) {
return !(lhs == rhs);
}
private:
common::Range devices_;
};
} // namespace xgboost
#endif // XGBOOST_COMMON_GPU_SET_H_

View File

@ -271,7 +271,7 @@ struct GPUSketcher {
find_cuts_k<<<dh::DivRoundUp(n_cuts_cur_[icol], block), block>>>
(cuts_d_.data().get() + icol * n_cuts_, fvalues_cur_.data().get(),
weights2_.data().get(), n_unique, n_cuts_cur_[icol]);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaGetLastError()); // NOLINT
}
}
@ -311,14 +311,14 @@ struct GPUSketcher {
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_,
row_batch.offset[row_begin_ + batch_row_begin], batch_nrows);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaGetLastError()); // NOLINT
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol);
}
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
@ -379,7 +379,7 @@ struct GPUSketcher {
}
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
devices_ = GPUSet::Range(param_.gpu_id, dh::NDevices(param_.n_gpus, n_rows));
devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id);
}
std::vector<std::unique_ptr<DeviceShard>> shards_;

View File

@ -11,6 +11,7 @@
#include <initializer_list>
#include <vector>
#include "gpu_set.h"
#include "span.h"
// only include thrust-related files if host_device_vector.h
@ -23,40 +24,6 @@ namespace xgboost {
template <typename T> struct HostDeviceVectorImpl;
// set of devices across which HostDeviceVector can be distributed;
// currently implemented as a range, but can be changed later to something else,
// e.g. a bitset
class GPUSet {
public:
explicit GPUSet(int start = 0, int ndevices = 0)
: start_(start), ndevices_(ndevices) {}
static GPUSet Empty() { return GPUSet(); }
static GPUSet Range(int start, int ndevices) { return GPUSet(start, ndevices); }
int Size() const { return ndevices_; }
int operator[](int index) const {
CHECK(index >= 0 && index < ndevices_);
return start_ + index;
}
bool IsEmpty() const { return ndevices_ <= 0; }
int Index(int device) const {
CHECK(device >= start_ && device < start_ + ndevices_);
return device - start_;
}
bool Contains(int device) const {
return start_ <= device && device < start_ + ndevices_;
}
friend bool operator==(GPUSet a, GPUSet b) {
return a.start_ == b.start_ && a.ndevices_ == b.ndevices_;
}
friend bool operator!=(GPUSet a, GPUSet b) {
return a.start_ != b.start_ || a.ndevices_ != b.ndevices_;
}
private:
int start_, ndevices_;
};
/**
* @file host_device_vector.h
* @brief A device-and-host vector abstraction layer.

View File

@ -7,7 +7,8 @@
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include "gpu_set.h"
namespace xgboost {
namespace common {
@ -66,21 +67,21 @@ struct Monitor {
this->label = label;
}
void Start(const std::string &name) { timer_map[name].Start(); }
void Start(const std::string &name, std::vector<int> dList) {
void Start(const std::string &name, GPUSet devices) {
if (debug_verbose) {
#ifdef __CUDACC__
#include "device_helpers.cuh"
dh::SynchronizeNDevices(dList.size(), dList);
dh::SynchronizeNDevices(devices);
#endif
}
timer_map[name].Start();
}
void Stop(const std::string &name) { timer_map[name].Stop(); }
void Stop(const std::string &name, std::vector<int> dList) {
void Stop(const std::string &name, GPUSet devices) {
if (debug_verbose) {
#ifdef __CUDACC__
#include "device_helpers.cuh"
dh::SynchronizeNDevices(dList.size(), dList);
dh::SynchronizeNDevices(devices);
#endif
}
timer_map[name].Stop();

View File

@ -6,6 +6,7 @@
#include <thrust/execution_policy.h>
#include <thrust/inner_product.h>
#include <xgboost/linear_updater.h>
#include "../common/gpu_set.h"
#include "../common/device_helpers.cuh"
#include "../common/timer.h"
#include "coordinate_common.h"
@ -214,14 +215,14 @@ class GPUCoordinateUpdater : public LinearUpdater {
void LazyInitShards(DMatrix *p_fmat,
const gbm::GBLinearModelParam &model_param) {
if (!shards.empty()) return;
int n_devices = dh::NDevices(param.n_gpus, p_fmat->Info().num_row_);
int n_devices = GPUSet::All(param.n_gpus, p_fmat->Info().num_row_).Size();
bst_uint row_begin = 0;
bst_uint shard_size =
std::ceil(static_cast<double>(p_fmat->Info().num_row_) / n_devices);
device_list.resize(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
int device_idx = (param.gpu_id + d_idx) % dh::NVisibleDevices();
int device_idx = GPUSet::GetDeviceIdx(param.gpu_id + d_idx);
device_list[d_idx] = device_idx;
}
// Partition input matrix into row segments

View File

@ -102,8 +102,8 @@ class GPURegLossObj : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
devices_ = GPUSet::Range(param_.gpu_id, dh::NDevicesAll(param_.n_gpus));
// CHECK(param_.n_gpus != 0) << "Must have at least one device";
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
}
void GetGradient(HostDeviceVector<float>* preds,

View File

@ -11,6 +11,7 @@
#include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h>
#include <memory>
#include "../common/gpu_set.h"
#include "../common/device_helpers.cuh"
#include "../common/host_device_vector.h"
@ -464,7 +465,7 @@ class GPUPredictor : public xgboost::Predictor {
Predictor::Init(cfg, cache);
cpu_predictor->Init(cfg, cache);
param.InitAllowUnknown(cfg);
devices = GPUSet::Range(param.gpu_id, dh::NDevicesAll(param.n_gpus));
devices = GPUSet::All(param.n_gpus).Normalised(param.gpu_id);
max_shared_memory_bytes = dh::MaxSharedMemory(param.gpu_id);
}

View File

@ -4,6 +4,7 @@
#include <xgboost/tree_updater.h>
#include <utility>
#include <vector>
#include "../common/gpu_set.h"
#include "param.h"
#include "updater_gpu_common.cuh"
@ -375,7 +376,7 @@ void argMaxByKey(ExactSplitCandidate* nodeSplits, const GradientPair* gradScans,
NodeIdT nodeStart, int len, const TrainParam param,
ArgMaxByKeyAlgo algo) {
dh::FillConst<ExactSplitCandidate, BLKDIM, ITEMS_PER_THREAD>(
dh::GetDeviceIdx(param.gpu_id), nodeSplits, nUniqKeys,
GPUSet::GetDeviceIdx(param.gpu_id), nodeSplits, nUniqKeys,
ExactSplitCandidate());
int nBlks = dh::DivRoundUp(len, ITEMS_PER_THREAD * BLKDIM);
switch (algo) {
@ -498,7 +499,7 @@ class GPUMaker : public TreeUpdater {
// devices are only used for resharding the HostDeviceVector passed as a parameter;
// the algorithm works with a single GPU only
GPUSet devices;
GPUSet devices_;
dh::CubMemory tmp_mem;
dh::DVec<GradientPair> tmpScanGradBuff;
@ -516,7 +517,7 @@ class GPUMaker : public TreeUpdater {
maxNodes = (1 << (param.max_depth + 1)) - 1;
maxLeaves = 1 << param.max_depth;
devices = GPUSet::Range(param.gpu_id, dh::NDevicesAll(param.n_gpus));
devices_ = GPUSet::All(param.n_gpus).Normalised(param.gpu_id);
}
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
@ -526,7 +527,7 @@ class GPUMaker : public TreeUpdater {
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
gpair->Reshard(devices);
gpair->Reshard(devices_);
try {
// build tree
@ -624,7 +625,7 @@ class GPUMaker : public TreeUpdater {
void allocateAllData(int offsetSize) {
int tmpBuffSize = ScanTempBufferSize(nVals);
ba.Allocate(dh::GetDeviceIdx(param.gpu_id), param.silent, &vals, nVals,
ba.Allocate(GPUSet::GetDeviceIdx(param.gpu_id), param.silent, &vals, nVals,
&vals_cached, nVals, &instIds, nVals, &instIds_cached, nVals,
&colOffsets, offsetSize, &gradsInst, nRows, &nodeAssigns, nVals,
&nodeLocations, nVals, &nodes, maxNodes, &nodeAssignsPerInst,
@ -634,7 +635,7 @@ class GPUMaker : public TreeUpdater {
}
void setupOneTimeData(DMatrix* dmat) {
size_t free_memory = dh::AvailableMemory(dh::GetDeviceIdx(param.gpu_id));
size_t free_memory = dh::AvailableMemory(GPUSet::GetDeviceIdx(param.gpu_id));
if (!dmat->SingleColBlock()) {
throw std::runtime_error("exact::GPUBuilder - must have 1 column block");
}
@ -730,7 +731,7 @@ class GPUMaker : public TreeUpdater {
nodeAssigns.Current(), instIds.Current(), nodes.Data(),
colOffsets.Data(), vals.Current(), nVals, nCols);
// gather the node assignments across all other columns too
dh::Gather(dh::GetDeviceIdx(param.gpu_id), nodeAssigns.Current(),
dh::Gather(GPUSet::GetDeviceIdx(param.gpu_id), nodeAssigns.Current(),
nodeAssignsPerInst.Data(), instIds.Current(), nVals);
sortKeys(level);
}
@ -741,7 +742,7 @@ class GPUMaker : public TreeUpdater {
// but we don't need more than level+1 bits for sorting!
SegmentedSort(&tmp_mem, &nodeAssigns, &nodeLocations, nVals, nCols,
colOffsets, 0, level + 1);
dh::Gather<float, int>(dh::GetDeviceIdx(param.gpu_id), vals.other(),
dh::Gather<float, int>(GPUSet::GetDeviceIdx(param.gpu_id), vals.other(),
vals.Current(), instIds.other(), instIds.Current(),
nodeLocations.Current(), nVals);
vals.buff().selector ^= 1;

View File

@ -728,7 +728,7 @@ class GPUHistMaker : public TreeUpdater {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
n_devices_ = param_.n_gpus;
devices_ = GPUSet::Range(param_.gpu_id, dh::NDevicesAll(param_.n_gpus));
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
dh::CheckComputeCapability();
@ -743,7 +743,7 @@ class GPUHistMaker : public TreeUpdater {
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor_.Start("Update", device_list_);
monitor_.Start("Update", devices_);
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
@ -759,17 +759,17 @@ class GPUHistMaker : public TreeUpdater {
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
}
param_.learning_rate = lr;
monitor_.Stop("Update", device_list_);
monitor_.Stop("Update", devices_);
}
void InitDataOnce(DMatrix* dmat) {
info_ = &dmat->Info();
int n_devices = dh::NDevices(param_.n_gpus, info_->num_row_);
int n_devices = GPUSet::All(param_.n_gpus, info_->num_row_).Size();
device_list_.resize(n_devices);
for (int d_idx = 0; d_idx < n_devices; ++d_idx) {
int device_idx = (param_.gpu_id + d_idx) % dh::NVisibleDevices();
int device_idx = GPUSet::GetDeviceIdx(param_.gpu_id + d_idx);
device_list_[d_idx] = device_idx;
}
@ -792,16 +792,16 @@ class GPUHistMaker : public TreeUpdater {
shard->InitRowPtrs(batch);
});
monitor_.Start("Quantiles", device_list_);
monitor_.Start("Quantiles", devices_);
common::DeviceSketch(batch, *info_, param_, &hmat_);
n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", device_list_);
monitor_.Stop("Quantiles", devices_);
monitor_.Start("BinningCompression", device_list_);
monitor_.Start("BinningCompression", devices_);
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->InitCompressedData(hmat_, batch);
});
monitor_.Stop("BinningCompression", device_list_);
monitor_.Stop("BinningCompression", devices_);
CHECK(!iter->Next()) << "External memory not supported";
@ -811,20 +811,20 @@ class GPUHistMaker : public TreeUpdater {
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const RegTree& tree) {
monitor_.Start("InitDataOnce", device_list_);
monitor_.Start("InitDataOnce", devices_);
if (!initialised_) {
this->InitDataOnce(dmat);
}
monitor_.Stop("InitDataOnce", device_list_);
monitor_.Stop("InitDataOnce", devices_);
column_sampler_.Init(info_->num_col_, param_);
// Copy gpair & reset memory
monitor_.Start("InitDataReset", device_list_);
monitor_.Start("InitDataReset", devices_);
gpair->Reshard(devices_);
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {shard->Reset(gpair); });
monitor_.Stop("InitDataReset", device_list_);
monitor_.Stop("InitDataReset", devices_);
}
void AllReduceHist(int nidx) {
@ -1036,12 +1036,12 @@ class GPUHistMaker : public TreeUpdater {
RegTree* p_tree) {
auto& tree = *p_tree;
monitor_.Start("InitData", device_list_);
monitor_.Start("InitData", devices_);
this->InitData(gpair, p_fmat, *p_tree);
monitor_.Stop("InitData", device_list_);
monitor_.Start("InitRoot", device_list_);
monitor_.Stop("InitData", devices_);
monitor_.Start("InitRoot", devices_);
this->InitRoot(p_tree);
monitor_.Stop("InitRoot", device_list_);
monitor_.Stop("InitRoot", devices_);
auto timestamp = qexpand_->size();
auto num_leaves = 1;
@ -1051,9 +1051,9 @@ class GPUHistMaker : public TreeUpdater {
qexpand_->pop();
if (!candidate.IsValid(param_, num_leaves)) continue;
// std::cout << candidate;
monitor_.Start("ApplySplit", device_list_);
monitor_.Start("ApplySplit", devices_);
this->ApplySplit(candidate, p_tree);
monitor_.Stop("ApplySplit", device_list_);
monitor_.Stop("ApplySplit", devices_);
num_leaves++;
auto left_child_nidx = tree[candidate.nid].LeftChild();
@ -1062,12 +1062,12 @@ class GPUHistMaker : public TreeUpdater {
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor_.Start("BuildHist", device_list_);
monitor_.Start("BuildHist", devices_);
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
right_child_nidx);
monitor_.Stop("BuildHist", device_list_);
monitor_.Stop("BuildHist", devices_);
monitor_.Start("EvaluateSplits", device_list_);
monitor_.Start("EvaluateSplits", devices_);
auto splits =
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
qexpand_->push(ExpandEntry(left_child_nidx,
@ -1076,21 +1076,21 @@ class GPUHistMaker : public TreeUpdater {
qexpand_->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx), splits[1],
timestamp++));
monitor_.Stop("EvaluateSplits", device_list_);
monitor_.Stop("EvaluateSplits", devices_);
}
}
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
monitor_.Start("UpdatePredictionCache", device_list_);
monitor_.Start("UpdatePredictionCache", devices_);
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false;
p_out_preds->Reshard(devices_);
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx));
});
monitor_.Stop("UpdatePredictionCache", device_list_);
monitor_.Stop("UpdatePredictionCache", devices_);
return true;
}

View File

@ -0,0 +1,37 @@
#include "../../../src/common/gpu_set.h"
#include <gtest/gtest.h>
namespace xgboost {
TEST(GPUSet, Basic) {
GPUSet devices = GPUSet::Empty();
ASSERT_TRUE(devices.IsEmpty());
devices = GPUSet{0, 1};
ASSERT_TRUE(devices != GPUSet::Empty());
EXPECT_EQ(devices.Size(), 1);
EXPECT_ANY_THROW(devices.Index(1));
EXPECT_ANY_THROW(devices.Index(-1));
devices = GPUSet::Range(1, 0);
EXPECT_EQ(devices, GPUSet::Empty());
EXPECT_EQ(devices.Size(), 0);
EXPECT_TRUE(devices.IsEmpty());
EXPECT_FALSE(devices.Contains(1));
devices = GPUSet::Range(2, -1);
EXPECT_EQ(devices, GPUSet::Empty());
EXPECT_EQ(devices.Size(), 0);
EXPECT_TRUE(devices.IsEmpty());
devices = GPUSet::Range(2, 8);
EXPECT_EQ(devices.Size(), 8);
devices = devices.Unnormalised();
EXPECT_EQ(*devices.begin(), 0);
EXPECT_EQ(*devices.end(), devices.Size());
}
} // namespace xgboost