This commit is contained in:
Hendrik Groove 2024-10-22 00:58:01 +02:00
parent 55f995fc50
commit 2bbb8b3786

View File

@ -423,9 +423,9 @@ void GPUHistEvaluator::CopyToHost(const std::vector<bst_node_t> &nidx) {
for (auto idx : nidx) { for (auto idx : nidx) {
copy_stream_.View().Wait(event); copy_stream_.View().Wait(event);
dh::safe_cuda(hipMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(
h_cats.GetNodeCatStorage(idx).data(), d_cats.GetNodeCatStorage(idx).data(), h_cats.GetNodeCatStorage(idx).data(), d_cats.GetNodeCatStorage(idx).data(),
d_cats.GetNodeCatStorage(idx).size_bytes(), hipMemcpyDeviceToHost, copy_stream_.View())); d_cats.GetNodeCatStorage(idx).size_bytes(), cudaMemcpyDeviceToHost, copy_stream_.View()));
} }
} }
@ -507,8 +507,8 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit(Context const *ctx, Evaluat
shared_inputs, dh::ToSpan(out_entries)); shared_inputs, dh::ToSpan(out_entries));
GPUExpandEntry root_entry; GPUExpandEntry root_entry;
dh::safe_cuda(hipMemcpy(&root_entry, out_entries.data().get(), sizeof(GPUExpandEntry), dh::safe_cuda(cudaMemcpyAsync(&root_entry, out_entries.data().get(), sizeof(GPUExpandEntry),
hipMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
return root_entry; return root_entry;
} }
} // namespace xgboost::tree } // namespace xgboost::tree