Implement devices to devices reshard. (#3721)
* Force clearing device memory before Reshard. * Remove calculating row_segments for gpu_hist and gpu_sketch. * Guard against changing device.
This commit is contained in:
parent
0b7fd74138
commit
5a7f7e7d49
@ -8,6 +8,8 @@ namespace xgboost {
|
||||
int AllVisibleImpl::AllVisible() {
|
||||
int n_visgpus = 0;
|
||||
try {
|
||||
// When compiled with CUDA but running on CPU only device,
|
||||
// cudaGetDeviceCount will fail.
|
||||
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
||||
} catch(const std::exception& e) {
|
||||
return 0;
|
||||
|
||||
@ -110,7 +110,7 @@ inline void CheckComputeCapability() {
|
||||
std::ostringstream oss;
|
||||
oss << "CUDA Capability Major/Minor version number: " << prop.major << "."
|
||||
<< prop.minor << " is insufficient. Need >=3.5";
|
||||
int failed = prop.major < 3 || prop.major == 3 && prop.minor < 5;
|
||||
int failed = prop.major < 3 || (prop.major == 3 && prop.minor < 5);
|
||||
if (failed) LOG(WARNING) << oss.str() << " for device: " << d_idx;
|
||||
}
|
||||
}
|
||||
@ -129,15 +129,10 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
|
||||
* than all elements of the array
|
||||
*/
|
||||
DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
|
||||
if (n == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (cuts[n - 1] <= v) {
|
||||
return n;
|
||||
}
|
||||
if (cuts[0] > v) {
|
||||
return 0;
|
||||
}
|
||||
if (n == 0) { return 0; }
|
||||
if (cuts[n - 1] <= v) { return n; }
|
||||
if (cuts[0] > v) { return 0; }
|
||||
|
||||
int left = 0, right = n - 1;
|
||||
while (right - left > 1) {
|
||||
int middle = left + (right - left) / 2;
|
||||
@ -145,7 +140,7 @@ DEV_INLINE int UpperBound(const float* __restrict__ cuts, int n, float v) {
|
||||
right = middle;
|
||||
} else {
|
||||
left = middle;
|
||||
}
|
||||
}
|
||||
}
|
||||
return right;
|
||||
}
|
||||
@ -184,18 +179,6 @@ T1 DivRoundUp(const T1 a, const T2 b) {
|
||||
return static_cast<T1>(ceil(static_cast<double>(a) / b));
|
||||
}
|
||||
|
||||
inline void RowSegments(size_t n_rows, size_t n_devices, std::vector<size_t>* segments) {
|
||||
segments->push_back(0);
|
||||
size_t row_begin = 0;
|
||||
size_t shard_size = DivRoundUp(n_rows, n_devices);
|
||||
for (size_t d_idx = 0; d_idx < n_devices; ++d_idx) {
|
||||
size_t row_end = std::min(row_begin + shard_size, n_rows);
|
||||
segments->push_back(row_end);
|
||||
row_begin = row_end;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename L>
|
||||
__global__ void LaunchNKernel(size_t begin, size_t end, L lambda) {
|
||||
for (auto i : GridStrideRange(begin, end)) {
|
||||
@ -322,8 +305,8 @@ class DVec {
|
||||
void copy(IterT begin, IterT end) {
|
||||
safe_cuda(cudaSetDevice(this->DeviceIdx()));
|
||||
if (end - begin != Size()) {
|
||||
throw std::runtime_error(
|
||||
"Cannot copy assign vector to DVec, sizes are different");
|
||||
LOG(FATAL) << "Cannot copy assign vector to DVec, sizes are different" <<
|
||||
" vector::Size(): " << end - begin << " DVec::Size(): " << Size();
|
||||
}
|
||||
thrust::copy(begin, end, this->tbegin());
|
||||
}
|
||||
@ -961,6 +944,29 @@ class AllReducer {
|
||||
}
|
||||
};
|
||||
|
||||
class SaveCudaContext {
|
||||
private:
|
||||
int saved_device_;
|
||||
|
||||
public:
|
||||
template <typename Functor>
|
||||
explicit SaveCudaContext (Functor func) : saved_device_{-1} {
|
||||
// When compiled with CUDA but running on CPU only device,
|
||||
// cudaGetDevice will fail.
|
||||
try {
|
||||
safe_cuda(cudaGetDevice(&saved_device_));
|
||||
} catch (thrust::system::system_error & err) {
|
||||
saved_device_ = -1;
|
||||
}
|
||||
func();
|
||||
}
|
||||
~SaveCudaContext() {
|
||||
if (saved_device_ != -1) {
|
||||
safe_cuda(cudaSetDevice(saved_device_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Executes some operation on each element of the input vector, using a
|
||||
* single controlling thread for each element.
|
||||
@ -973,10 +979,13 @@ class AllReducer {
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shards->at(shard));
|
||||
}
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shards->at(shard));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/**
|
||||
@ -992,10 +1001,13 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
|
||||
|
||||
template <typename T, typename FunctionT>
|
||||
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
f(shard, shards->at(shard));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1011,13 +1023,16 @@ void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
|
||||
* \return A reduce_t.
|
||||
*/
|
||||
|
||||
template <typename ReduceT,typename T, typename FunctionT>
|
||||
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
|
||||
template <typename ReduceT, typename ShardT, typename FunctionT>
|
||||
ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
|
||||
std::vector<ReduceT> sums(shards->size());
|
||||
SaveCudaContext {
|
||||
[&](){
|
||||
#pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
sums[shard] = f(shards->at(shard));
|
||||
}
|
||||
for (int shard = 0; shard < shards->size(); ++shard) {
|
||||
sums[shard] = f(shards->at(shard));
|
||||
}}
|
||||
};
|
||||
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
@ -17,7 +17,6 @@ namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
void HistCutMatrix::Init(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
using WXQSketch = common::WXQuantileSketch<bst_float, bst_float>;
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
|
||||
// safe factor for better accuracy
|
||||
|
||||
@ -347,15 +347,13 @@ struct GPUSketcher {
|
||||
};
|
||||
|
||||
void Sketch(const SparsePage& batch, const MetaInfo& info, HistCutMatrix* hmat) {
|
||||
// partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
dh::RowSegments(info.num_row_, devices_.Size(), &row_segments);
|
||||
|
||||
// create device shards
|
||||
shards_.resize(devices_.Size());
|
||||
shards_.resize(dist_.Devices().Size());
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
size_t start = dist_.ShardStart(info.num_row_, i);
|
||||
size_t size = dist_.ShardSize(info.num_row_, i);
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(devices_[i], row_segments[i], row_segments[i + 1], param_));
|
||||
(new DeviceShard(dist_.Devices()[i], start, start + size, param_));
|
||||
});
|
||||
|
||||
// compute sketches for each shard
|
||||
@ -381,12 +379,13 @@ struct GPUSketcher {
|
||||
}
|
||||
|
||||
GPUSketcher(tree::TrainParam param, size_t n_rows) : param_(std::move(param)) {
|
||||
devices_ = GPUSet::All(param_.n_gpus, n_rows).Normalised(param_.gpu_id);
|
||||
dist_ = GPUDistribution::Block(GPUSet::All(param_.n_gpus, n_rows).
|
||||
Normalised(param_.gpu_id));
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<DeviceShard>> shards_;
|
||||
tree::TrainParam param_;
|
||||
GPUSet devices_;
|
||||
GPUDistribution dist_;
|
||||
};
|
||||
|
||||
void DeviceSketch
|
||||
|
||||
@ -67,7 +67,7 @@ struct HistCutUnit {
|
||||
: cut(cut), size(size) {}
|
||||
};
|
||||
|
||||
/*! \brief cut configuration for all the features */
|
||||
/*! \brief cut configuration for all the features. */
|
||||
struct HistCutMatrix {
|
||||
/*! \brief unit pointer to rows by element position */
|
||||
std::vector<uint32_t> row_ptr;
|
||||
|
||||
@ -289,6 +289,7 @@ struct HostDeviceVectorImpl {
|
||||
data_h_.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
} else {
|
||||
//
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
|
||||
}
|
||||
}
|
||||
@ -347,7 +348,10 @@ struct HostDeviceVectorImpl {
|
||||
|
||||
void Reshard(const GPUDistribution& distribution) {
|
||||
if (distribution_ == distribution) { return; }
|
||||
CHECK(distribution_.IsEmpty());
|
||||
CHECK(distribution_.IsEmpty() || distribution.IsEmpty());
|
||||
if (distribution.IsEmpty()) {
|
||||
LazySyncHost(GPUAccess::kWrite);
|
||||
}
|
||||
distribution_ = distribution;
|
||||
InitShards();
|
||||
}
|
||||
|
||||
@ -243,6 +243,11 @@ class HostDeviceVector {
|
||||
bool HostCanAccess(GPUAccess access) const;
|
||||
bool DeviceCanAccess(int device, GPUAccess access) const;
|
||||
|
||||
/*!
|
||||
* \brief Specify memory distribution.
|
||||
*
|
||||
* If GPUSet::Empty() is used, all data will be drawn back to CPU.
|
||||
*/
|
||||
void Reshard(const GPUDistribution& distribution) const;
|
||||
void Reshard(GPUSet devices) const;
|
||||
void Resize(size_t new_size, T v = T());
|
||||
|
||||
@ -242,7 +242,7 @@ struct DeviceHistogram {
|
||||
}
|
||||
|
||||
/**
|
||||
* \summary Return pointer to histogram memory for a given node.
|
||||
* \summary Return pointer to histogram memory for a given node.
|
||||
* \param nidx Tree node index.
|
||||
* \return hist pointer.
|
||||
*/
|
||||
@ -372,19 +372,19 @@ struct DeviceShard {
|
||||
|
||||
// TODO(canonizer): do add support multi-batch DMatrix here
|
||||
DeviceShard(int device_idx, int normalised_device_idx,
|
||||
bst_uint row_begin, bst_uint row_end, TrainParam param)
|
||||
: device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(0),
|
||||
null_gidx_value(0),
|
||||
param(param),
|
||||
prediction_cache_initialised(false),
|
||||
can_use_smem_atomics(false),
|
||||
tmp_pinned(nullptr) {}
|
||||
bst_uint row_begin, bst_uint row_end, TrainParam param) :
|
||||
device_idx(device_idx),
|
||||
normalised_device_idx(normalised_device_idx),
|
||||
row_begin_idx(row_begin),
|
||||
row_end_idx(row_end),
|
||||
row_stride(0),
|
||||
n_rows(row_end - row_begin),
|
||||
n_bins(0),
|
||||
null_gidx_value(0),
|
||||
param(param),
|
||||
prediction_cache_initialised(false),
|
||||
can_use_smem_atomics(false),
|
||||
tmp_pinned(nullptr) {}
|
||||
|
||||
void InitRowPtrs(const SparsePage& row_batch) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
@ -754,7 +754,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
||||
n_devices_ = param_.n_gpus;
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
dist_ =
|
||||
GPUDistribution::Block(GPUSet::All(param_.n_gpus)
|
||||
.Normalised(param_.gpu_id));
|
||||
|
||||
dh::CheckComputeCapability();
|
||||
|
||||
@ -769,7 +771,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor_.Start("Update", devices_);
|
||||
monitor_.Start("Update", dist_.Devices());
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
@ -785,7 +787,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||
}
|
||||
param_.learning_rate = lr;
|
||||
monitor_.Stop("Update", devices_);
|
||||
monitor_.Stop("Update", dist_.Devices());
|
||||
}
|
||||
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
@ -801,10 +803,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
reducer_.Init(device_list_);
|
||||
|
||||
// Partition input matrix into row segments
|
||||
std::vector<size_t> row_segments;
|
||||
dh::RowSegments(info_->num_row_, n_devices, &row_segments);
|
||||
|
||||
dmlc::DataIter<SparsePage>* iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next()) << "Empty batches are not supported";
|
||||
@ -812,22 +810,24 @@ class GPUHistMaker : public TreeUpdater {
|
||||
// Create device shards
|
||||
shards_.resize(n_devices);
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
|
||||
size_t start = dist_.ShardStart(info_->num_row_, i);
|
||||
size_t size = dist_.ShardSize(info_->num_row_, i);
|
||||
shard = std::unique_ptr<DeviceShard>
|
||||
(new DeviceShard(device_list_[i], i,
|
||||
row_segments[i], row_segments[i + 1], param_));
|
||||
(new DeviceShard(device_list_.at(i), i,
|
||||
start, start + size, param_));
|
||||
shard->InitRowPtrs(batch);
|
||||
});
|
||||
|
||||
monitor_.Start("Quantiles", devices_);
|
||||
monitor_.Start("Quantiles", dist_.Devices());
|
||||
common::DeviceSketch(batch, *info_, param_, &hmat_);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
monitor_.Stop("Quantiles", devices_);
|
||||
monitor_.Stop("Quantiles", dist_.Devices());
|
||||
|
||||
monitor_.Start("BinningCompression", devices_);
|
||||
monitor_.Start("BinningCompression", dist_.Devices());
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
});
|
||||
monitor_.Stop("BinningCompression", devices_);
|
||||
monitor_.Stop("BinningCompression", dist_.Devices());
|
||||
|
||||
CHECK(!iter->Next()) << "External memory not supported";
|
||||
|
||||
@ -837,20 +837,22 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const RegTree& tree) {
|
||||
monitor_.Start("InitDataOnce", devices_);
|
||||
monitor_.Start("InitDataOnce", dist_.Devices());
|
||||
if (!initialised_) {
|
||||
this->InitDataOnce(dmat);
|
||||
}
|
||||
monitor_.Stop("InitDataOnce", devices_);
|
||||
monitor_.Stop("InitDataOnce", dist_.Devices());
|
||||
|
||||
column_sampler_.Init(info_->num_col_, param_.colsample_bylevel, param_.colsample_bytree);
|
||||
|
||||
// Copy gpair & reset memory
|
||||
monitor_.Start("InitDataReset", devices_);
|
||||
monitor_.Start("InitDataReset", dist_.Devices());
|
||||
|
||||
gpair->Reshard(devices_);
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {shard->Reset(gpair); });
|
||||
monitor_.Stop("InitDataReset", devices_);
|
||||
gpair->Reshard(dist_);
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->Reset(gpair);
|
||||
});
|
||||
monitor_.Stop("InitDataReset", dist_.Devices());
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx) {
|
||||
@ -1081,12 +1083,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
RegTree* p_tree) {
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor_.Start("InitData", devices_);
|
||||
monitor_.Start("InitData", dist_.Devices());
|
||||
this->InitData(gpair, p_fmat, *p_tree);
|
||||
monitor_.Stop("InitData", devices_);
|
||||
monitor_.Start("InitRoot", devices_);
|
||||
monitor_.Stop("InitData", dist_.Devices());
|
||||
monitor_.Start("InitRoot", dist_.Devices());
|
||||
this->InitRoot(p_tree);
|
||||
monitor_.Stop("InitRoot", devices_);
|
||||
monitor_.Stop("InitRoot", dist_.Devices());
|
||||
|
||||
auto timestamp = qexpand_->size();
|
||||
auto num_leaves = 1;
|
||||
@ -1096,9 +1098,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
qexpand_->pop();
|
||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||
// std::cout << candidate;
|
||||
monitor_.Start("ApplySplit", devices_);
|
||||
monitor_.Start("ApplySplit", dist_.Devices());
|
||||
this->ApplySplit(candidate, p_tree);
|
||||
monitor_.Stop("ApplySplit", devices_);
|
||||
monitor_.Stop("ApplySplit", dist_.Devices());
|
||||
num_leaves++;
|
||||
|
||||
auto left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
@ -1107,12 +1109,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor_.Start("BuildHist", devices_);
|
||||
monitor_.Start("BuildHist", dist_.Devices());
|
||||
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
||||
right_child_nidx);
|
||||
monitor_.Stop("BuildHist", devices_);
|
||||
monitor_.Stop("BuildHist", dist_.Devices());
|
||||
|
||||
monitor_.Start("EvaluateSplits", devices_);
|
||||
monitor_.Start("EvaluateSplits", dist_.Devices());
|
||||
auto splits =
|
||||
this->EvaluateSplits({left_child_nidx, right_child_nidx}, p_tree);
|
||||
qexpand_->push(ExpandEntry(left_child_nidx,
|
||||
@ -1121,21 +1123,21 @@ class GPUHistMaker : public TreeUpdater {
|
||||
qexpand_->push(ExpandEntry(right_child_nidx,
|
||||
tree.GetDepth(right_child_nidx), splits[1],
|
||||
timestamp++));
|
||||
monitor_.Stop("EvaluateSplits", devices_);
|
||||
monitor_.Stop("EvaluateSplits", dist_.Devices());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
monitor_.Start("UpdatePredictionCache", devices_);
|
||||
monitor_.Start("UpdatePredictionCache", dist_.Devices());
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||
return false;
|
||||
p_out_preds->Reshard(devices_);
|
||||
p_out_preds->Reshard(dist_.Devices());
|
||||
dh::ExecuteShards(&shards_, [&](std::unique_ptr<DeviceShard>& shard) {
|
||||
shard->UpdatePredictionCache(p_out_preds->DevicePointer(shard->device_idx));
|
||||
});
|
||||
monitor_.Stop("UpdatePredictionCache", devices_);
|
||||
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1208,7 +1210,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
std::vector<int> device_list_;
|
||||
|
||||
DMatrix* p_last_fmat_;
|
||||
GPUSet devices_;
|
||||
GPUDistribution dist_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist")
|
||||
|
||||
@ -8,6 +8,17 @@
|
||||
#include "../../../src/common/timer.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
struct Shard { int id; };
|
||||
|
||||
TEST(DeviceHelpers, Basic) {
|
||||
std::vector<Shard> shards (4);
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
shards[i].id = i;
|
||||
}
|
||||
int sum = dh::ReduceShards<int>(&shards, [](Shard& s) { return s.id ; });
|
||||
ASSERT_EQ(sum, 6);
|
||||
}
|
||||
|
||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
thrust::host_vector<int> *row_ptr,
|
||||
thrust::host_vector<xgboost::bst_uint> *rows) {
|
||||
|
||||
@ -28,7 +28,7 @@ TEST(gpu_hist_util, TestDeviceSketch) {
|
||||
tree::TrainParam p;
|
||||
p.max_bin = 20;
|
||||
p.gpu_id = 0;
|
||||
p.n_gpus = 1;
|
||||
p.n_gpus = GPUSet::AllVisible().Size();
|
||||
// ensure that the exact quantiles are found
|
||||
p.gpu_batch_nrows = nrows * 10;
|
||||
|
||||
|
||||
@ -162,7 +162,7 @@ TEST(HostDeviceVector, TestCopy) {
|
||||
std::vector<size_t> starts{0, 501};
|
||||
std::vector<size_t> sizes{501, 500};
|
||||
SetCudaSetDeviceHandler(SetDevice);
|
||||
|
||||
|
||||
HostDeviceVector<int> v;
|
||||
{
|
||||
// a separate scope to ensure that v1 is gone before further checks
|
||||
@ -178,6 +178,52 @@ TEST(HostDeviceVector, TestCopy) {
|
||||
SetCudaSetDeviceHandler(nullptr);
|
||||
}
|
||||
|
||||
// The test is not really useful if n_gpus < 2
|
||||
TEST(HostDeviceVector, Reshard) {
|
||||
std::vector<int> h_vec (2345);
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
h_vec[i] = i;
|
||||
}
|
||||
HostDeviceVector<int> vec (h_vec);
|
||||
auto devices = GPUSet::AllVisible();
|
||||
std::vector<size_t> devices_size (devices.Size());
|
||||
|
||||
// From CPU to GPUs.
|
||||
// Assuming we have > 1 devices.
|
||||
vec.Reshard(devices);
|
||||
size_t total_size = 0;
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
total_size += vec.DeviceSize(i);
|
||||
devices_size[i] = vec.DeviceSize(i);
|
||||
}
|
||||
ASSERT_EQ(total_size, h_vec.size());
|
||||
ASSERT_EQ(total_size, vec.Size());
|
||||
auto h_vec_1 = vec.HostVector();
|
||||
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||
|
||||
// Shrink down the number of devices.
|
||||
vec.Reshard(GPUSet::Range(0, 1));
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
|
||||
h_vec_1 = vec.HostVector();
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||
|
||||
// Grow the number of devices.
|
||||
vec.Reshard(devices);
|
||||
total_size = 0;
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
total_size += vec.DeviceSize(i);
|
||||
ASSERT_EQ(devices_size[i], vec.DeviceSize(i));
|
||||
}
|
||||
ASSERT_EQ(total_size, h_vec.size());
|
||||
ASSERT_EQ(total_size, vec.Size());
|
||||
h_vec_1 = vec.HostVector();
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Span) {
|
||||
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
vec.Reshard(GPUSet{0, 1});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user