Implement devices to devices reshard. (#3721)
* Force clearing device memory before Reshard. * Remove calculating row_segments for gpu_hist and gpu_sketch. * Guard against changing device.
This commit is contained in:
parent
0b7fd74138
commit
5a7f7e7d49
@ -8,6 +8,8 @@ namespace xgboost {
|
|||||||
int AllVisibleImpl::AllVisible() {
|
int AllVisibleImpl::AllVisible() {
|
||||||
int n_visgpus = 0;
|
int n_visgpus = 0;
|
||||||
try {
|
try {
|
||||||
|
// When compiled with CUDA but running on CPU only device,
|
||||||
|
// cudaGetDeviceCount will fail.
|
||||||
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
||||||
} catch(const std::exception& e) {
|
} catch(const std::exception& e) {
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
@ -110,7 +110,7 @@ inline void CheckComputeCapability() {
|
|||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
oss << "CUDA Capability Major/Minor version number: " << prop.major << "."
|
oss << "CUDA Capability Major/Minor version number: " << prop.major << "."
|
||||||
<< prop.minor << " is insufficient. Need >=3.5";
|
<< prop.minor << " is insufficient. Need >=3.5";
|
||||||
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
|
int failed = prop.major < 3 || (prop.major == 3 && prop.minor < 5);
|
||||||
if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx;
|
if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -129,15 +129,10 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
|
|||||||
* than all elements of the array
|
* than all elements of the array
|
||||||
*/
|
*/
|
||||||
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
|
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
|
||||||
if (n == 0) {
|
if (n == 0) { return 0; }
|
||||||
return 0;
|
if (cuts[n - 1] <= v) { return n; }
|
||||||
}
|
if (cuts[0] > v) { return 0; }
|
||||||
if (cuts[n - 1] <= v) {
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
if (cuts[0] > v) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
int left = 0, right = n - 1;
|
int left = 0, right = n - 1;
|
||||||
while (right - left > 1) {
|
while (right - left > 1) {
|
||||||
int middle = left + (right - left) / 2;
|
int middle = left + (right - left) / 2;
|
||||||
@ -184,18 +179,6 @@ T1 DivRoundUp(const T1 a, const T2 b) {
|
|||||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void RowSegments(size_t n_rows, size_t n_devices, std::vector<size_t>* segments) {
|
|
||||||
segments->push_back(0);
|
|
||||||
size_t row_begin = 0;
|
|
||||||
size_t shard_size = DivRoundUp(n_rows, n_devices);
|
|
||||||
for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) {
|
|
||||||
size_t row_end = std::min(row_begin + shard_size, n_rows);
|
|
||||||
segments->push_back(row_end);
|
|
||||||
row_begin = row_end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename L>
|
template <typename L>
|
||||||
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
||||||
for (auto i : GridStrideRange(begin, end)) {
|
for (auto i : GridStrideRange(begin, end)) {
|
||||||
@ -322,8 +305,8 @@ class DVec {
|
|||||||
void copy(IterT begin, IterT end) {
|
void copy(IterT begin, IterT end) {
|
||||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||||
if (end - begin != Size()) {
|
if (end - begin != Size()) {
|
||||||
throw std::runtime_error(
|
LOG(FATAL) << "Cannot copy assign vector to DVec, sizes are different" <<
|
||||||
"Cannot copy assign vector to DVec, sizes are different");
|
" vector::Size(): " << end - begin << " DVec::Size(): " << Size();
|
||||||
}
|
}
|
||||||
thrust::copy(begin, end, this->tbegin());
|
thrust::copy(begin, end, this->tbegin());
|
||||||
}
|
}
|
||||||
@ -961,6 +944,29 @@ class AllReducer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SaveCudaContext {
|
||||||
|
private:
|
||||||
|
int saved_device_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <typename Functor>
|
||||||
|
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
|
||||||
|
// When compiled with CUDA but running on CPU only device,
|
||||||
|
// cudaGetDevice will fail.
|
||||||
|
try {
|
||||||
|
safe_cuda(cudaGetDevice(&saved_device_));
|
||||||
|
} catch (thrust::system::system_error & err) {
|
||||||
|
saved_device_ = -1;
|
||||||
|
}
|
||||||
|
func();
|
||||||
|
}
|
||||||
|
~SaveCudaContext() {
|
||||||
|
if (saved_device_ != -1) {
|
||||||
|
safe_cuda(cudaSetDevice(saved_device_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Executes some operation on each element of the input vector, using a
|
* \brief Executes some operation on each element of the input vector, using a
|
||||||
* single controlling thread for each element.
|
* single controlling thread for each element.
|
||||||
@ -973,10 +979,13 @@ class AllReducer {
|
|||||||
|
|
||||||
template <typename T, typename FunctionT>
|
template <typename T, typename FunctionT>
|
||||||
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||||
|
SaveCudaContext {
|
||||||
|
[&](){
|
||||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||||
f(shards->at(shard));
|
f(shards->at(shard));
|
||||||
}
|
}
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -992,10 +1001,13 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
|||||||
|
|
||||||
template <typename T, typename FunctionT>
|
template <typename T, typename FunctionT>
|
||||||
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||||
|
SaveCudaContext {
|
||||||
|
[&](){
|
||||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||||
f(shard, shards->at(shard));
|
f(shard, shards->at(shard));
|
||||||
}
|
}
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1011,13 +1023,16 @@ void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
|||||||
* \return A reduce_t.
|
* \return A reduce_t.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename ReduceT,typename T, typename FunctionT>
|
template <typename ReduceT, typename ShardT, typename FunctionT>
|
||||||
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
|
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
|
||||||
std::vector<ReduceT> sums(shards->size());
|
std::vector<ReduceT> sums(shards->size());
|
||||||
|
SaveCudaContext {
|
||||||
|
[&](){
|
||||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||||
sums[shard] = f(shards->at(shard));
|
sums[shard] = f(shards->at(shard));
|
||||||
}
|
}}
|
||||||
|
};
|
||||||
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||||
}
|
}
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -17,7 +17,6 @@ namespace xgboost {
|
|||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
|
|
||||||
// safe factor for better accuracy
|
// safe factor for better accuracy
|
||||||
|
|||||||
@ -347,15 +347,13 @@ struct GPUSketcher {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
|
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
|
||||||
// partition input matrix into row segments
|
|
||||||
std::vector<size_t> row_segments;
|
|
||||||
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
|
|
||||||
|
|
||||||
// create device shards
|
// create device shards
|
||||||
shards_.resize(devices_.Size());
|
shards_.resize(dist_.Devices().Size());
|
||||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||||
|
size_t start = dist_.ShardStart(info.num_row_, i);
|
||||||
|
size_t size = dist_.ShardSize(info.num_row_, i);
|
||||||
shard = std::unique_ptr<DeviceShard>
|
shard = std::unique_ptr<DeviceShard>
|
||||||
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
|
(new DeviceShard(dist_.Devices()[i], start, start + size, param_));
|
||||||
});
|
});
|
||||||
|
|
||||||
// compute sketches for each shard
|
// compute sketches for each shard
|
||||||
@ -381,12 +379,13 @@ struct GPUSketcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
|
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
|
||||||
devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id);
|
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows).
|
||||||
|
Normalised(param_.gpu_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DeviceShard>> shards_;
|
std::vector<std::unique_ptr<DeviceShard>> shards_;
|
||||||
tree::TrainParam param_;
|
tree::TrainParam param_;
|
||||||
GPUSet devices_;
|
GPUDistribution dist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void DeviceSketch
|
void DeviceSketch
|
||||||
|
|||||||
@ -67,7 +67,7 @@ struct HistCutUnit {
|
|||||||
: cut(cut), size(size) {}
|
: cut(cut), size(size) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief cut configuration for all the features */
|
/*! \brief cut configuration for all the features. */
|
||||||
struct HistCutMatrix {
|
struct HistCutMatrix {
|
||||||
/*! \brief unit pointer to rows by element position */
|
/*! \brief unit pointer to rows by element position */
|
||||||
std::vector<uint32_t> row_ptr;
|
std::vector<uint32_t> row_ptr;
|
||||||
|
|||||||
@ -289,6 +289,7 @@ struct HostDeviceVectorImpl {
|
|||||||
data_h_.size() * sizeof(T),
|
data_h_.size() * sizeof(T),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
} else {
|
} else {
|
||||||
|
//
|
||||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
|
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -347,7 +348,10 @@ struct HostDeviceVectorImpl {
|
|||||||
|
|
||||||
void Reshard(const GPUDistribution& distribution) {
|
void Reshard(const GPUDistribution& distribution) {
|
||||||
if (distribution_ == distribution) { return; }
|
if (distribution_ == distribution) { return; }
|
||||||
CHECK(distribution_.IsEmpty());
|
CHECK(distribution_.IsEmpty() || distribution.IsEmpty());
|
||||||
|
if (distribution.IsEmpty()) {
|
||||||
|
LazySyncHost(GPUAccess::kWrite);
|
||||||
|
}
|
||||||
distribution_ = distribution;
|
distribution_ = distribution;
|
||||||
InitShards();
|
InitShards();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -243,6 +243,11 @@ class HostDeviceVector {
|
|||||||
bool HostCanAccess(GPUAccess access) const;
|
bool HostCanAccess(GPUAccess access) const;
|
||||||
bool DeviceCanAccess(int device, GPUAccess access) const;
|
bool DeviceCanAccess(int device, GPUAccess access) const;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Specify memory distribution.
|
||||||
|
*
|
||||||
|
* If GPUSet::Empty() is used, all data will be drawn back to CPU.
|
||||||
|
*/
|
||||||
void Reshard(const GPUDistribution& distribution) const;
|
void Reshard(const GPUDistribution& distribution) const;
|
||||||
void Reshard(GPUSet devices) const;
|
void Reshard(GPUSet devices) const;
|
||||||
void Resize(size_t new_size, T v = T());
|
void Resize(size_t new_size, T v = T());
|
||||||
|
|||||||
@ -372,8 +372,8 @@ struct DeviceShard {
|
|||||||
|
|
||||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||||
DeviceShard(int device_idx, int normalised_device_idx,
|
DeviceShard(int device_idx, int normalised_device_idx,
|
||||||
bst_uint row_begin, bst_uint row_end, TrainParam param)
|
bst_uint row_begin, bst_uint row_end, TrainParam param) :
|
||||||
: device_idx(device_idx),
|
device_idx(device_idx),
|
||||||
normalised_device_idx(normalised_device_idx),
|
normalised_device_idx(normalised_device_idx),
|
||||||
row_begin_idx(row_begin),
|
row_begin_idx(row_begin),
|
||||||
row_end_idx(row_end),
|
row_end_idx(row_end),
|
||||||
@ -754,7 +754,9 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
param_.InitAllowUnknown(args);
|
param_.InitAllowUnknown(args);
|
||||||
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
||||||
n_devices_ = param_.n_gpus;
|
n_devices_ = param_.n_gpus;
|
||||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
dist_ =
|
||||||
|
GPUDistribution::Block(GPUSet::All(param_.n_gpus)
|
||||||
|
.Normalised(param_.gpu_id));
|
||||||
|
|
||||||
dh::CheckComputeCapability();
|
dh::CheckComputeCapability();
|
||||||
|
|
||||||
@ -769,7 +771,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
const std::vector<RegTree*>& trees) override {
|
const std::vector<RegTree*>& trees) override {
|
||||||
monitor_.Start("Update", devices_);
|
monitor_.Start("Update", dist_.Devices());
|
||||||
GradStats::CheckInfo(dmat->Info());
|
GradStats::CheckInfo(dmat->Info());
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
@ -785,7 +787,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
monitor_.Stop("Update", devices_);
|
monitor_.Stop("Update", dist_.Devices());
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitDataOnce(DMatrix* dmat) {
|
void InitDataOnce(DMatrix* dmat) {
|
||||||
@ -801,10 +803,6 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
reducer_.Init(device_list_);
|
reducer_.Init(device_list_);
|
||||||
|
|
||||||
// Partition input matrix into row segments
|
|
||||||
std::vector<size_t> row_segments;
|
|
||||||
dh::RowSegments(info_->num_row_, n_devices, &row_segments);
|
|
||||||
|
|
||||||
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
CHECK(iter->Next()) << "Empty batches are not supported";
|
CHECK(iter->Next()) << "Empty batches are not supported";
|
||||||
@ -812,22 +810,24 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
// Create device shards
|
// Create device shards
|
||||||
shards_.resize(n_devices);
|
shards_.resize(n_devices);
|
||||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||||
|
size_t start = dist_.ShardStart(info_->num_row_, i);
|
||||||
|
size_t size = dist_.ShardSize(info_->num_row_, i);
|
||||||
shard = std::unique_ptr<DeviceShard>
|
shard = std::unique_ptr<DeviceShard>
|
||||||
(new DeviceShard(device_list_[i], i,
|
(new DeviceShard(device_list_.at(i), i,
|
||||||
row_segments[i], row_segments[i + 1], param_));
|
start, start + size, param_));
|
||||||
shard->InitRowPtrs(batch);
|
shard->InitRowPtrs(batch);
|
||||||
});
|
});
|
||||||
|
|
||||||
monitor_.Start("Quantiles", devices_);
|
monitor_.Start("Quantiles", dist_.Devices());
|
||||||
common::DeviceSketch(batch, *info_, param_, &hmat_);
|
common::DeviceSketch(batch, *info_, param_, &hmat_);
|
||||||
n_bins_ = hmat_.row_ptr.back();
|
n_bins_ = hmat_.row_ptr.back();
|
||||||
monitor_.Stop("Quantiles", devices_);
|
monitor_.Stop("Quantiles", dist_.Devices());
|
||||||
|
|
||||||
monitor_.Start("BinningCompression", devices_);
|
monitor_.Start("BinningCompression", dist_.Devices());
|
||||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||||
shard->InitCompressedData(hmat_, batch);
|
shard->InitCompressedData(hmat_, batch);
|
||||||
});
|
});
|
||||||
monitor_.Stop("BinningCompression", devices_);
|
monitor_.Stop("BinningCompression", dist_.Devices());
|
||||||
|
|
||||||
CHECK(!iter->Next()) << "External memory not supported";
|
CHECK(!iter->Next()) << "External memory not supported";
|
||||||
|
|
||||||
@ -837,20 +837,22 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||||
const RegTree& tree) {
|
const RegTree& tree) {
|
||||||
monitor_.Start("InitDataOnce", devices_);
|
monitor_.Start("InitDataOnce", dist_.Devices());
|
||||||
if (!initialised_) {
|
if (!initialised_) {
|
||||||
this->InitDataOnce(dmat);
|
this->InitDataOnce(dmat);
|
||||||
}
|
}
|
||||||
monitor_.Stop("InitDataOnce", devices_);
|
monitor_.Stop("InitDataOnce", dist_.Devices());
|
||||||
|
|
||||||
column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
|
column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
|
||||||
|
|
||||||
// Copy gpair & reset memory
|
// Copy gpair & reset memory
|
||||||
monitor_.Start("InitDataReset", devices_);
|
monitor_.Start("InitDataReset", dist_.Devices());
|
||||||
|
|
||||||
gpair->Reshard(devices_);
|
gpair->Reshard(dist_);
|
||||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {shard->Reset(gpair); });
|
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||||
monitor_.Stop("InitDataReset", devices_);
|
shard->Reset(gpair);
|
||||||
|
});
|
||||||
|
monitor_.Stop("InitDataReset", dist_.Devices());
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduceHist(int nidx) {
|
void AllReduceHist(int nidx) {
|
||||||
@ -1081,12 +1083,12 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
auto& tree = *p_tree;
|
auto& tree = *p_tree;
|
||||||
|
|
||||||
monitor_.Start("InitData", devices_);
|
monitor_.Start("InitData", dist_.Devices());
|
||||||
this->InitData(gpair, p_fmat, *p_tree);
|
this->InitData(gpair, p_fmat, *p_tree);
|
||||||
monitor_.Stop("InitData", devices_);
|
monitor_.Stop("InitData", dist_.Devices());
|
||||||
monitor_.Start("InitRoot", devices_);
|
monitor_.Start("InitRoot", dist_.Devices());
|
||||||
this->InitRoot(p_tree);
|
this->InitRoot(p_tree);
|
||||||
monitor_.Stop("InitRoot", devices_);
|
monitor_.Stop("InitRoot", dist_.Devices());
|
||||||
|
|
||||||
auto timestamp = qexpand_->size();
|
auto timestamp = qexpand_->size();
|
||||||
auto num_leaves = 1;
|
auto num_leaves = 1;
|
||||||
@ -1096,9 +1098,9 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
qexpand_->pop();
|
qexpand_->pop();
|
||||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||||
// std::cout << candidate;
|
// std::cout << candidate;
|
||||||
monitor_.Start("ApplySplit", devices_);
|
monitor_.Start("ApplySplit", dist_.Devices());
|
||||||
this->ApplySplit(candidate, p_tree);
|
this->ApplySplit(candidate, p_tree);
|
||||||
monitor_.Stop("ApplySplit", devices_);
|
monitor_.Stop("ApplySplit", dist_.Devices());
|
||||||
num_leaves++;
|
num_leaves++;
|
||||||
|
|
||||||
auto left_child_nidx = tree[candidate.nid].LeftChild();
|
auto left_child_nidx = tree[candidate.nid].LeftChild();
|
||||||
@ -1107,12 +1109,12 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
// Only create child entries if needed
|
// Only create child entries if needed
|
||||||
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
||||||
num_leaves)) {
|
num_leaves)) {
|
||||||
monitor_.Start("BuildHist", devices_);
|
monitor_.Start("BuildHist", dist_.Devices());
|
||||||
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
||||||
right_child_nidx);
|
right_child_nidx);
|
||||||
monitor_.Stop("BuildHist", devices_);
|
monitor_.Stop("BuildHist", dist_.Devices());
|
||||||
|
|
||||||
monitor_.Start("EvaluateSplits", devices_);
|
monitor_.Start("EvaluateSplits", dist_.Devices());
|
||||||
auto splits =
|
auto splits =
|
||||||
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
|
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
|
||||||
qexpand_->push(ExpandEntry(left_child_nidx,
|
qexpand_->push(ExpandEntry(left_child_nidx,
|
||||||
@ -1121,21 +1123,21 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
qexpand_->push(ExpandEntry(right_child_nidx,
|
qexpand_->push(ExpandEntry(right_child_nidx,
|
||||||
tree.GetDepth(right_child_nidx), splits[1],
|
tree.GetDepth(right_child_nidx), splits[1],
|
||||||
timestamp++));
|
timestamp++));
|
||||||
monitor_.Stop("EvaluateSplits", devices_);
|
monitor_.Stop("EvaluateSplits", dist_.Devices());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UpdatePredictionCache(
|
bool UpdatePredictionCache(
|
||||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||||
monitor_.Start("UpdatePredictionCache", devices_);
|
monitor_.Start("UpdatePredictionCache", dist_.Devices());
|
||||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||||
return false;
|
return false;
|
||||||
p_out_preds->Reshard(devices_);
|
p_out_preds->Reshard(dist_.Devices());
|
||||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||||
shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx));
|
shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx));
|
||||||
});
|
});
|
||||||
monitor_.Stop("UpdatePredictionCache", devices_);
|
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1208,7 +1210,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
std::vector<int> device_list_;
|
std::vector<int> device_list_;
|
||||||
|
|
||||||
DMatrix* p_last_fmat_;
|
DMatrix* p_last_fmat_;
|
||||||
GPUSet devices_;
|
GPUDistribution dist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||||
|
|||||||
@ -8,6 +8,17 @@
|
|||||||
#include "../../../src/common/timer.h"
|
#include "../../../src/common/timer.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
struct Shard { int id; };
|
||||||
|
|
||||||
|
TEST(DeviceHelpers, Basic) {
|
||||||
|
std::vector<Shard> shards (4);
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
shards[i].id = i;
|
||||||
|
}
|
||||||
|
int sum = dh::ReduceShards<int>(&shards, [](Shard& s) { return s.id ; });
|
||||||
|
ASSERT_EQ(sum, 6);
|
||||||
|
}
|
||||||
|
|
||||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||||
thrust::host_vector<int> *row_ptr,
|
thrust::host_vector<int> *row_ptr,
|
||||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
thrust::host_vector<xgboost::bst_uint> *rows) {
|
||||||
|
|||||||
@ -28,7 +28,7 @@ TEST(gpu_hist_util, TestDeviceSketch) {
|
|||||||
tree::TrainParam p;
|
tree::TrainParam p;
|
||||||
p.max_bin = 20;
|
p.max_bin = 20;
|
||||||
p.gpu_id = 0;
|
p.gpu_id = 0;
|
||||||
p.n_gpus = 1;
|
p.n_gpus = GPUSet::AllVisible().Size();
|
||||||
// ensure that the exact quantiles are found
|
// ensure that the exact quantiles are found
|
||||||
p.gpu_batch_nrows = nrows * 10;
|
p.gpu_batch_nrows = nrows * 10;
|
||||||
|
|
||||||
|
|||||||
@ -178,6 +178,52 @@ TEST(HostDeviceVector, TestCopy) {
|
|||||||
SetCudaSetDeviceHandler(nullptr);
|
SetCudaSetDeviceHandler(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The test is not really useful if n_gpus < 2
|
||||||
|
TEST(HostDeviceVector, Reshard) {
|
||||||
|
std::vector<int> h_vec (2345);
|
||||||
|
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||||
|
h_vec[i] = i;
|
||||||
|
}
|
||||||
|
HostDeviceVector<int> vec (h_vec);
|
||||||
|
auto devices = GPUSet::AllVisible();
|
||||||
|
std::vector<size_t> devices_size (devices.Size());
|
||||||
|
|
||||||
|
// From CPU to GPUs.
|
||||||
|
// Assuming we have > 1 devices.
|
||||||
|
vec.Reshard(devices);
|
||||||
|
size_t total_size = 0;
|
||||||
|
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||||
|
total_size += vec.DeviceSize(i);
|
||||||
|
devices_size[i] = vec.DeviceSize(i);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(total_size, h_vec.size());
|
||||||
|
ASSERT_EQ(total_size, vec.Size());
|
||||||
|
auto h_vec_1 = vec.HostVector();
|
||||||
|
|
||||||
|
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||||
|
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||||
|
|
||||||
|
// Shrink down the number of devices.
|
||||||
|
vec.Reshard(GPUSet::Range(0, 1));
|
||||||
|
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||||
|
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
|
||||||
|
h_vec_1 = vec.HostVector();
|
||||||
|
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||||
|
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||||
|
|
||||||
|
// Grow the number of devices.
|
||||||
|
vec.Reshard(devices);
|
||||||
|
total_size = 0;
|
||||||
|
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||||
|
total_size += vec.DeviceSize(i);
|
||||||
|
ASSERT_EQ(devices_size[i], vec.DeviceSize(i));
|
||||||
|
}
|
||||||
|
ASSERT_EQ(total_size, h_vec.size());
|
||||||
|
ASSERT_EQ(total_size, vec.Size());
|
||||||
|
h_vec_1 = vec.HostVector();
|
||||||
|
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(HostDeviceVector, Span) {
|
TEST(HostDeviceVector, Span) {
|
||||||
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
||||||
vec.Reshard(GPUSet{0, 1});
|
vec.Reshard(GPUSet{0, 1});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user