[EM] Improve memory estimation for quantile sketching. (#10843)
I- Add basic estimation for RMM. - Re-estimate after every sub-batch. - Some debug logs for memory usage. - Fix the locking mechanism in the memory allocator logger.
This commit is contained in:
parent
f3df0d0eb4
commit
bc69a3e877
@ -30,6 +30,7 @@
|
|||||||
#include <cub/util_device.cuh> // for CurrentDevice
|
#include <cub/util_device.cuh> // for CurrentDevice
|
||||||
#include <map> // for map
|
#include <map> // for map
|
||||||
#include <memory> // for unique_ptr
|
#include <memory> // for unique_ptr
|
||||||
|
#include <mutex> // for defer_lock
|
||||||
|
|
||||||
#include "common.h" // for safe_cuda, HumanMemUnit
|
#include "common.h" // for safe_cuda, HumanMemUnit
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -46,6 +47,12 @@ class MemoryLogger {
|
|||||||
size_t num_deallocations{0};
|
size_t num_deallocations{0};
|
||||||
std::map<void *, size_t> device_allocations;
|
std::map<void *, size_t> device_allocations;
|
||||||
void RegisterAllocation(void *ptr, size_t n) {
|
void RegisterAllocation(void *ptr, size_t n) {
|
||||||
|
auto itr = device_allocations.find(ptr);
|
||||||
|
if (itr != device_allocations.cend()) {
|
||||||
|
LOG(WARNING) << "Attempting to allocate " << n << " bytes."
|
||||||
|
<< " that was already allocated\nptr:" << ptr << "\n"
|
||||||
|
<< dmlc::StackTrace();
|
||||||
|
}
|
||||||
device_allocations[ptr] = n;
|
device_allocations[ptr] = n;
|
||||||
currently_allocated_bytes += n;
|
currently_allocated_bytes += n;
|
||||||
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
|
peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes);
|
||||||
@ -56,7 +63,7 @@ class MemoryLogger {
|
|||||||
auto itr = device_allocations.find(ptr);
|
auto itr = device_allocations.find(ptr);
|
||||||
if (itr == device_allocations.end()) {
|
if (itr == device_allocations.end()) {
|
||||||
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
|
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device
|
||||||
<< " that was never allocated\n"
|
<< " that was never allocated\nptr:" << ptr << "\n"
|
||||||
<< dmlc::StackTrace();
|
<< dmlc::StackTrace();
|
||||||
} else {
|
} else {
|
||||||
num_deallocations++;
|
num_deallocations++;
|
||||||
@ -70,18 +77,34 @@ class MemoryLogger {
|
|||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void RegisterAllocation(void *ptr, size_t n) {
|
/**
|
||||||
|
* @brief Register the allocation for logging.
|
||||||
|
*
|
||||||
|
* @param lock Set to false if the allocator has locking machanism.
|
||||||
|
*/
|
||||||
|
void RegisterAllocation(void *ptr, size_t n, bool lock) {
|
||||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
std::unique_lock guard{mutex_, std::defer_lock};
|
||||||
|
if (lock) {
|
||||||
|
guard.lock();
|
||||||
|
}
|
||||||
stats_.RegisterAllocation(ptr, n);
|
stats_.RegisterAllocation(ptr, n);
|
||||||
}
|
}
|
||||||
void RegisterDeallocation(void *ptr, size_t n) {
|
/**
|
||||||
|
* @brief Register the deallocation for logging.
|
||||||
|
*
|
||||||
|
* @param lock Set to false if the allocator has locking machanism.
|
||||||
|
*/
|
||||||
|
void RegisterDeallocation(void *ptr, size_t n, bool lock) {
|
||||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
std::unique_lock guard{mutex_, std::defer_lock};
|
||||||
|
if (lock) {
|
||||||
|
guard.lock();
|
||||||
|
}
|
||||||
stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice());
|
stats_.RegisterDeallocation(ptr, n, cub::CurrentDevice());
|
||||||
}
|
}
|
||||||
size_t PeakMemory() const { return stats_.peak_allocated_bytes; }
|
size_t PeakMemory() const { return stats_.peak_allocated_bytes; }
|
||||||
@ -140,11 +163,12 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
detail::ThrowOOMError(e.what(), n * sizeof(T));
|
detail::ThrowOOMError(e.what(), n * sizeof(T));
|
||||||
}
|
}
|
||||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
|
// We can't place a lock here as template allocator is transient.
|
||||||
|
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T), true);
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) { // NOLINT
|
void deallocate(pointer ptr, size_t n) { // NOLINT
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
|
||||||
SuperT::deallocate(ptr, n);
|
SuperT::deallocate(ptr, n);
|
||||||
}
|
}
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
@ -193,11 +217,12 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
detail::ThrowOOMError(e.what(), n * sizeof(T));
|
detail::ThrowOOMError(e.what(), n * sizeof(T));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
// We can't place a lock here as template allocator is transient.
|
||||||
|
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T), true);
|
||||||
return thrust_ptr;
|
return thrust_ptr;
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) { // NOLINT
|
void deallocate(pointer ptr, size_t n) { // NOLINT
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T), true);
|
||||||
if (use_cub_allocator_) {
|
if (use_cub_allocator_) {
|
||||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||||
} else {
|
} else {
|
||||||
@ -239,14 +264,15 @@ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocato
|
|||||||
*/
|
*/
|
||||||
class LoggingResource : public rmm::mr::device_memory_resource {
|
class LoggingResource : public rmm::mr::device_memory_resource {
|
||||||
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
|
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
|
||||||
|
std::mutex lock_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LoggingResource() = default;
|
LoggingResource() = default;
|
||||||
~LoggingResource() override = default;
|
~LoggingResource() override = default;
|
||||||
LoggingResource(LoggingResource const &) = delete;
|
LoggingResource(LoggingResource const &) = delete;
|
||||||
LoggingResource &operator=(LoggingResource const &) = delete;
|
LoggingResource &operator=(LoggingResource const &) = delete;
|
||||||
LoggingResource(LoggingResource &&) noexcept = default;
|
LoggingResource(LoggingResource &&) noexcept = delete;
|
||||||
LoggingResource &operator=(LoggingResource &&) noexcept = default;
|
LoggingResource &operator=(LoggingResource &&) noexcept = delete;
|
||||||
|
|
||||||
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { // NOLINT
|
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { // NOLINT
|
||||||
return mr_;
|
return mr_;
|
||||||
@ -256,9 +282,13 @@ class LoggingResource : public rmm::mr::device_memory_resource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
|
void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
|
||||||
|
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
|
||||||
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
|
guard.lock();
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
auto const ptr = mr_->allocate(bytes, stream);
|
auto const ptr = mr_->allocate(bytes, stream);
|
||||||
GlobalMemoryLogger().RegisterAllocation(ptr, bytes);
|
GlobalMemoryLogger().RegisterAllocation(ptr, bytes, false);
|
||||||
return ptr;
|
return ptr;
|
||||||
} catch (rmm::bad_alloc const &e) {
|
} catch (rmm::bad_alloc const &e) {
|
||||||
detail::ThrowOOMError(e.what(), bytes);
|
detail::ThrowOOMError(e.what(), bytes);
|
||||||
@ -268,8 +298,12 @@ class LoggingResource : public rmm::mr::device_memory_resource {
|
|||||||
|
|
||||||
void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
|
void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
|
||||||
rmm::cuda_stream_view stream) override {
|
rmm::cuda_stream_view stream) override {
|
||||||
|
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
|
||||||
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
|
guard.lock();
|
||||||
|
}
|
||||||
mr_->deallocate(ptr, bytes, stream);
|
mr_->deallocate(ptr, bytes, stream);
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes);
|
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool do_is_equal( // NOLINT
|
[[nodiscard]] bool do_is_equal( // NOLINT
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2018~2023 by XGBoost contributors
|
* Copyright 2018~2024, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
@ -32,13 +32,12 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
|
|||||||
double eps = 1.0 / (WQSketch::kFactor * max_bins);
|
double eps = 1.0 / (WQSketch::kFactor * max_bins);
|
||||||
size_t dummy_nlevel;
|
size_t dummy_nlevel;
|
||||||
size_t num_cuts;
|
size_t num_cuts;
|
||||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
|
||||||
return std::min(num_cuts, num_rows);
|
return std::min(num_cuts, num_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns,
|
size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins,
|
||||||
size_t max_bins, size_t nnz) {
|
bst_idx_t nnz) {
|
||||||
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
|
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
|
||||||
auto if_dense = num_columns * per_column;
|
auto if_dense = num_columns * per_column;
|
||||||
auto result = std::min(nnz, if_dense);
|
auto result = std::min(nnz, if_dense);
|
||||||
@ -83,23 +82,31 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
|
|||||||
return peak;
|
return peak;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows,
|
bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device,
|
||||||
bst_feature_t columns, size_t nnz, int device, size_t num_cuts,
|
size_t num_cuts, bool has_weight, std::size_t container_bytes) {
|
||||||
bool has_weight) {
|
|
||||||
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
|
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
// device available memory is not accurate when rmm is used.
|
// Device available memory is not accurate when rmm is used.
|
||||||
return std::min(nnz, kIntMax);
|
double total_mem = dh::TotalMemory(device) - container_bytes;
|
||||||
|
double total_f32 = total_mem / sizeof(float);
|
||||||
|
double n_max_used_f32 = std::max(total_f32 / 16.0, 1.0); // a quarter
|
||||||
|
if (shape.nnz > shape.Size()) {
|
||||||
|
// Unknown nnz
|
||||||
|
shape.nnz = shape.Size();
|
||||||
|
}
|
||||||
|
return std::min(static_cast<bst_idx_t>(n_max_used_f32), shape.nnz);
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
|
(void)container_bytes; // We known the remaining size when RMM is not used.
|
||||||
|
|
||||||
if (sketch_batch_num_elements == 0) {
|
if (sketch_batch_num_elements == detail::UnknownSketchNumElements()) {
|
||||||
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
|
auto required_memory =
|
||||||
|
RequiredMemory(shape.n_samples, shape.n_features, shape.nnz, num_cuts, has_weight);
|
||||||
// use up to 80% of available space
|
// use up to 80% of available space
|
||||||
auto avail = dh::AvailableMemory(device) * 0.8;
|
auto avail = dh::AvailableMemory(device) * 0.8;
|
||||||
if (required_memory > avail) {
|
if (required_memory > avail) {
|
||||||
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
|
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
|
||||||
} else {
|
} else {
|
||||||
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
|
sketch_batch_num_elements = std::min(shape.Size(), shape.nnz);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,8 +345,9 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
|
|||||||
// Configure batch size based on available memory
|
// Configure batch size based on available memory
|
||||||
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
|
std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bin, info.num_row_);
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||||
sketch_batch_num_elements, info.num_row_, info.num_col_, info.num_nonzero_, ctx->Ordinal(),
|
sketch_batch_num_elements,
|
||||||
num_cuts_per_feature, has_weight);
|
detail::SketchShape{info.num_row_, info.num_col_, info.num_nonzero_}, ctx->Ordinal(),
|
||||||
|
num_cuts_per_feature, has_weight, 0);
|
||||||
|
|
||||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,10 @@
|
|||||||
#include <thrust/host_vector.h>
|
#include <thrust/host_vector.h>
|
||||||
#include <thrust/sort.h> // for sort
|
#include <thrust/sort.h> // for sort
|
||||||
|
|
||||||
#include <cstddef> // for size_t
|
#include <algorithm> // for max
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for uint32_t
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
|
||||||
#include "../data/adapter.h" // for IsValidFunctor
|
#include "../data/adapter.h" // for IsValidFunctor
|
||||||
#include "algorithm.cuh" // for CopyIf
|
#include "algorithm.cuh" // for CopyIf
|
||||||
@ -186,13 +189,24 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
|
|||||||
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
|
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements`
|
struct SketchShape {
|
||||||
|
bst_idx_t n_samples;
|
||||||
|
bst_feature_t n_features;
|
||||||
|
bst_idx_t nnz;
|
||||||
|
|
||||||
|
template <typename F, std::enable_if_t<std::is_integral_v<F>>* = nullptr>
|
||||||
|
SketchShape(bst_idx_t n_samples, F n_features, bst_idx_t nnz)
|
||||||
|
: n_samples{n_samples}, n_features{static_cast<bst_feature_t>(n_features)}, nnz{nnz} {}
|
||||||
|
|
||||||
|
[[nodiscard]] bst_idx_t Size() const { return n_samples * n_features; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements`
|
||||||
* directly if it's not 0.
|
* directly if it's not 0.
|
||||||
*/
|
*/
|
||||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
bst_idx_t SketchBatchNumElements(bst_idx_t sketch_batch_num_elements, SketchShape shape, int device,
|
||||||
bst_idx_t num_rows, bst_feature_t columns,
|
size_t num_cuts, bool has_weight, std::size_t container_bytes);
|
||||||
size_t nnz, int device,
|
|
||||||
size_t num_cuts, bool has_weight);
|
|
||||||
|
|
||||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||||
// We take more cuts than needed and then reduce them later
|
// We take more cuts than needed and then reduce them later
|
||||||
@ -249,6 +263,8 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
|
|||||||
dh::device_vector<Entry>* p_sorted_entries,
|
dh::device_vector<Entry>* p_sorted_entries,
|
||||||
dh::device_vector<float>* p_sorted_weights,
|
dh::device_vector<float>* p_sorted_weights,
|
||||||
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
dh::caching_device_vector<size_t>* p_column_sizes_scan);
|
||||||
|
|
||||||
|
constexpr bst_idx_t UnknownSketchNumElements() { return 0; }
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -264,7 +280,7 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
|
|||||||
*/
|
*/
|
||||||
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
||||||
Span<float const> hessian,
|
Span<float const> hessian,
|
||||||
std::size_t sketch_batch_num_elements = 0);
|
std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute sketch on DMatrix with GPU.
|
* @brief Compute sketch on DMatrix with GPU.
|
||||||
@ -276,14 +292,15 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
|
|||||||
*
|
*
|
||||||
* @return Quantile cuts
|
* @return Quantile cuts
|
||||||
*/
|
*/
|
||||||
inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
inline HistogramCuts DeviceSketch(
|
||||||
std::size_t sketch_batch_num_elements = 0) {
|
Context const* ctx, DMatrix* p_fmat, bst_bin_t max_bin,
|
||||||
|
std::size_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) {
|
||||||
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
|
return DeviceSketchWithHessian(ctx, p_fmat, max_bin, {}, sketch_batch_num_elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterBatch>
|
template <typename AdapterBatch>
|
||||||
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
|
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
|
||||||
size_t columns, size_t begin, size_t end, float missing,
|
size_t n_features, size_t begin, size_t end, float missing,
|
||||||
SketchContainer* sketch_container, int num_cuts) {
|
SketchContainer* sketch_container, int num_cuts) {
|
||||||
// Copy current subset of valid elements into temporary storage and sort
|
// Copy current subset of valid elements into temporary storage and sort
|
||||||
dh::device_vector<Entry> sorted_entries;
|
dh::device_vector<Entry> sorted_entries;
|
||||||
@ -294,8 +311,9 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
|
|||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
cuts_ptr.SetDevice(ctx->Device());
|
cuts_ptr.SetDevice(ctx->Device());
|
||||||
CUDAContext const* cuctx = ctx->CUDACtx();
|
CUDAContext const* cuctx = ctx->CUDACtx();
|
||||||
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts,
|
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, n_features,
|
||||||
ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries);
|
num_cuts, ctx->Device(), &cuts_ptr, &column_sizes_scan,
|
||||||
|
&sorted_entries);
|
||||||
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());
|
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());
|
||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
@ -305,10 +323,11 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
auto const &h_cuts_ptr = cuts_ptr.HostVector();
|
auto const& h_cuts_ptr = cuts_ptr.HostVector();
|
||||||
// Extract the cuts from all columns concurrently
|
// Extract the cuts from all columns concurrently
|
||||||
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
|
||||||
h_cuts_ptr.back());
|
h_cuts_ptr.back());
|
||||||
|
|
||||||
sorted_entries.clear();
|
sorted_entries.clear();
|
||||||
sorted_entries.shrink_to_fit();
|
sorted_entries.shrink_to_fit();
|
||||||
}
|
}
|
||||||
@ -316,10 +335,10 @@ void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInf
|
|||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
|
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
|
||||||
int num_cuts_per_feature, bool is_ranking, float missing,
|
int num_cuts_per_feature, bool is_ranking, float missing,
|
||||||
DeviceOrd device, size_t columns, size_t begin, size_t end,
|
size_t columns, size_t begin, size_t end,
|
||||||
SketchContainer* sketch_container) {
|
SketchContainer* sketch_container) {
|
||||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
info.weights_.SetDevice(device);
|
info.weights_.SetDevice(ctx->Device());
|
||||||
auto weights = info.weights_.ConstDeviceSpan();
|
auto weights = info.weights_.ConstDeviceSpan();
|
||||||
|
|
||||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
@ -330,7 +349,7 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
|
|||||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||||
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
|
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
|
||||||
num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan,
|
num_cuts_per_feature, ctx->Device(), &cuts_ptr, &column_sizes_scan,
|
||||||
&sorted_entries);
|
&sorted_entries);
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
|
|
||||||
@ -388,48 +407,59 @@ void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo cons
|
|||||||
sorted_entries.shrink_to_fit();
|
sorted_entries.shrink_to_fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/**
|
||||||
* \brief Perform sketching on GPU.
|
* @brief Perform sketching on GPU.
|
||||||
*
|
*
|
||||||
* \param batch A batch from adapter.
|
* @param batch A batch from adapter.
|
||||||
* \param num_bins Bins per column.
|
* @param num_bins Bins per column.
|
||||||
* \param info Metainfo used for sketching.
|
* @param info Metainfo used for sketching.
|
||||||
* \param missing Floating point value that represents invalid value.
|
* @param missing Floating point value that represents invalid value.
|
||||||
* \param sketch_container Container for output sketch.
|
* @param sketch_container Container for output sketch.
|
||||||
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
|
* @param sketch_batch_num_elements Number of element per-sliding window, use it only for
|
||||||
* testing.
|
* testing.
|
||||||
*/
|
*/
|
||||||
template <typename Batch>
|
template <typename Batch>
|
||||||
void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info,
|
void AdapterDeviceSketch(Context const* ctx, Batch batch, bst_bin_t num_bins, MetaInfo const& info,
|
||||||
float missing, SketchContainer* sketch_container,
|
float missing, SketchContainer* sketch_container,
|
||||||
size_t sketch_batch_num_elements = 0) {
|
bst_idx_t sketch_batch_num_elements = detail::UnknownSketchNumElements()) {
|
||||||
size_t num_rows = batch.NumRows();
|
bst_idx_t num_rows = batch.NumRows();
|
||||||
size_t num_cols = batch.NumCols();
|
size_t num_cols = batch.NumCols();
|
||||||
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
|
||||||
auto device = sketch_container->DeviceIdx();
|
|
||||||
bool weighted = !info.weights_.Empty();
|
bool weighted = !info.weights_.Empty();
|
||||||
|
|
||||||
if (weighted) {
|
bst_idx_t const kRemaining = batch.Size();
|
||||||
|
bst_idx_t begin = 0;
|
||||||
|
|
||||||
|
auto shape = detail::SketchShape{num_rows, num_cols, std::numeric_limits<bst_idx_t>::max()};
|
||||||
|
|
||||||
|
while (begin < kRemaining) {
|
||||||
|
// Use total number of samples to estimate the needed cuts first, this doesn't hurt
|
||||||
|
// accuracy as total number of samples is larger.
|
||||||
|
auto num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
|
||||||
|
// Estimate the memory usage based on the current available memory.
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||||
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
sketch_batch_num_elements, shape, ctx->Ordinal(), num_cuts_per_feature, weighted,
|
||||||
device.ordinal, num_cuts_per_feature, true);
|
sketch_container->MemCostBytes());
|
||||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
// Re-estimate the needed number of cuts based on the size of the sub-batch.
|
||||||
size_t end =
|
//
|
||||||
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
// The estimation of `sketch_batch_num_elements` assumes dense input, so the
|
||||||
|
// approximation here is reasonably accurate. It doesn't hurt accuracy since the
|
||||||
|
// estimated n_samples must be greater or equal to the actual n_samples thanks to the
|
||||||
|
// dense assumption.
|
||||||
|
auto approx_n_samples = std::max(sketch_batch_num_elements / num_cols, bst_idx_t{1});
|
||||||
|
num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, approx_n_samples);
|
||||||
|
bst_idx_t end =
|
||||||
|
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
||||||
|
|
||||||
|
if (weighted) {
|
||||||
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
|
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
|
||||||
HostSketchContainer::UseGroup(info), missing, device, num_cols,
|
HostSketchContainer::UseGroup(info), missing, num_cols, begin,
|
||||||
begin, end, sketch_container);
|
end, sketch_container);
|
||||||
}
|
} else {
|
||||||
} else {
|
|
||||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
|
||||||
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
|
|
||||||
device.ordinal, num_cuts_per_feature, false);
|
|
||||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
|
||||||
size_t end =
|
|
||||||
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
|
|
||||||
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
|
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
|
||||||
num_cuts_per_feature);
|
num_cuts_per_feature);
|
||||||
}
|
}
|
||||||
|
begin += sketch_batch_num_elements;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace xgboost::common
|
} // namespace xgboost::common
|
||||||
|
|||||||
@ -309,7 +309,7 @@ void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
|
|||||||
|
|
||||||
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
|
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||||
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
|
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
|
||||||
common::SetDevice(device_.ordinal);
|
common::SetDevice(ctx->Ordinal());
|
||||||
Span<SketchEntry> out;
|
Span<SketchEntry> out;
|
||||||
dh::device_vector<SketchEntry> cuts;
|
dh::device_vector<SketchEntry> cuts;
|
||||||
bool first_window = this->Current().empty();
|
bool first_window = this->Current().empty();
|
||||||
@ -354,7 +354,7 @@ void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<s
|
|||||||
this->FixError();
|
this->FixError();
|
||||||
} else {
|
} else {
|
||||||
this->Current().resize(n_uniques);
|
this->Current().resize(n_uniques);
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(ctx->Device());
|
||||||
this->columns_ptr_.Resize(cuts_ptr.size());
|
this->columns_ptr_.Resize(cuts_ptr.size());
|
||||||
|
|
||||||
auto d_cuts_ptr = this->columns_ptr_.DeviceSpan();
|
auto d_cuts_ptr = this->columns_ptr_.DeviceSpan();
|
||||||
@ -369,7 +369,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
|
|||||||
* pruning or merging. We preserve the first type and remove the second type.
|
* pruning or merging. We preserve the first type and remove the second type.
|
||||||
*/
|
*/
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
|
||||||
|
|
||||||
auto key_it = dh::MakeTransformIterator<size_t>(
|
auto key_it = dh::MakeTransformIterator<size_t>(
|
||||||
@ -408,7 +408,7 @@ size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
|
|||||||
|
|
||||||
void SketchContainer::Prune(Context const* ctx, std::size_t to) {
|
void SketchContainer::Prune(Context const* ctx, std::size_t to) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
|
|
||||||
OffsetT to_total = 0;
|
OffsetT to_total = 0;
|
||||||
auto& h_columns_ptr = columns_ptr_b_.HostVector();
|
auto& h_columns_ptr = columns_ptr_b_.HostVector();
|
||||||
@ -443,7 +443,12 @@ void SketchContainer::Prune(Context const* ctx, std::size_t to) {
|
|||||||
|
|
||||||
void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr,
|
void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr,
|
||||||
Span<SketchEntry const> that) {
|
Span<SketchEntry const> that) {
|
||||||
common::SetDevice(device_.ordinal);
|
SetDevice(ctx->Ordinal());
|
||||||
|
auto self = dh::ToSpan(this->Current());
|
||||||
|
LOG(DEBUG) << "Merge: self:" << HumanMemUnit(self.size_bytes()) << ". "
|
||||||
|
<< "That:" << HumanMemUnit(that.size_bytes()) << ". "
|
||||||
|
<< "This capacity:" << HumanMemUnit(this->MemCapacityBytes()) << "." << std::endl;
|
||||||
|
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
if (this->Current().size() == 0) {
|
if (this->Current().size() == 0) {
|
||||||
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
|
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
|
||||||
@ -478,7 +483,6 @@ void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_colum
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::FixError() {
|
void SketchContainer::FixError() {
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
|
||||||
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||||
auto in = dh::ToSpan(this->Current());
|
auto in = dh::ToSpan(this->Current());
|
||||||
dh::LaunchN(in.size(), [=] __device__(size_t idx) {
|
dh::LaunchN(in.size(), [=] __device__(size_t idx) {
|
||||||
@ -503,7 +507,7 @@ void SketchContainer::FixError() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
auto world = collective::GetWorldSize();
|
auto world = collective::GetWorldSize();
|
||||||
if (world == 1 || is_column_split) {
|
if (world == 1 || is_column_split) {
|
||||||
return;
|
return;
|
||||||
@ -541,7 +545,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
|||||||
std::vector<std::int64_t> recv_lengths;
|
std::vector<std::int64_t> recv_lengths;
|
||||||
HostDeviceVector<std::int8_t> recvbuf;
|
HostDeviceVector<std::int8_t> recvbuf;
|
||||||
rc = collective::AllgatherV(
|
rc = collective::AllgatherV(
|
||||||
ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_),
|
ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), ctx->Device()),
|
||||||
&recv_lengths, &recvbuf);
|
&recv_lengths, &recvbuf);
|
||||||
collective::SafeColl(rc);
|
collective::SafeColl(rc);
|
||||||
for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) {
|
for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) {
|
||||||
@ -563,9 +567,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Merge them into a new sketch.
|
// Merge them into a new sketch.
|
||||||
SketchContainer new_sketch(this->feature_types_, num_bins_,
|
SketchContainer new_sketch(this->feature_types_, num_bins_, this->num_columns_, global_sum_rows,
|
||||||
this->num_columns_, global_sum_rows,
|
ctx->Device());
|
||||||
this->device_);
|
|
||||||
for (size_t i = 0; i < allworkers.size(); ++i) {
|
for (size_t i = 0; i < allworkers.size(); ++i) {
|
||||||
auto worker = allworkers[i];
|
auto worker = allworkers[i];
|
||||||
auto worker_ptr =
|
auto worker_ptr =
|
||||||
@ -593,7 +596,7 @@ struct InvalidCatOp {
|
|||||||
|
|
||||||
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
|
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
p_cuts->min_vals_.Resize(num_columns_);
|
p_cuts->min_vals_.Resize(num_columns_);
|
||||||
|
|
||||||
// Sync between workers.
|
// Sync between workers.
|
||||||
@ -606,12 +609,12 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
// Set up inputs
|
// Set up inputs
|
||||||
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
|
||||||
|
|
||||||
p_cuts->min_vals_.SetDevice(device_);
|
p_cuts->min_vals_.SetDevice(ctx->Device());
|
||||||
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
|
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
|
||||||
auto const in_cut_values = dh::ToSpan(this->Current());
|
auto const in_cut_values = dh::ToSpan(this->Current());
|
||||||
|
|
||||||
// Set up output ptr
|
// Set up output ptr
|
||||||
p_cuts->cut_ptrs_.SetDevice(device_);
|
p_cuts->cut_ptrs_.SetDevice(ctx->Device());
|
||||||
auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
|
auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
|
||||||
h_out_columns_ptr.clear();
|
h_out_columns_ptr.clear();
|
||||||
h_out_columns_ptr.push_back(0);
|
h_out_columns_ptr.push_back(0);
|
||||||
@ -689,7 +692,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
|
|||||||
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||||
|
|
||||||
size_t total_bins = h_out_columns_ptr.back();
|
size_t total_bins = h_out_columns_ptr.back();
|
||||||
p_cuts->cut_values_.SetDevice(device_);
|
p_cuts->cut_values_.SetDevice(ctx->Device());
|
||||||
p_cuts->cut_values_.Resize(total_bins);
|
p_cuts->cut_values_.Resize(total_bins);
|
||||||
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
|
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "categorical.h"
|
#include "categorical.h"
|
||||||
#include "cuda_context.cuh" // for CUDAContext
|
#include "cuda_context.cuh" // for CUDAContext
|
||||||
|
#include "cuda_rt_utils.h" // for SetDevice
|
||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "error_msg.h" // for InvalidMaxBin
|
#include "error_msg.h" // for InvalidMaxBin
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
@ -15,9 +16,7 @@
|
|||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
class HistogramCuts;
|
class HistogramCuts;
|
||||||
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
||||||
using SketchEntry = WQSketch::Entry;
|
using SketchEntry = WQSketch::Entry;
|
||||||
@ -46,7 +45,6 @@ class SketchContainer {
|
|||||||
bst_idx_t num_rows_;
|
bst_idx_t num_rows_;
|
||||||
bst_feature_t num_columns_;
|
bst_feature_t num_columns_;
|
||||||
int32_t num_bins_;
|
int32_t num_bins_;
|
||||||
DeviceOrd device_;
|
|
||||||
|
|
||||||
// Double buffer as neither prune nor merge can be performed inplace.
|
// Double buffer as neither prune nor merge can be performed inplace.
|
||||||
dh::device_vector<SketchEntry> entries_a_;
|
dh::device_vector<SketchEntry> entries_a_;
|
||||||
@ -100,12 +98,12 @@ class SketchContainer {
|
|||||||
*/
|
*/
|
||||||
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
|
SketchContainer(HostDeviceVector<FeatureType> const& feature_types, bst_bin_t max_bin,
|
||||||
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
|
bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device)
|
||||||
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
|
: num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin} {
|
||||||
CHECK(device.IsCUDA());
|
CHECK(device.IsCUDA());
|
||||||
// Initialize Sketches for this dmatrix
|
// Initialize Sketches for this dmatrix
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(device);
|
||||||
this->columns_ptr_.Resize(num_columns + 1, 0);
|
this->columns_ptr_.Resize(num_columns + 1, 0);
|
||||||
this->columns_ptr_b_.SetDevice(device_);
|
this->columns_ptr_b_.SetDevice(device);
|
||||||
this->columns_ptr_b_.Resize(num_columns + 1, 0);
|
this->columns_ptr_b_.Resize(num_columns + 1, 0);
|
||||||
|
|
||||||
this->feature_types_.Resize(feature_types.Size());
|
this->feature_types_.Resize(feature_types.Size());
|
||||||
@ -123,8 +121,25 @@ class SketchContainer {
|
|||||||
|
|
||||||
timer_.Init(__func__);
|
timer_.Init(__func__);
|
||||||
}
|
}
|
||||||
/* \brief Return GPU ID for this container. */
|
/**
|
||||||
[[nodiscard]] DeviceOrd DeviceIdx() const { return device_; }
|
* @brief Calculate the memory cost of the container.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] std::size_t MemCapacityBytes() const {
|
||||||
|
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
|
||||||
|
auto n_bytes = (this->entries_a_.capacity() + this->entries_b_.capacity()) * kE;
|
||||||
|
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
|
||||||
|
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
|
||||||
|
|
||||||
|
return n_bytes;
|
||||||
|
}
|
||||||
|
[[nodiscard]] std::size_t MemCostBytes() const {
|
||||||
|
auto constexpr kE = sizeof(typename decltype(this->entries_a_)::value_type);
|
||||||
|
auto n_bytes = (this->entries_a_.size() + this->entries_b_.size()) * kE;
|
||||||
|
n_bytes += (this->columns_ptr_.Size() + this->columns_ptr_b_.Size()) * sizeof(OffsetT);
|
||||||
|
n_bytes += this->feature_types_.Size() * sizeof(FeatureType);
|
||||||
|
|
||||||
|
return n_bytes;
|
||||||
|
}
|
||||||
/* \brief Whether the predictor matrix contains categorical features. */
|
/* \brief Whether the predictor matrix contains categorical features. */
|
||||||
bool HasCategorical() const { return has_categorical_; }
|
bool HasCategorical() const { return has_categorical_; }
|
||||||
/* \brief Accumulate weights of duplicated entries in input. */
|
/* \brief Accumulate weights of duplicated entries in input. */
|
||||||
@ -166,6 +181,7 @@ class SketchContainer {
|
|||||||
this->Current().shrink_to_fit();
|
this->Current().shrink_to_fit();
|
||||||
this->Other().clear();
|
this->Other().clear();
|
||||||
this->Other().shrink_to_fit();
|
this->Other().shrink_to_fit();
|
||||||
|
LOG(DEBUG) << "Quantile memory cost:" << this->MemCapacityBytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* \brief Merge quantiles from other GPU workers. */
|
/* \brief Merge quantiles from other GPU workers. */
|
||||||
@ -190,13 +206,13 @@ class SketchContainer {
|
|||||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||||
size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
SetDevice(ctx->Ordinal());
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(ctx->Device());
|
||||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||||
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
Span<SketchEntry> entries = dh::ToSpan(this->Current());
|
||||||
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
|
||||||
scan_out.SetDevice(device_);
|
scan_out.SetDevice(ctx->Device());
|
||||||
auto d_scan_out = scan_out.DeviceSpan();
|
auto d_scan_out = scan_out.DeviceSpan();
|
||||||
|
|
||||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
@ -212,7 +228,6 @@ class SketchContainer {
|
|||||||
return n_uniques;
|
return n_uniques;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|
||||||
#endif // XGBOOST_COMMON_QUANTILE_CUH_
|
#endif // XGBOOST_COMMON_QUANTILE_CUH_
|
||||||
|
|||||||
@ -65,7 +65,9 @@ TEST(HistUtil, SketchBatchNumElements) {
|
|||||||
auto per_elem = detail::BytesPerElement(false);
|
auto per_elem = detail::BytesPerElement(false);
|
||||||
auto avail_elem = avail / per_elem;
|
auto avail_elem = avail / per_elem;
|
||||||
size_t rows = avail_elem / kCols * 10;
|
size_t rows = avail_elem / kCols * 10;
|
||||||
auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false);
|
auto shape = detail::SketchShape{rows, kCols, rows * kCols};
|
||||||
|
auto batch = detail::SketchBatchNumElements(detail::UnknownSketchNumElements(), shape, device,
|
||||||
|
256, false, 0);
|
||||||
ASSERT_EQ(batch, avail_elem);
|
ASSERT_EQ(batch, avail_elem);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user