Use default allocator in sketching. (#6182)
This commit is contained in:
parent
444131a2e6
commit
f0c63902ff
@ -430,38 +430,18 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
return *allocator;
|
return *allocator;
|
||||||
}
|
}
|
||||||
pointer allocate(size_t n) { // NOLINT
|
pointer allocate(size_t n) { // NOLINT
|
||||||
pointer ptr;
|
T* ptr;
|
||||||
if (use_cub_allocator_) {
|
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||||
T* raw_ptr{nullptr};
|
n * sizeof(T));
|
||||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void**>(&raw_ptr), n * sizeof(T));
|
pointer thrust_ptr{ ptr };
|
||||||
ptr = pointer(raw_ptr);
|
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
||||||
} else {
|
return thrust_ptr;
|
||||||
ptr = SuperT::allocate(n);
|
|
||||||
}
|
|
||||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
|
|
||||||
return ptr;
|
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) { // NOLINT
|
void deallocate(pointer ptr, size_t n) { // NOLINT
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||||
if (use_cub_allocator_) {
|
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
|
||||||
} else {
|
|
||||||
SuperT::deallocate(ptr, n);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
||||||
XGBCachingDeviceAllocatorImpl()
|
|
||||||
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) {
|
|
||||||
std::string symbol{typeid(*SuperT::resource()).name()};
|
|
||||||
if (symbol.find("pool_memory_resource") != std::string::npos
|
|
||||||
|| symbol.find("binning_memory_resource") != std::string::npos
|
|
||||||
|| symbol.find("arena_memory_resource") != std::string::npos) {
|
|
||||||
use_cub_allocator_ = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
|
||||||
private:
|
|
||||||
bool use_cub_allocator_{true};
|
|
||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
|||||||
@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
|||||||
return sketch_batch_num_elements;
|
return sketch_batch_num_elements;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
void SortByWeight(dh::device_vector<float>* weights,
|
||||||
dh::caching_device_vector<float>* weights,
|
dh::device_vector<Entry>* sorted_entries) {
|
||||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
|
||||||
// Sort both entries and wegihts.
|
// Sort both entries and wegihts.
|
||||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
dh::XGBDeviceAllocator<char> alloc;
|
||||||
|
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
|
||||||
sorted_entries->end(), weights->begin(),
|
sorted_entries->end(), weights->begin(),
|
||||||
detail::EntryCompareOp());
|
detail::EntryCompareOp());
|
||||||
|
|
||||||
// Scan weights
|
// Scan weights
|
||||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
dh::XGBCachingDeviceAllocator<char> caching;
|
||||||
|
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
|
||||||
sorted_entries->begin(), sorted_entries->end(),
|
sorted_entries->begin(), sorted_entries->end(),
|
||||||
weights->begin(), weights->begin(),
|
weights->begin(), weights->begin(),
|
||||||
[=] __device__(const Entry& a, const Entry& b) {
|
[=] __device__(const Entry& a, const Entry& b) {
|
||||||
@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
const auto& host_data = page.data.ConstHostVector();
|
const auto& host_data = page.data.ConstHostVector();
|
||||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||||
host_data.begin() + end);
|
host_data.begin() + end);
|
||||||
|
|
||||||
// Binary search to assign weights to each element
|
// Binary search to assign weights to each element
|
||||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||||
auto d_temp_weights = temp_weights.data().get();
|
auto d_temp_weights = temp_weights.data().get();
|
||||||
page.offset.SetDevice(device);
|
page.offset.SetDevice(device);
|
||||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||||
@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||||
|
|
||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
|
|||||||
@ -102,7 +102,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
|||||||
size_t columns, size_t cuts_per_feature, int device,
|
size_t columns, size_t cuts_per_feature, int device,
|
||||||
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
||||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
dh::device_vector<Entry>* sorted_entries) {
|
||||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||||
return Entry(batch.GetElement(idx).column_idx,
|
return Entry(batch.GetElement(idx).column_idx,
|
||||||
@ -123,9 +123,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
|||||||
entry_iter + range.end(), sorted_entries->begin(), is_valid);
|
entry_iter + range.end(), sorted_entries->begin(), is_valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
void SortByWeight(dh::device_vector<float>* weights,
|
||||||
dh::caching_device_vector<float>* weights,
|
dh::device_vector<Entry>* sorted_entries);
|
||||||
dh::caching_device_vector<Entry>* sorted_entries);
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
// Compute sketch on DMatrix.
|
// Compute sketch on DMatrix.
|
||||||
@ -138,7 +137,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
|||||||
size_t begin, size_t end, float missing,
|
size_t begin, size_t end, float missing,
|
||||||
SketchContainer* sketch_container, int num_cuts) {
|
SketchContainer* sketch_container, int num_cuts) {
|
||||||
// Copy current subset of valid elements into temporary storage and sort
|
// Copy current subset of valid elements into temporary storage and sort
|
||||||
dh::caching_device_vector<Entry> sorted_entries;
|
dh::device_vector<Entry> sorted_entries;
|
||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
@ -149,7 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
|||||||
&cuts_ptr,
|
&cuts_ptr,
|
||||||
&column_sizes_scan,
|
&column_sizes_scan,
|
||||||
&sorted_entries);
|
&sorted_entries);
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBDeviceAllocator<char> alloc;
|
||||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||||
sorted_entries.end(), detail::EntryCompareOp());
|
sorted_entries.end(), detail::EntryCompareOp());
|
||||||
|
|
||||||
@ -179,7 +178,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||||
dh::caching_device_vector<Entry> sorted_entries;
|
dh::device_vector<Entry> sorted_entries;
|
||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
detail::MakeEntriesFromAdapter(batch, batch_iter,
|
detail::MakeEntriesFromAdapter(batch, batch_iter,
|
||||||
@ -190,7 +189,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
&sorted_entries);
|
&sorted_entries);
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
|
|
||||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||||
auto d_temp_weights = dh::ToSpan(temp_weights);
|
auto d_temp_weights = dh::ToSpan(temp_weights);
|
||||||
|
|
||||||
if (is_ranking) {
|
if (is_ranking) {
|
||||||
@ -222,7 +221,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
|||||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||||
|
|
||||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||||
|
|||||||
@ -310,7 +310,7 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
|||||||
common::Span<OffsetT const> cuts_ptr,
|
common::Span<OffsetT const> cuts_ptr,
|
||||||
size_t total_cuts, Span<float> weights) {
|
size_t total_cuts, Span<float> weights) {
|
||||||
Span<SketchEntry> out;
|
Span<SketchEntry> out;
|
||||||
dh::caching_device_vector<SketchEntry> cuts;
|
dh::device_vector<SketchEntry> cuts;
|
||||||
bool first_window = this->Current().empty();
|
bool first_window = this->Current().empty();
|
||||||
if (!first_window) {
|
if (!first_window) {
|
||||||
cuts.resize(total_cuts);
|
cuts.resize(total_cuts);
|
||||||
|
|||||||
@ -36,31 +36,31 @@ class SketchContainer {
|
|||||||
int32_t device_;
|
int32_t device_;
|
||||||
|
|
||||||
// Double buffer as neither prune nor merge can be performed inplace.
|
// Double buffer as neither prune nor merge can be performed inplace.
|
||||||
dh::caching_device_vector<SketchEntry> entries_a_;
|
dh::device_vector<SketchEntry> entries_a_;
|
||||||
dh::caching_device_vector<SketchEntry> entries_b_;
|
dh::device_vector<SketchEntry> entries_b_;
|
||||||
bool current_buffer_ {true};
|
bool current_buffer_ {true};
|
||||||
// The container is just a CSC matrix.
|
// The container is just a CSC matrix.
|
||||||
HostDeviceVector<OffsetT> columns_ptr_;
|
HostDeviceVector<OffsetT> columns_ptr_;
|
||||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||||
|
|
||||||
dh::caching_device_vector<SketchEntry>& Current() {
|
dh::device_vector<SketchEntry>& Current() {
|
||||||
if (current_buffer_) {
|
if (current_buffer_) {
|
||||||
return entries_a_;
|
return entries_a_;
|
||||||
} else {
|
} else {
|
||||||
return entries_b_;
|
return entries_b_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dh::caching_device_vector<SketchEntry>& Other() {
|
dh::device_vector<SketchEntry>& Other() {
|
||||||
if (!current_buffer_) {
|
if (!current_buffer_) {
|
||||||
return entries_a_;
|
return entries_a_;
|
||||||
} else {
|
} else {
|
||||||
return entries_b_;
|
return entries_b_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dh::caching_device_vector<SketchEntry> const& Current() const {
|
dh::device_vector<SketchEntry> const& Current() const {
|
||||||
return const_cast<SketchContainer*>(this)->Current();
|
return const_cast<SketchContainer*>(this)->Current();
|
||||||
}
|
}
|
||||||
dh::caching_device_vector<SketchEntry> const& Other() const {
|
dh::device_vector<SketchEntry> const& Other() const {
|
||||||
return const_cast<SketchContainer*>(this)->Other();
|
return const_cast<SketchContainer*>(this)->Other();
|
||||||
}
|
}
|
||||||
void Alternate() {
|
void Alternate() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user