Fix dart inplace prediction with GPU input. (#6777)
* Fix dart inplace predict with data on GPU, which might trigger a fatal check for device access right. * Avoid copying data whenever possible.
This commit is contained in:
@@ -185,9 +185,15 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
||||
size_t begin, size_t end, SketchContainer *sketch_container,
|
||||
int num_cuts_per_feature, size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
dh::device_vector<Entry> sorted_entries;
|
||||
if (page.data.DeviceCanRead()) {
|
||||
const auto& device_data = page.data.ConstDevicePointer();
|
||||
sorted_entries = dh::device_vector<Entry>(device_data + begin, device_data + end);
|
||||
} else {
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
sorted_entries = dh::device_vector<Entry>(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
}
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
|
||||
@@ -92,7 +92,10 @@ class HostDeviceVectorImpl {
|
||||
} else {
|
||||
gpu_access_ = GPUAccess::kWrite;
|
||||
SetDevice();
|
||||
thrust::fill(data_d_->begin(), data_d_->end(), v);
|
||||
auto s_data = dh::ToSpan(*data_d_);
|
||||
dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) {
|
||||
s_data[i] = v;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user