Fix dart inplace prediction with GPU input. (#6777)

* Fix dart inplace predict with data on GPU, which might trigger a fatal check
for device access right.
* Avoid copying data whenever possible.
This commit is contained in:
Jiaming Yuan
2021-03-25 12:00:32 +08:00
committed by GitHub
parent 1d90577800
commit a7083d3c13
6 changed files with 135 additions and 25 deletions

View File

@@ -185,9 +185,15 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
size_t begin, size_t end, SketchContainer *sketch_container,
int num_cuts_per_feature, size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
dh::device_vector<Entry> sorted_entries;
if (page.data.DeviceCanRead()) {
const auto& device_data = page.data.ConstDevicePointer();
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
} else {
const auto& host_data = page.data.ConstHostVector();
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
host_data.begin() + end);
}
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

View File

@@ -92,7 +92,10 @@ class HostDeviceVectorImpl {
} else {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
thrust::fill(data_d_->begin(), data_d_->end(), v);
auto s_data = dh::ToSpan(*data_d_);
dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) {
s_data[i] = v;
});
}
}