Implement devices to devices reshard. (#3721)

* Force clearing device memory before Reshard.
* Remove calculating row_segments for gpu_hist and gpu_sketch.
* Guard against changing device.
This commit is contained in:
trivialfis
2018-09-28 17:40:23 +12:00
committed by Rory Mitchell
parent 0b7fd74138
commit 5a7f7e7d49
11 changed files with 179 additions and 96 deletions

View File

@@ -8,6 +8,8 @@ namespace xgboost {
int AllVisibleImpl::AllVisible() {
int n_visgpus = 0;
try {
// When compiled with CUDA but running on CPU only device,
// cudaGetDeviceCount will fail.
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
} catch(const std::exception& e) {
return 0;

View File

@@ -110,7 +110,7 @@ inline void CheckComputeCapability() {
std::ostringstream oss;
oss << "CUDA Capability Major/Minor version number: " << prop.major << "."
<< prop.minor << " is insufficient. Need >=3.5";
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
int failed = prop.major < 3 || (prop.major == 3 && prop.minor < 5);
if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx;
}
}
@@ -129,15 +129,10 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
* than all elements of the array
*/
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
if (n == 0) {
return 0;
}
if (cuts[n - 1] <= v) {
return n;
}
if (cuts[0] > v) {
return 0;
}
if (n == 0) { return 0; }
if (cuts[n - 1] <= v) { return n; }
if (cuts[0] > v) { return 0; }
int left = 0, right = n - 1;
while (right - left > 1) {
int middle = left + (right - left) / 2;
@@ -145,7 +140,7 @@ DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
right = middle;
} else {
left = middle;
}
}
}
return right;
}
@@ -184,18 +179,6 @@ T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(ceil(static_cast<double>(a) / b));
}
inline void RowSegments(size_t n_rows, size_t n_devices, std::vector<size_t>* segments) {
segments->push_back(0);
size_t row_begin = 0;
size_t shard_size = DivRoundUp(n_rows, n_devices);
for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) {
size_t row_end = std::min(row_begin + shard_size, n_rows);
segments->push_back(row_end);
row_begin = row_end;
}
}
template <typename L>
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
for (auto i : GridStrideRange(begin, end)) {
@@ -322,8 +305,8 @@ class DVec {
void copy(IterT begin, IterT end) {
safe_cuda(cudaSetDevice(this->DeviceIdx()));
if (end - begin != Size()) {
throw std::runtime_error(
"Cannot copy assign vector to DVec, sizes are different");
LOG(FATAL) << "Cannot copy assign vector to DVec, sizes are different" <<
" vector::Size(): " << end - begin << " DVec::Size(): " << Size();
}
thrust::copy(begin, end, this->tbegin());
}
@@ -961,6 +944,29 @@ class AllReducer {
}
};
class SaveCudaContext {
private:
int saved_device_;
public:
template <typename Functor>
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
// When compiled with CUDA but running on CPU only device,
// cudaGetDevice will fail.
try {
safe_cuda(cudaGetDevice(&saved_device_));
} catch (thrust::system::system_error & err) {
saved_device_ = -1;
}
func();
}
~SaveCudaContext() {
if (saved_device_ != -1) {
safe_cuda(cudaSetDevice(saved_device_));
}
}
};
/**
* \brief Executes some operation on each element of the input vector, using a
* single controlling thread for each element.
@@ -973,10 +979,13 @@ class AllReducer {
template <typename T, typename FunctionT>
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
f(shards->at(shard));
}
for (int shard = 0; shard < shards->size(); ++shard) {
f(shards->at(shard));
}
}};
}
/**
@@ -992,10 +1001,13 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard));
}
for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard));
}
}};
}
/**
@@ -1011,13 +1023,16 @@ void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
* \return A reduce_t.
*/
template <typename ReduceT,typename T, typename FunctionT>
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
template <typename ReduceT, typename ShardT, typename FunctionT>
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
std::vector<ReduceT> sums(shards->size());
SaveCudaContext {
[&](){
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) {
sums[shard] = f(shards->at(shard));
}
for (int shard = 0; shard < shards->size(); ++shard) {
sums[shard] = f(shards->at(shard));
}}
};
return std::accumulate(sums.begin(), sums.end(), ReduceT());
}
} // namespace dh

View File

@@ -17,7 +17,6 @@ namespace xgboost {
namespace common {
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
const MetaInfo& info = p_fmat->Info();
// safe factor for better accuracy

View File

@@ -347,15 +347,13 @@ struct GPUSketcher {
};
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
// partition input matrix into row segments
std::vector<size_t> row_segments;
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
// create device shards
shards_.resize(devices_.Size());
shards_.resize(dist_.Devices().Size());
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
size_t start = dist_.ShardStart(info.num_row_, i);
size_t size = dist_.ShardSize(info.num_row_, i);
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
(new DeviceShard(dist_.Devices()[i], start, start + size, param_));
});
// compute sketches for each shard
@@ -381,12 +379,13 @@ struct GPUSketcher {
}
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id);
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows).
Normalised(param_.gpu_id));
}
std::vector<std::unique_ptr<DeviceShard>> shards_;
tree::TrainParam param_;
GPUSet devices_;
GPUDistribution dist_;
};
void DeviceSketch

View File

@@ -67,7 +67,7 @@ struct HistCutUnit {
: cut(cut), size(size) {}
};
/*! \brief cut configuration for all the features */
/*! \brief cut configuration for all the features. */
struct HistCutMatrix {
/*! \brief unit pointer to rows by element position */
std::vector<uint32_t> row_ptr;

View File

@@ -289,6 +289,7 @@ struct HostDeviceVectorImpl {
data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice));
} else {
//
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
}
}
@@ -347,7 +348,10 @@ struct HostDeviceVectorImpl {
void Reshard(const GPUDistribution& distribution) {
if (distribution_ == distribution) { return; }
CHECK(distribution_.IsEmpty());
CHECK(distribution_.IsEmpty() || distribution.IsEmpty());
if (distribution.IsEmpty()) {
LazySyncHost(GPUAccess::kWrite);
}
distribution_ = distribution;
InitShards();
}

View File

@@ -243,6 +243,11 @@ class HostDeviceVector {
bool HostCanAccess(GPUAccess access) const;
bool DeviceCanAccess(int device, GPUAccess access) const;
/*!
* \brief Specify memory distribution.
*
* If GPUSet::Empty() is used, all data will be drawn back to CPU.
*/
void Reshard(const GPUDistribution& distribution) const;
void Reshard(GPUSet devices) const;
void Resize(size_t new_size, T v = T());