From e471056ec443f3c040abe574f7841e508cf29729 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 17 Jul 2020 08:33:16 +0800 Subject: [PATCH] Fix sketch size calculation. (#5898) --- src/common/hist_util.cu | 13 +++++++++---- src/common/hist_util.cuh | 3 ++- tests/cpp/common/test_hist_util.cu | 12 ++++++++++++ 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index ba7207ad2..6953a556b 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -162,13 +162,18 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, } size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, size_t columns, size_t nnz, int device, + bst_row_t num_rows, bst_feature_t columns, + size_t nnz, int device, size_t num_cuts, bool has_weight) { if (sketch_batch_num_elements == 0) { auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); // use up to 80% of available space - sketch_batch_num_elements = (dh::AvailableMemory(device) - - required_memory * 0.8); + auto avail = dh::AvailableMemory(device) * 0.8; + if (required_memory > avail) { + sketch_batch_num_elements = avail / BytesPerElement(has_weight); + } else { + sketch_batch_num_elements = std::min(num_rows * static_cast(columns), nnz); + } } return sketch_batch_num_elements; } @@ -196,7 +201,7 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, size_t num_columns) { dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); - dh::caching_device_vector sorted_entries(host_data.begin() + begin, + dh::device_vector sorted_entries(host_data.begin() + begin, host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index c8f0e3f7d..8dca9fdb9 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -100,7 +100,8 @@ inline size_t constexpr BytesPerElement(bool has_weight) { * directly if it's not 0. */ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, size_t columns, size_t nnz, int device, + bst_row_t num_rows, bst_feature_t columns, + size_t nnz, int device, size_t num_cuts, bool has_weight); // Compute number of sample cuts needed on local node to maintain accuracy diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 433e91679..83e9595f7 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -48,6 +48,18 @@ TEST(HistUtil, DeviceSketch) { EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } +TEST(HistUtil, SketchBatchNumElements) { + size_t constexpr kCols = 10000; + int device; + dh::safe_cuda(cudaGetDevice(&device)); + auto avail = static_cast(dh::AvailableMemory(device) * 0.8); + auto per_elem = detail::BytesPerElement(false); + auto avail_elem = avail / per_elem; + size_t rows = avail_elem / kCols * 10; + auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false); + ASSERT_EQ(batch, avail_elem); +} + TEST(HistUtil, DeviceSketchMemory) { int num_columns = 100; int num_rows = 1000;