diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index e6114549f..40c9c5374 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -263,34 +263,39 @@ class GPUPredictor : public xgboost::Predictor { struct DeviceShard { DeviceShard() : device_{-1} {} + void Init(int device) { this->device_ = device; max_shared_memory_bytes_ = dh::MaxSharedMemory(this->device_); } - void PredictInternal - (const SparsePage& batch, const MetaInfo& info, - HostDeviceVector* predictions, - const gbm::GBTreeModel& model, + + void InitModel(const gbm::GBTreeModel& model, const thrust::host_vector& h_tree_segments, const thrust::host_vector& h_nodes, size_t tree_begin, size_t tree_end) { - if (predictions->DeviceSize(device_) == 0) { return; } dh::safe_cuda(cudaSetDevice(device_)); nodes_.resize(h_nodes.size()); dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(), sizeof(DevicePredictionNode) * h_nodes.size(), cudaMemcpyHostToDevice)); tree_segments_.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments_), h_tree_segments.data(), sizeof(size_t) * h_tree_segments.size(), cudaMemcpyHostToDevice)); tree_group_.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group_), model.tree_info.data(), sizeof(int) * model.tree_info.size(), cudaMemcpyHostToDevice)); + this->tree_begin_ = tree_begin; + this->tree_end_ = tree_end; + this->num_group_ = model.param.num_output_group; + } + void PredictInternal + (const SparsePage& batch, const MetaInfo& info, + HostDeviceVector* predictions) { + if (predictions->DeviceSize(device_) == 0) { return; } + dh::safe_cuda(cudaSetDevice(device_)); const int BLOCK_THREADS = 128; size_t num_rows = batch.offset.DeviceSize(device_) - 1; const int GRID_SIZE = static_cast(dh::DivRoundUp(num_rows, BLOCK_THREADS)); @@ -309,8 +314,8 @@ class GPUPredictor : public xgboost::Predictor { PredictKernel<<>> (dh::ToSpan(nodes_), predictions->DeviceSpan(device_), dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(device_), - batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, - num_rows, entry_start, use_shared, model.param.num_output_group); + batch.data.DeviceSpan(device_), this->tree_begin_, this->tree_end_, info.num_col_, + num_rows, entry_start, use_shared, this->num_group_); } private: @@ -319,15 +324,12 @@ class GPUPredictor : public xgboost::Predictor { thrust::device_vector tree_segments_; thrust::device_vector tree_group_; size_t max_shared_memory_bytes_; + size_t tree_begin_; + size_t tree_end_; + int num_group_; }; - void DevicePredictInternal(DMatrix* dmat, - HostDeviceVector* out_preds, - const gbm::GBTreeModel& model, size_t tree_begin, - size_t tree_end) { - if (tree_end - tree_begin == 0) { return; } - monitor_.StartCuda("DevicePredictInternal"); - + void InitModel(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device thrust::host_vector h_tree_segments; @@ -345,6 +347,19 @@ class GPUPredictor : public xgboost::Predictor { std::copy(src_nodes.begin(), src_nodes.end(), h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard &shard) { + shard.InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); + }); + } + + void DevicePredictInternal(DMatrix* dmat, + HostDeviceVector* out_preds, + const gbm::GBTreeModel& model, size_t tree_begin, + size_t tree_end) { + if (tree_end - tree_begin == 0) { return; } + monitor_.StartCuda("DevicePredictInternal"); + + InitModel(model, tree_begin, tree_end); size_t batch_offset = 0; for (auto &batch : dmat->GetRowBatches()) { @@ -361,10 +376,8 @@ class GPUPredictor : public xgboost::Predictor { DeviceOffsets(batch.offset, batch.data.Size(), &device_offsets); batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); - // TODO(rongou): only copy the model once for all the batches. dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { - shard.PredictInternal(batch, dmat->Info(), out_preds, model, - h_tree_segments, h_nodes, tree_begin, tree_end); + shard.PredictInternal(batch, dmat->Info(), out_preds); }); batch_offset += batch.Size() * model.param.num_output_group; }