From be0f346ec907b9b8c76597c2ee308fcd53ba311a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 10 May 2019 14:35:06 -0700 Subject: [PATCH] mgpu predictor using explicit offsets (#4438) * mgpu prediction using explicit sharding --- src/common/host_device_vector.h | 3 ++ src/predictor/gpu_predictor.cu | 62 +++++++++++++++++------ tests/cpp/predictor/test_gpu_predictor.cu | 30 ++++++++++- 3 files changed, 77 insertions(+), 18 deletions(-) diff --git a/src/common/host_device_vector.h b/src/common/host_device_vector.h index db27832b6..eff9f9933 100644 --- a/src/common/host_device_vector.h +++ b/src/common/host_device_vector.h @@ -110,6 +110,9 @@ class GPUDistribution { return GPUDistribution(devices, granularity, 0, std::vector()); } + // NOTE(rongou): Explicit offsets don't necessarily cover the whole vector. Sections before the + // first shard or after the last shard may be on host only. This windowing is done in the GPU + // predictor for external memory support. static GPUDistribution Explicit(GPUSet devices, std::vector offsets) { return GPUDistribution(devices, 1, 0, std::move(offsets)); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 821014a7e..e6114549f 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -221,7 +221,9 @@ class GPUPredictor : public xgboost::Predictor { }; private: - void DeviceOffsets(const HostDeviceVector& data, std::vector* out_offsets) { + void DeviceOffsets(const HostDeviceVector& data, + size_t total_size, + std::vector* out_offsets) { auto& offsets = *out_offsets; offsets.resize(devices_.Size() + 1); offsets[0] = 0; @@ -230,13 +232,35 @@ class GPUPredictor : public xgboost::Predictor { int device = devices_.DeviceId(shard); auto data_span = data.DeviceSpan(device); dh::safe_cuda(cudaSetDevice(device)); - // copy the last element from every shard - dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1), - &data_span[data_span.size()-1], - sizeof(size_t), cudaMemcpyDeviceToHost)); + if (data_span.size() == 0) { + offsets[shard + 1] = total_size; + } else { + // copy the last element from every shard + dh::safe_cuda(cudaMemcpy(&offsets.at(shard + 1), + &data_span[data_span.size()-1], + sizeof(size_t), cudaMemcpyDeviceToHost)); + } } } + // This function populates the explicit offsets that can be used to create a window into the + // underlying host vector. The window starts from the `batch_offset` and has a size of + // `batch_size`, and is sharded across all the devices. Each shard is granular depending on + // the number of output classes `n_classes`. + void PredictionDeviceOffsets(size_t total_size, size_t batch_offset, size_t batch_size, + int n_classes, std::vector* out_offsets) { + auto& offsets = *out_offsets; + size_t n_shards = devices_.Size(); + offsets.resize(n_shards + 2); + size_t rows_per_shard = dh::DivRoundUp(batch_size, n_shards); + for (size_t shard = 0; shard < devices_.Size(); ++shard) { + size_t n_rows = std::min(batch_size, shard * rows_per_shard); + offsets[shard] = batch_offset + n_rows * n_classes; + } + offsets[n_shards] = batch_offset + batch_size * n_classes; + offsets[n_shards + 1] = total_size; + } + struct DeviceShard { DeviceShard() : device_{-1} {} void Init(int device) { @@ -246,11 +270,11 @@ class GPUPredictor : public xgboost::Predictor { void PredictInternal (const SparsePage& batch, const MetaInfo& info, HostDeviceVector* predictions, - const size_t batch_offset, const gbm::GBTreeModel& model, const thrust::host_vector& h_tree_segments, const thrust::host_vector& h_nodes, size_t tree_begin, size_t tree_end) { + if (predictions->DeviceSize(device_) == 0) { return; } dh::safe_cuda(cudaSetDevice(device_)); nodes_.resize(h_nodes.size()); dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(), @@ -269,8 +293,6 @@ class GPUPredictor : public xgboost::Predictor { const int BLOCK_THREADS = 128; size_t num_rows = batch.offset.DeviceSize(device_) - 1; - if (num_rows < 1) { return; } - const int GRID_SIZE = static_cast(dh::DivRoundUp(num_rows, BLOCK_THREADS)); int shared_memory_bytes = static_cast @@ -285,8 +307,8 @@ class GPUPredictor : public xgboost::Predictor { data_distr.Devices().Index(device_)); PredictKernel<<>> - (dh::ToSpan(nodes_), predictions->DeviceSpan(device_).subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), + (dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_), + dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, num_rows, entry_start, use_shared, model.param.num_output_group); } @@ -324,22 +346,30 @@ class GPUPredictor : public xgboost::Predictor { h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } - size_t i_batch = 0; size_t batch_offset = 0; for (auto &batch : dmat->GetRowBatches()) { - CHECK(i_batch == 0 || devices_.Size() == 1) << "External memory not supported for multi-GPU"; - // out_preds have been sharded and resized in InitOutPredictions() + bool is_external_memory = batch.Size() < dmat->Info().num_row_; + if (is_external_memory) { + std::vector out_preds_offsets; + PredictionDeviceOffsets(out_preds->Size(), batch_offset, batch.Size(), + model.param.num_output_group, &out_preds_offsets); + out_preds->Reshard(GPUDistribution::Explicit(devices_, out_preds_offsets)); + } + batch.offset.Shard(GPUDistribution::Overlap(devices_, 1)); std::vector device_offsets; - DeviceOffsets(batch.offset, &device_offsets); + DeviceOffsets(batch.offset, batch.data.Size(), &device_offsets); batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); + + // TODO(rongou): only copy the model once for all the batches. dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { - shard.PredictInternal(batch, dmat->Info(), out_preds, batch_offset, model, + shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments, h_nodes, tree_begin, tree_end); }); batch_offset += batch.Size() * model.param.num_output_group; - i_batch++; } + out_preds->Reshard(GPUDistribution::Granular(devices_, model.param.num_output_group)); + monitor_.StopCuda("DevicePredictInternal"); } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index f992b8509..61722ad9c 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -198,9 +198,7 @@ TEST(gpu_predictor, MGPU_PicklingTest) { CheckCAPICall(XGBoosterFree(bst2)); } -#endif // defined(XGBOOST_USE_NCCL) -#if defined(XGBOOST_USE_NCCL) // multi-GPU predictor test TEST(gpu_predictor, MGPU_Test) { std::unique_ptr gpu_predictor = @@ -233,6 +231,34 @@ TEST(gpu_predictor, MGPU_Test) { delete dmat; } } + +// multi-GPU predictor external memory test +TEST(gpu_predictor, MGPU_ExternalMemoryTest) { + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor")); + gpu_predictor->Init({std::pair("n_gpus", "-1")}, {}); + + gbm::GBTreeModel model = CreateTestModel(); + const int n_classes = 3; + model.param.num_output_group = n_classes; + std::vector> dmats; + dmats.push_back(CreateSparsePageDMatrix(9, 64UL)); + dmats.push_back(CreateSparsePageDMatrix(128, 128UL)); + dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL)); + + for (const auto& dmat: dmats) { + // Test predict batch + HostDeviceVector out_predictions; + gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes); + const std::vector &host_vector = out_predictions.ConstHostVector(); + for (int i = 0; i < host_vector.size() / n_classes; i++) { + ASSERT_EQ(host_vector[i * n_classes], 1.5); + ASSERT_EQ(host_vector[i * n_classes + 1], 0.); + ASSERT_EQ(host_vector[i * n_classes + 2], 0.); + } + } +} #endif // defined(XGBOOST_USE_NCCL) } // namespace predictor } // namespace xgboost