Remove various synchronisations from cuda API calls, instrument monitor (#4205)

* Remove various synchronisations from cuda API calls, instrument monitor
with nvtx profiler ranges.
This commit is contained in:
Rory Mitchell
2019-03-10 15:01:23 +13:00
committed by GitHub
parent f83e62dca5
commit 4eeeded7d1
9 changed files with 116 additions and 104 deletions

View File

@@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor {
size_t tree_begin, size_t tree_end) {
dh::safe_cuda(cudaSetDevice(device_));
nodes.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(),
sizeof(DevicePredictionNode) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice));
tree_group.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice));
@@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor {
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
num_rows, entry_start, use_shared, model.param.num_output_group);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
int device_;