Use default allocator in sketching. (#6182)

This commit is contained in:
Jiaming Yuan 2020-09-30 14:55:59 +08:00 committed by GitHub
parent 444131a2e6
commit f0c63902ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 33 additions and 53 deletions

View File

@ -430,38 +430,18 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
pointer ptr;
if (use_cub_allocator_) {
T* raw_ptr{nullptr};
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void**>(&raw_ptr), n * sizeof(T));
ptr = pointer(raw_ptr);
} else {
ptr = SuperT::allocate(n);
}
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
return ptr;
T* ptr;
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
n * sizeof(T));
pointer thrust_ptr{ ptr };
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
return thrust_ptr;
}
void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get());
} else {
SuperT::deallocate(ptr, n);
}
GetGlobalCachingAllocator().DeviceFree(ptr.get());
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBCachingDeviceAllocatorImpl()
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) {
std::string symbol{typeid(*SuperT::resource()).name()};
if (symbol.find("pool_memory_resource") != std::string::npos
|| symbol.find("binning_memory_resource") != std::string::npos
|| symbol.find("arena_memory_resource") != std::string::npos) {
use_cub_allocator_ = false;
}
}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
private:
bool use_cub_allocator_{true};
XGBOOST_DEVICE void construct(T *) {} // NOLINT
};
} // namespace detail

View File

@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
return sketch_batch_num_elements;
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries) {
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
dh::XGBDeviceAllocator<char> alloc;
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
detail::EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
sorted_entries->begin(), sorted_entries->end(),
weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) {
@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
// Binary search to assign weights to each element
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
detail::SortByWeight(&temp_weights, &sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;

View File

@ -102,7 +102,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
size_t columns, size_t cuts_per_feature, int device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::caching_device_vector<Entry>* sorted_entries) {
dh::device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
@ -123,9 +123,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
entry_iter + range.end(), sorted_entries->begin(), is_valid);
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries);
} // namespace detail
// Compute sketch on DMatrix.
@ -138,7 +137,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
// Copy current subset of valid elements into temporary storage and sort
dh::caching_device_vector<Entry> sorted_entries;
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
@ -149,7 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
&cuts_ptr,
&column_sizes_scan,
&sorted_entries);
dh::XGBCachingDeviceAllocator<char> alloc;
dh::XGBDeviceAllocator<char> alloc;
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
@ -179,7 +178,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
dh::caching_device_vector<Entry> sorted_entries;
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(batch, batch_iter,
@ -190,7 +189,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
&sorted_entries);
data::IsValidFunctor is_valid(missing);
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = dh::ToSpan(temp_weights);
if (is_ranking) {
@ -222,7 +221,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
}
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
detail::SortByWeight(&temp_weights, &sorted_entries);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();

View File

@ -310,7 +310,7 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT const> cuts_ptr,
size_t total_cuts, Span<float> weights) {
Span<SketchEntry> out;
dh::caching_device_vector<SketchEntry> cuts;
dh::device_vector<SketchEntry> cuts;
bool first_window = this->Current().empty();
if (!first_window) {
cuts.resize(total_cuts);

View File

@ -36,31 +36,31 @@ class SketchContainer {
int32_t device_;
// Double buffer as neither prune nor merge can be performed inplace.
dh::caching_device_vector<SketchEntry> entries_a_;
dh::caching_device_vector<SketchEntry> entries_b_;
dh::device_vector<SketchEntry> entries_a_;
dh::device_vector<SketchEntry> entries_b_;
bool current_buffer_ {true};
// The container is just a CSC matrix.
HostDeviceVector<OffsetT> columns_ptr_;
HostDeviceVector<OffsetT> columns_ptr_b_;
dh::caching_device_vector<SketchEntry>& Current() {
dh::device_vector<SketchEntry>& Current() {
if (current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry>& Other() {
dh::device_vector<SketchEntry>& Other() {
if (!current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry> const& Current() const {
dh::device_vector<SketchEntry> const& Current() const {
return const_cast<SketchContainer*>(this)->Current();
}
dh::caching_device_vector<SketchEntry> const& Other() const {
dh::device_vector<SketchEntry> const& Other() const {
return const_cast<SketchContainer*>(this)->Other();
}
void Alternate() {