mgpu predictor using explicit offsets (#4438)

* mgpu prediction using explicit sharding
This commit is contained in:
Rong Ou 2019-05-10 14:35:06 -07:00 committed by Rory Mitchell
parent d16d9a9988
commit be0f346ec9
3 changed files with 77 additions and 18 deletions

View File

@ -110,6 +110,9 @@ class GPUDistribution {
return GPUDistribution(devices, granularity, 0, std::vector<size_t>()); return GPUDistribution(devices, granularity, 0, std::vector<size_t>());
} }
// NOTE(rongou): Explicit offsets don't necessarily cover the whole vector. Sections before the
// first shard or after the last shard may be on host only. This windowing is done in the GPU
// predictor for external memory support.
static GPUDistribution Explicit(GPUSet devices, std::vector<size_t> offsets) { static GPUDistribution Explicit(GPUSet devices, std::vector<size_t> offsets) {
return GPUDistribution(devices, 1, 0, std::move(offsets)); return GPUDistribution(devices, 1, 0, std::move(offsets));
} }

View File

@ -221,7 +221,9 @@ class GPUPredictor : public xgboost::Predictor {
}; };
private: private:
void DeviceOffsets(const HostDeviceVector<size_t>& data, std::vector<size_t>* out_offsets) { void DeviceOffsets(const HostDeviceVector<size_t>& data,
size_t total_size,
std::vector<size_t>* out_offsets) {
auto& offsets = *out_offsets; auto& offsets = *out_offsets;
offsets.resize(devices_.Size() + 1); offsets.resize(devices_.Size() + 1);
offsets[0] = 0; offsets[0] = 0;
@ -230,13 +232,35 @@ class GPUPredictor : public xgboost::Predictor {
int device = devices_.DeviceId(shard); int device = devices_.DeviceId(shard);
auto data_span = data.DeviceSpan(device); auto data_span = data.DeviceSpan(device);
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
// copy the last element from every shard if (data_span.size() == 0) {
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1), offsets[shard + 1] = total_size;
&data_span[data_span.size()-1], } else {
sizeof(size_t), cudaMemcpyDeviceToHost)); // copy the last element from every shard
dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1),
&data_span[data_span.size()-1],
sizeof(size_t), cudaMemcpyDeviceToHost));
}
} }
} }
// This function populates the explicit offsets that can be used to create a window into the
// underlying host vector. The window starts from the `batch_offset` and has a size of
// `batch_size`, and is sharded across all the devices. Each shard is granular depending on
// the number of output classes `n_classes`.
void PredictionDeviceOffsets(size_t total_size, size_t batch_offset, size_t batch_size,
int n_classes, std::vector<size_t>* out_offsets) {
auto& offsets = *out_offsets;
size_t n_shards = devices_.Size();
offsets.resize(n_shards + 2);
size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards);
for (size_t shard = 0; shard < devices_.Size(); ++shard) {
size_t n_rows = std::min(batch_size, shard * rows_per_shard);
offsets[shard] = batch_offset + n_rows * n_classes;
}
offsets[n_shards] = batch_offset + batch_size * n_classes;
offsets[n_shards + 1] = total_size;
}
struct DeviceShard { struct DeviceShard {
DeviceShard() : device_{-1} {} DeviceShard() : device_{-1} {}
void Init(int device) { void Init(int device) {
@ -246,11 +270,11 @@ class GPUPredictor : public xgboost::Predictor {
void PredictInternal void PredictInternal
(const SparsePage& batch, const MetaInfo& info, (const SparsePage& batch, const MetaInfo& info,
HostDeviceVector<bst_float>* predictions, HostDeviceVector<bst_float>* predictions,
const size_t batch_offset,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
const thrust::host_vector<size_t>& h_tree_segments, const thrust::host_vector<size_t>& h_tree_segments,
const thrust::host_vector<DevicePredictionNode>& h_nodes, const thrust::host_vector<DevicePredictionNode>& h_nodes,
size_t tree_begin, size_t tree_end) { size_t tree_begin, size_t tree_end) {
if (predictions->DeviceSize(device_) == 0) { return; }
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
nodes_.resize(h_nodes.size()); nodes_.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(), dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(),
@ -269,8 +293,6 @@ class GPUPredictor : public xgboost::Predictor {
const int BLOCK_THREADS = 128; const int BLOCK_THREADS = 128;
size_t num_rows = batch.offset.DeviceSize(device_) - 1; size_t num_rows = batch.offset.DeviceSize(device_) - 1;
if (num_rows < 1) { return; }
const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS)); const int GRID_SIZE = static_cast<int>(dh::DivRoundUp(num_rows, BLOCK_THREADS));
int shared_memory_bytes = static_cast<int> int shared_memory_bytes = static_cast<int>
@ -285,8 +307,8 @@ class GPUPredictor : public xgboost::Predictor {
data_distr.Devices().Index(device_)); data_distr.Devices().Index(device_));
PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>> PredictKernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS, shared_memory_bytes>>>
(dh::ToSpan(nodes_), predictions->DeviceSpan(device_).subspan(batch_offset), (dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_),
dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_),
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
num_rows, entry_start, use_shared, model.param.num_output_group); num_rows, entry_start, use_shared, model.param.num_output_group);
} }
@ -324,22 +346,30 @@ class GPUPredictor : public xgboost::Predictor {
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
} }
size_t i_batch = 0;
size_t batch_offset = 0; size_t batch_offset = 0;
for (auto &batch : dmat->GetRowBatches()) { for (auto &batch : dmat->GetRowBatches()) {
CHECK(i_batch == 0 || devices_.Size() == 1) << "External memory not supported for multi-GPU"; bool is_external_memory = batch.Size() < dmat->Info().num_row_;
// out_preds have been sharded and resized in InitOutPredictions() if (is_external_memory) {
std::vector<size_t> out_preds_offsets;
PredictionDeviceOffsets(out_preds->Size(), batch_offset, batch.Size(),
model.param.num_output_group, &out_preds_offsets);
out_preds->Reshard(GPUDistribution::Explicit(devices_, out_preds_offsets));
}
batch.offset.Shard(GPUDistribution::Overlap(devices_, 1)); batch.offset.Shard(GPUDistribution::Overlap(devices_, 1));
std::vector<size_t> device_offsets; std::vector<size_t> device_offsets;
DeviceOffsets(batch.offset, &device_offsets); DeviceOffsets(batch.offset, batch.data.Size(), &device_offsets);
batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets));
// TODO(rongou): only copy the model once for all the batches.
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.PredictInternal(batch, dmat->Info(), out_preds, batch_offset, model, shard.PredictInternal(batch, dmat->Info(), out_preds, model,
h_tree_segments, h_nodes, tree_begin, tree_end); h_tree_segments, h_nodes, tree_begin, tree_end);
}); });
batch_offset += batch.Size() * model.param.num_output_group; batch_offset += batch.Size() * model.param.num_output_group;
i_batch++;
} }
out_preds->Reshard(GPUDistribution::Granular(devices_, model.param.num_output_group));
monitor_.StopCuda("DevicePredictInternal"); monitor_.StopCuda("DevicePredictInternal");
} }

View File

@ -198,9 +198,7 @@ TEST(gpu_predictor, MGPU_PicklingTest) {
CheckCAPICall(XGBoosterFree(bst2)); CheckCAPICall(XGBoosterFree(bst2));
} }
#endif // defined(XGBOOST_USE_NCCL)
#if defined(XGBOOST_USE_NCCL)
// multi-GPU predictor test // multi-GPU predictor test
TEST(gpu_predictor, MGPU_Test) { TEST(gpu_predictor, MGPU_Test) {
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
@ -233,6 +231,34 @@ TEST(gpu_predictor, MGPU_Test) {
delete dmat; delete dmat;
} }
} }
// multi-GPU predictor external memory test
TEST(gpu_predictor, MGPU_ExternalMemoryTest) {
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor"));
gpu_predictor->Init({std::pair<std::string, std::string>("n_gpus", "-1")}, {});
gbm::GBTreeModel model = CreateTestModel();
const int n_classes = 3;
model.param.num_output_group = n_classes;
std::vector<std::unique_ptr<DMatrix>> dmats;
dmats.push_back(CreateSparsePageDMatrix(9, 64UL));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL));
for (const auto& dmat: dmats) {
// Test predict batch
HostDeviceVector<float> out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes);
const std::vector<float> &host_vector = out_predictions.ConstHostVector();
for (int i = 0; i < host_vector.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 1.5);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.);
ASSERT_EQ(host_vector[i * n_classes + 2], 0.);
}
}
}
#endif // defined(XGBOOST_USE_NCCL) #endif // defined(XGBOOST_USE_NCCL)
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost