Fix dart inplace prediction with GPU input. (#6777)
* Fix dart inplace predict with data on GPU, which might trigger a fatal check for device access right. * Avoid copying data whenever possible.
This commit is contained in:
@@ -407,7 +407,6 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
size_t gpu_batch_nrows =
|
||||
std::min(dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(row_batch.Size()));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
||||
|
||||
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
|
||||
|
||||
@@ -429,9 +428,18 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
size_t n_entries = ent_cnt_end - ent_cnt_begin;
|
||||
dh::device_vector<Entry> entries_d(n_entries);
|
||||
// copy data entries to device.
|
||||
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
|
||||
data_vec.data() + ent_cnt_begin,
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
if (row_batch.data.DeviceCanRead()) {
|
||||
auto const& d_data = row_batch.data.ConstDeviceSpan();
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
entries_d.data().get(), d_data.data() + ent_cnt_begin,
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
} else {
|
||||
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
entries_d.data().get(), data_vec.data() + ent_cnt_begin,
|
||||
n_entries * sizeof(Entry), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
const dim3 block3(32, 8, 1); // 256 threads
|
||||
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
|
||||
common::DivRoundUp(row_stride, block3.y), 1);
|
||||
|
||||
Reference in New Issue
Block a user