Fix dart inplace prediction with GPU input. (#6777)

* Fix dart inplace predict with data on GPU, which might trigger a fatal check
for device access right.
* Avoid copying data whenever possible.
This commit is contained in:
Jiaming Yuan
2021-03-25 12:00:32 +08:00
committed by GitHub
parent 1d90577800
commit a7083d3c13
6 changed files with 135 additions and 25 deletions

View File

@@ -407,7 +407,6 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t gpu_batch_nrows =
std::min(dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
static_cast<size_t>(row_batch.Size()));
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
@@ -429,9 +428,18 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
if (row_batch.data.DeviceCanRead()) {
auto const& d_data = row_batch.data.ConstDeviceSpan();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), d_data.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
} else {
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
dh::safe_cuda(cudaMemcpyAsync(
entries_d.data().get(), data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
}
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);