Remove various synchronisations from cuda API calls, instrument monitor (#4205)
* Remove various synchronisations from cuda API calls, instrument monitor with nvtx profiler ranges.
This commit is contained in:
@@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
nodes.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(),
|
||||
sizeof(DevicePredictionNode) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group.resize(model.tree_info.size());
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
@@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
|
||||
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
|
||||
num_rows, entry_start, use_shared, model.param.num_output_group);
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
int device_;
|
||||
|
||||
Reference in New Issue
Block a user