[EM] Improve memory estimation for quantile sketching. (#10843)

I- Add basic estimation for RMM.
- Re-estimate after every sub-batch.
- Some debug logs for memory usage.
- Fix the locking mechanism in the memory allocator logger.
This commit is contained in:
Jiaming Yuan 2024-09-25 03:20:09 +08:00 committed by GitHub
parent f3df0d0eb4
commit bc69a3e877
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 196 additions and 104 deletions

View File

@ -30,6 +30,7 @@
#include <cub/util_device.cuh> // for CurrentDevice #include <cub/util_device.cuh> // for CurrentDevice
#include <map> // for map #include <map> // for map
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for defer_lock
#include "common.h" // for safe_cuda, HumanMemUnit #include "common.h" // for safe_cuda, HumanMemUnit
#include "xgboost/logging.h" #include "xgboost/logging.h"
@ -46,6 +47,12 @@ class MemoryLogger {
size_t num_deallocations{0}; size_t num_deallocations{0};
std::map<void *, size_t> device_allocations; std::map<void *, size_t> device_allocations;
void RegisterAllocation(void *ptr, size_t n) { void RegisterAllocation(void *ptr, size_t n) {
auto itr = device_allocations.find(ptr);
if (itr != device_allocations.cend()) {
LOG(WARNING) << "Attempting to allocate " << n << " bytes."
<< " that was already allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace();
}
device_allocations[ptr] = n; device_allocations[ptr] = n;
currently_allocated_bytes += n; currently_allocated_bytes += n;
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes); peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
@ -56,7 +63,7 @@ class MemoryLogger {
auto itr = device_allocations.find(ptr); auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) { if (itr == device_allocations.end()) {
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
<< " that was never allocated\n" << " that was never allocated\nptr:" << ptr << "\n"
<< dmlc::StackTrace(); << dmlc::StackTrace();
} else { } else {
num_deallocations++; num_deallocations++;
@ -70,18 +77,34 @@ class MemoryLogger {
std::mutex mutex_; std::mutex mutex_;
public: public:
void RegisterAllocation(void *ptr, size_t n) { /**
* @brief Register the allocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterAllocation(void *ptr, size_t n, bool lock) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return; return;
} }
std::lock_guard<std::mutex> guard(mutex_); std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterAllocation(ptr, n); stats_.RegisterAllocation(ptr, n);
} }
void RegisterDeallocation(void *ptr, size_t n) { /**
* @brief Register the deallocation for logging.
*
* @param lock Set to false if the allocator has locking machanism.
*/
void RegisterDeallocation(void *ptr, size_t n, bool lock) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
return; return;
} }
std::lock_guard<std::mutex> guard(mutex_); std::unique_lock guard{mutex_, std::defer_lock};
if (lock) {
guard.lock();
}
stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice()); stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice());
} }
size_t PeakMemory() const { return stats_.peak_allocated_bytes; } size_t PeakMemory() const { return stats_.peak_allocated_bytes; }
@ -140,11 +163,12 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
} catch (const std::exception &e) { } catch (const std::exception &e) {
detail::ThrowOOMError(e.what(), n * sizeof(T)); detail::ThrowOOMError(e.what(), n * sizeof(T));
} }
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T)); // We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T), true);
return ptr; return ptr;
} }
void deallocate(pointer ptr, size_t n) { // NOLINT void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
SuperT::deallocate(ptr, n); SuperT::deallocate(ptr, n);
} }
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
@ -193,11 +217,12 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
detail::ThrowOOMError(e.what(), n * sizeof(T)); detail::ThrowOOMError(e.what(), n * sizeof(T));
} }
} }
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T)); // We can't place a lock here as template allocator is transient.
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T), true);
return thrust_ptr; return thrust_ptr;
} }
void deallocate(pointer ptr, size_t n) { // NOLINT void deallocate(pointer ptr, size_t n) { // NOLINT
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
if (use_cub_allocator_) { if (use_cub_allocator_) {
GetGlobalCachingAllocator().DeviceFree(ptr.get()); GetGlobalCachingAllocator().DeviceFree(ptr.get());
} else { } else {
@ -239,14 +264,15 @@ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocato
*/ */
class LoggingResource : public rmm::mr::device_memory_resource { class LoggingResource : public rmm::mr::device_memory_resource {
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()}; rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
std::mutex lock_;
public: public:
LoggingResource() = default; LoggingResource() = default;
~LoggingResource() override = default; ~LoggingResource() override = default;
LoggingResource(LoggingResource const &) = delete; LoggingResource(LoggingResource const &) = delete;
LoggingResource &operator=(LoggingResource const &) = delete; LoggingResource &operator=(LoggingResource const &) = delete;
LoggingResource(LoggingResource &&) noexcept = default; LoggingResource(LoggingResource &&) noexcept = delete;
LoggingResource &operator=(LoggingResource &&) noexcept = default; LoggingResource &operator=(LoggingResource &&) noexcept = delete;
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { // NOLINT [[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { // NOLINT
return mr_; return mr_;
@ -256,9 +282,13 @@ class LoggingResource : public rmm::mr::device_memory_resource {
} }
void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
try { try {
auto const ptr = mr_->allocate(bytes, stream); auto const ptr = mr_->allocate(bytes, stream);
GlobalMemoryLogger().RegisterAllocation(ptr, bytes); GlobalMemoryLogger().RegisterAllocation(ptr, bytes, false);
return ptr; return ptr;
} catch (rmm::bad_alloc const &e) { } catch (rmm::bad_alloc const &e) {
detail::ThrowOOMError(e.what(), bytes); detail::ThrowOOMError(e.what(), bytes);
@ -268,8 +298,12 @@ class LoggingResource : public rmm::mr::device_memory_resource {
void do_deallocate(void *ptr, std::size_t bytes, // NOLINT void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
rmm::cuda_stream_view stream) override { rmm::cuda_stream_view stream) override {
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
mr_->deallocate(ptr, bytes, stream); mr_->deallocate(ptr, bytes, stream);
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes); GlobalMemoryLogger().RegisterDeallocation(ptr, bytes, false);
} }
[[nodiscard]] bool do_is_equal( // NOLINT [[nodiscard]] bool do_is_equal( // NOLINT

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2018~2023 by XGBoost contributors * Copyright 2018~2024, XGBoost contributors
*/ */
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <thrust/copy.h> #include <thrust/copy.h>
@ -32,13 +32,12 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins); double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t dummy_nlevel; size_t dummy_nlevel;
size_t num_cuts; size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel( WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(num_rows, eps, &dummy_nlevel, &num_cuts);
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows); return std::min(num_cuts, num_rows);
} }
size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins,
size_t max_bins, size_t nnz) { bst_idx_t nnz) {
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows); auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
auto if_dense = num_columns * per_column; auto if_dense = num_columns * per_column;
auto result = std::min(nnz, if_dense); auto result = std::min(nnz, if_dense);
@ -83,23 +82,31 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
return peak; return peak;
} }
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows, bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device,
bst_feature_t columns, size_t nnz, int device, size_t num_cuts, size_t num_cuts, bool has_weight, std::size_t container_bytes) {
bool has_weight) {
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max()); auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
// device available memory is not accurate when rmm is used. // Device available memory is not accurate when rmm is used.
return std::min(nnz, kIntMax); double total_mem = dh::TotalMemory(device) - container_bytes;
double total_f32 = total_mem / sizeof(float);
double n_max_used_f32 = std::max(total_f32 / 16.0, 1.0); // a quarter
if (shape.nnz > shape.Size()) {
// Unknown nnz
shape.nnz = shape.Size();
}
return std::min(static_cast<bst_idx_t>(n_max_used_f32), shape.nnz);
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
(void)container_bytes; // We known the remaining size when RMM is not used.
if (sketch_batch_num_elements == 0) { if (sketch_batch_num_elements == detail::UnknownSketchNumElements()) {
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); auto required_memory =
RequiredMemory(shape.n_samples, shape.n_features, shape.nnz, num_cuts, has_weight);
// use up to 80% of available space // use up to 80% of available space
auto avail = dh::AvailableMemory(device) * 0.8; auto avail = dh::AvailableMemory(device) * 0.8;
if (required_memory > avail) { if (required_memory > avail) {
sketch_batch_num_elements = avail / BytesPerElement(has_weight); sketch_batch_num_elements = avail / BytesPerElement(has_weight);
} else { } else {
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz); sketch_batch_num_elements = std::min(shape.Size(), shape.nnz);
} }
} }
@ -338,8 +345,9 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
// Configure batch size based on available memory // Configure batch size based on available memory
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_); std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(), sketch_batch_num_elements,
num_cuts_per_feature, has_weight); detail::SketchShape{info.num_row_, info.num_col_, info.num_nonzero_}, ctx->Ordinal(),
num_cuts_per_feature, has_weight, 0);
CUDAContext const* cuctx = ctx->CUDACtx(); CUDAContext const* cuctx = ctx->CUDACtx();

View File

@ -10,7 +10,10 @@
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/sort.h> // for sort #include <thrust/sort.h> // for sort
#include <cstddef> // for size_t #include <algorithm> // for max
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <limits> // for numeric_limits
#include "../data/adapter.h" // for IsValidFunctor #include "../data/adapter.h" // for IsValidFunctor
#include "algorithm.cuh" // for CopyIf #include "algorithm.cuh" // for CopyIf
@ -186,13 +189,24 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
} }
/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` struct SketchShape {
bst_idx_t n_samples;
bst_feature_t n_features;
bst_idx_t nnz;
template <typename F, std::enable_if_t<std::is_integral_v<F>>* = nullptr>
SketchShape(bst_idx_t n_samples, F n_features, bst_idx_t nnz)
: n_samples{n_samples}, n_features{static_cast<bst_feature_t>(n_features)}, nnz{nnz} {}
[[nodiscard]] bst_idx_t Size() const { return n_samples * n_features; }
};
/**
* @brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements`
* directly if it's not 0. * directly if it's not 0.
*/ */
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device,
bst_idx_t num_rows, bst_feature_t columns, size_t num_cuts, bool has_weight, std::size_t container_bytes);
size_t nnz, int device,
size_t num_cuts, bool has_weight);
// Compute number of sample cuts needed on local node to maintain accuracy // Compute number of sample cuts needed on local node to maintain accuracy
// We take more cuts than needed and then reduce them later // We take more cuts than needed and then reduce them later
@ -249,6 +263,8 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
dh::device_vector<Entry>* p_sorted_entries, dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights, dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan); dh::caching_device_vector<size_t>* p_column_sizes_scan);
constexpr bst_idx_t UnknownSketchNumElements() { return 0; }
} // namespace detail } // namespace detail
/** /**
@ -264,7 +280,7 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
*/ */
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
Span<float const> hessian, Span<float const> hessian,
std::size_t sketch_batch_num_elements = 0); std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements());
/** /**
* @brief Compute sketch on DMatrix with GPU. * @brief Compute sketch on DMatrix with GPU.
@ -276,14 +292,15 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
* *
* @return Quantile cuts * @return Quantile cuts
*/ */
inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin, inline HistogramCuts DeviceSketch(
std::size_t sketch_batch_num_elements = 0) { Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) {
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements); return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
} }
template <typename AdapterBatch> template <typename AdapterBatch>
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info, void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
size_t columns, size_t begin, size_t end, float missing, size_t n_features, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) { SketchContainer* sketch_container, int num_cuts) {
// Copy current subset of valid elements into temporary storage and sort // Copy current subset of valid elements into temporary storage and sort
dh::device_vector<Entry> sorted_entries; dh::device_vector<Entry> sorted_entries;
@ -294,8 +311,9 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr; HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
cuts_ptr.SetDevice(ctx->Device()); cuts_ptr.SetDevice(ctx->Device());
CUDAContext const* cuctx = ctx->CUDACtx(); CUDAContext const* cuctx = ctx->CUDACtx();
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts, detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, n_features,
ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries); num_cuts, ctx->Device(), &cuts_ptr, &column_sizes_scan,
&sorted_entries);
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());
if (sketch_container->HasCategorical()) { if (sketch_container->HasCategorical()) {
@ -305,10 +323,11 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
} }
auto d_cuts_ptr = cuts_ptr.DeviceSpan(); auto d_cuts_ptr = cuts_ptr.DeviceSpan();
auto const &h_cuts_ptr = cuts_ptr.HostVector(); auto const& h_cuts_ptr = cuts_ptr.HostVector();
// Extract the cuts from all columns concurrently // Extract the cuts from all columns concurrently
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back()); h_cuts_ptr.back());
sorted_entries.clear(); sorted_entries.clear();
sorted_entries.shrink_to_fit(); sorted_entries.shrink_to_fit();
} }
@ -316,10 +335,10 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
template <typename Batch> template <typename Batch>
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info, void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
int num_cuts_per_feature, bool is_ranking, float missing, int num_cuts_per_feature, bool is_ranking, float missing,
DeviceOrd device, size_t columns, size_t begin, size_t end, size_t columns, size_t begin, size_t end,
SketchContainer* sketch_container) { SketchContainer* sketch_container) {
dh::safe_cuda(cudaSetDevice(device.ordinal)); SetDevice(ctx->Ordinal());
info.weights_.SetDevice(device); info.weights_.SetDevice(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan(); auto weights = info.weights_.ConstDeviceSpan();
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>( auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
@ -330,7 +349,7 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
dh::caching_device_vector<size_t> column_sizes_scan; dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr; HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan, num_cuts_per_feature, ctx->Device(), &cuts_ptr, &column_sizes_scan,
&sorted_entries); &sorted_entries);
data::IsValidFunctor is_valid(missing); data::IsValidFunctor is_valid(missing);
@ -388,48 +407,59 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
sorted_entries.shrink_to_fit(); sorted_entries.shrink_to_fit();
} }
/* /**
* \brief Perform sketching on GPU. * @brief Perform sketching on GPU.
* *
* \param batch A batch from adapter. * @param batch A batch from adapter.
* \param num_bins Bins per column. * @param num_bins Bins per column.
* \param info Metainfo used for sketching. * @param info Metainfo used for sketching.
* \param missing Floating point value that represents invalid value. * @param missing Floating point value that represents invalid value.
* \param sketch_container Container for output sketch. * @param sketch_container Container for output sketch.
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for * @param sketch_batch_num_elements Number of element per-sliding window, use it only for
* testing. * testing.
*/ */
template <typename Batch> template <typename Batch>
void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info, void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, MetaInfo const& info,
float missing, SketchContainer* sketch_container, float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) { bst_idx_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) {
size_t num_rows = batch.NumRows(); bst_idx_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols(); size_t num_cols = batch.NumCols();
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
auto device = sketch_container->DeviceIdx();
bool weighted = !info.weights_.Empty(); bool weighted = !info.weights_.Empty();
if (weighted) { bst_idx_t const kRemaining = batch.Size();
bst_idx_t begin = 0;
auto shape = detail::SketchShape{num_rows, num_cols, std::numeric_limits<bst_idx_t>::max()};
while (begin < kRemaining) {
// Use total number of samples to estimate the needed cuts first, this doesn't hurt
// accuracy as total number of samples is larger.
auto num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
// Estimate the memory usage based on the current available memory.
sketch_batch_num_elements = detail::SketchBatchNumElements( sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(), sketch_batch_num_elements, shape, ctx->Ordinal(), num_cuts_per_feature, weighted,
device.ordinal, num_cuts_per_feature, true); sketch_container->MemCostBytes());
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { // Re-estimate the needed number of cuts based on the size of the sub-batch.
size_t end = //
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements)); // The estimation of `sketch_batch_num_elements` assumes dense input, so the
// approximation here is reasonably accurate. It doesn't hurt accuracy since the
// estimated n_samples must be greater or equal to the actual n_samples thanks to the
// dense assumption.
auto approx_n_samples = std::max(sketch_batch_num_elements / num_cols, bst_idx_t{1});
num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples);
bst_idx_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
if (weighted) {
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature, ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
HostSketchContainer::UseGroup(info), missing, device, num_cols, HostSketchContainer::UseGroup(info), missing, num_cols, begin,
begin, end, sketch_container); end, sketch_container);
} } else {
} else {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
device.ordinal, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container, ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
num_cuts_per_feature); num_cuts_per_feature);
} }
begin += sketch_batch_num_elements;
} }
} }
} // namespace xgboost::common } // namespace xgboost::common

View File

@ -309,7 +309,7 @@ void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr, void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) { common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
common::SetDevice(device_.ordinal); common::SetDevice(ctx->Ordinal());
Span<SketchEntry> out; Span<SketchEntry> out;
dh::device_vector<SketchEntry> cuts; dh::device_vector<SketchEntry> cuts;
bool first_window = this->Current().empty(); bool first_window = this->Current().empty();
@ -354,7 +354,7 @@ void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<s
this->FixError(); this->FixError();
} else { } else {
this->Current().resize(n_uniques); this->Current().resize(n_uniques);
this->columns_ptr_.SetDevice(device_); this->columns_ptr_.SetDevice(ctx->Device());
this->columns_ptr_.Resize(cuts_ptr.size()); this->columns_ptr_.Resize(cuts_ptr.size());
auto d_cuts_ptr = this->columns_ptr_.DeviceSpan(); auto d_cuts_ptr = this->columns_ptr_.DeviceSpan();
@ -369,7 +369,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
* pruning or merging. We preserve the first type and remove the second type. * pruning or merging. We preserve the first type and remove the second type.
*/ */
timer_.Start(__func__); timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal)); SetDevice(ctx->Ordinal());
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1); CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
auto key_it = dh::MakeTransformIterator<size_t>( auto key_it = dh::MakeTransformIterator<size_t>(
@ -408,7 +408,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
void SketchContainer::Prune(Context const* ctx, std::size_t to) { void SketchContainer::Prune(Context const* ctx, std::size_t to) {
timer_.Start(__func__); timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal)); SetDevice(ctx->Ordinal());
OffsetT to_total = 0; OffsetT to_total = 0;
auto& h_columns_ptr = columns_ptr_b_.HostVector(); auto& h_columns_ptr = columns_ptr_b_.HostVector();
@ -443,7 +443,12 @@ void SketchContainer::Prune(Context const* ctx, std::size_t to) {
void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr, void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr,
Span<SketchEntry const> that) { Span<SketchEntry const> that) {
common::SetDevice(device_.ordinal); SetDevice(ctx->Ordinal());
auto self = dh::ToSpan(this->Current());
LOG(DEBUG) << "Merge: self:" << HumanMemUnit(self.size_bytes()) << ". "
<< "That:" << HumanMemUnit(that.size_bytes()) << ". "
<< "This capacity:" << HumanMemUnit(this->MemCapacityBytes()) << "." << std::endl;
timer_.Start(__func__); timer_.Start(__func__);
if (this->Current().size() == 0) { if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0); CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
@ -478,7 +483,6 @@ void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_colum
} }
void SketchContainer::FixError() { void SketchContainer::FixError() {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
auto in = dh::ToSpan(this->Current()); auto in = dh::ToSpan(this->Current());
dh::LaunchN(in.size(), [=] __device__(size_t idx) { dh::LaunchN(in.size(), [=] __device__(size_t idx) {
@ -503,7 +507,7 @@ void SketchContainer::FixError() {
} }
void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_.ordinal)); SetDevice(ctx->Ordinal());
auto world = collective::GetWorldSize(); auto world = collective::GetWorldSize();
if (world == 1 || is_column_split) { if (world == 1 || is_column_split) {
return; return;
@ -541,7 +545,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
std::vector<std::int64_t> recv_lengths; std::vector<std::int64_t> recv_lengths;
HostDeviceVector<std::int8_t> recvbuf; HostDeviceVector<std::int8_t> recvbuf;
rc = collective::AllgatherV( rc = collective::AllgatherV(
ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_), ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), ctx->Device()),
&recv_lengths, &recvbuf); &recv_lengths, &recvbuf);
collective::SafeColl(rc); collective::SafeColl(rc);
for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) { for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) {
@ -563,9 +567,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
} }
// Merge them into a new sketch. // Merge them into a new sketch.
SketchContainer new_sketch(this->feature_types_, num_bins_, SketchContainer new_sketch(this->feature_types_, num_bins_, this->num_columns_, global_sum_rows,
this->num_columns_, global_sum_rows, ctx->Device());
this->device_);
for (size_t i = 0; i < allworkers.size(); ++i) { for (size_t i = 0; i < allworkers.size(); ++i) {
auto worker = allworkers[i]; auto worker = allworkers[i];
auto worker_ptr = auto worker_ptr =
@ -593,7 +596,7 @@ struct InvalidCatOp {
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) { void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
timer_.Start(__func__); timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal)); SetDevice(ctx->Ordinal());
p_cuts->min_vals_.Resize(num_columns_); p_cuts->min_vals_.Resize(num_columns_);
// Sync between workers. // Sync between workers.
@ -606,12 +609,12 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
// Set up inputs // Set up inputs
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
p_cuts->min_vals_.SetDevice(device_); p_cuts->min_vals_.SetDevice(ctx->Device());
auto d_min_values = p_cuts->min_vals_.DeviceSpan(); auto d_min_values = p_cuts->min_vals_.DeviceSpan();
auto const in_cut_values = dh::ToSpan(this->Current()); auto const in_cut_values = dh::ToSpan(this->Current());
// Set up output ptr // Set up output ptr
p_cuts->cut_ptrs_.SetDevice(device_); p_cuts->cut_ptrs_.SetDevice(ctx->Device());
auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector(); auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
h_out_columns_ptr.clear(); h_out_columns_ptr.clear();
h_out_columns_ptr.push_back(0); h_out_columns_ptr.push_back(0);
@ -689,7 +692,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan(); auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
size_t total_bins = h_out_columns_ptr.back(); size_t total_bins = h_out_columns_ptr.back();
p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.SetDevice(ctx->Device());
p_cuts->cut_values_.Resize(total_bins); p_cuts->cut_values_.Resize(total_bins);
auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); auto out_cut_values = p_cuts->cut_values_.DeviceSpan();

View File

@ -8,6 +8,7 @@
#include "categorical.h" #include "categorical.h"
#include "cuda_context.cuh" // for CUDAContext #include "cuda_context.cuh" // for CUDAContext
#include "cuda_rt_utils.h" // for SetDevice
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "error_msg.h" // for InvalidMaxBin #include "error_msg.h" // for InvalidMaxBin
#include "quantile.h" #include "quantile.h"
@ -15,9 +16,7 @@
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/span.h" #include "xgboost/span.h"
namespace xgboost { namespace xgboost::common {
namespace common {
class HistogramCuts; class HistogramCuts;
using WQSketch = WQuantileSketch<bst_float, bst_float>; using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry; using SketchEntry = WQSketch::Entry;
@ -46,7 +45,6 @@ class SketchContainer {
bst_idx_t num_rows_; bst_idx_t num_rows_;
bst_feature_t num_columns_; bst_feature_t num_columns_;
int32_t num_bins_; int32_t num_bins_;
DeviceOrd device_;
// Double buffer as neither prune nor merge can be performed inplace. // Double buffer as neither prune nor merge can be performed inplace.
dh::device_vector<SketchEntry> entries_a_; dh::device_vector<SketchEntry> entries_a_;
@ -100,12 +98,12 @@ class SketchContainer {
*/ */
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin, SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device) bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin} {
CHECK(device.IsCUDA()); CHECK(device.IsCUDA());
// Initialize Sketches for this dmatrix // Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_); this->columns_ptr_.SetDevice(device);
this->columns_ptr_.Resize(num_columns + 1, 0); this->columns_ptr_.Resize(num_columns + 1, 0);
this->columns_ptr_b_.SetDevice(device_); this->columns_ptr_b_.SetDevice(device);
this->columns_ptr_b_.Resize(num_columns + 1, 0); this->columns_ptr_b_.Resize(num_columns + 1, 0);
this->feature_types_.Resize(feature_types.Size()); this->feature_types_.Resize(feature_types.Size());
@ -123,8 +121,25 @@ class SketchContainer {
timer_.Init(__func__); timer_.Init(__func__);
} }
/* \brief Return GPU ID for this container. */ /**
[[nodiscard]] DeviceOrd DeviceIdx() const { return device_; } * @brief Calculate the memory cost of the container.
*/
[[nodiscard]] std::size_t MemCapacityBytes() const {
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
auto n_bytes = (this->entries_a_.capacity() + this->entries_b_.capacity()) * kE;
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
return n_bytes;
}
[[nodiscard]] std::size_t MemCostBytes() const {
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
auto n_bytes = (this->entries_a_.size() + this->entries_b_.size()) * kE;
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
return n_bytes;
}
/* \brief Whether the predictor matrix contains categorical features. */ /* \brief Whether the predictor matrix contains categorical features. */
bool HasCategorical() const { return has_categorical_; } bool HasCategorical() const { return has_categorical_; }
/* \brief Accumulate weights of duplicated entries in input. */ /* \brief Accumulate weights of duplicated entries in input. */
@ -166,6 +181,7 @@ class SketchContainer {
this->Current().shrink_to_fit(); this->Current().shrink_to_fit();
this->Other().clear(); this->Other().clear();
this->Other().shrink_to_fit(); this->Other().shrink_to_fit();
LOG(DEBUG) << "Quantile memory cost:" << this->MemCapacityBytes();
} }
/* \brief Merge quantiles from other GPU workers. */ /* \brief Merge quantiles from other GPU workers. */
@ -190,13 +206,13 @@ class SketchContainer {
template <typename KeyComp = thrust::equal_to<size_t>> template <typename KeyComp = thrust::equal_to<size_t>>
size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) { size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) {
timer_.Start(__func__); timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal)); SetDevice(ctx->Ordinal());
this->columns_ptr_.SetDevice(device_); this->columns_ptr_.SetDevice(ctx->Device());
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan(); Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1); CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current()); Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size()); HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_); scan_out.SetDevice(ctx->Device());
auto d_scan_out = scan_out.DeviceSpan(); auto d_scan_out = scan_out.DeviceSpan();
d_column_scan = this->columns_ptr_.DeviceSpan(); d_column_scan = this->columns_ptr_.DeviceSpan();
@ -212,7 +228,6 @@ class SketchContainer {
return n_uniques; return n_uniques;
} }
}; };
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_CUH_ #endif // XGBOOST_COMMON_QUANTILE_CUH_

View File

@ -65,7 +65,9 @@ TEST(HistUtil, SketchBatchNumElements) {
auto per_elem = detail::BytesPerElement(false); auto per_elem = detail::BytesPerElement(false);
auto avail_elem = avail / per_elem; auto avail_elem = avail / per_elem;
size_t rows = avail_elem / kCols * 10; size_t rows = avail_elem / kCols * 10;
auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false); auto shape = detail::SketchShape{rows, kCols, rows * kCols};
auto batch = detail::SketchBatchNumElements(detail::UnknownSketchNumElements(), shape, device,
256, false, 0);
ASSERT_EQ(batch, avail_elem); ASSERT_EQ(batch, avail_elem);
} }