Fix sketch size calculation. (#5898)

This commit is contained in:
Jiaming Yuan 2020-07-17 08:33:16 +08:00 committed by GitHub
parent 730866a7bc
commit e471056ec4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 5 deletions

View File

@ -162,13 +162,18 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
}
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
bst_row_t num_rows, size_t columns, size_t nnz, int device,
bst_row_t num_rows, bst_feature_t columns,
size_t nnz, int device,
size_t num_cuts, bool has_weight) {
if (sketch_batch_num_elements == 0) {
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
// use up to 80% of available space
sketch_batch_num_elements = (dh::AvailableMemory(device) -
required_memory * 0.8);
auto avail = dh::AvailableMemory(device) * 0.8;
if (required_memory > avail) {
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
} else {
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
}
}
return sketch_batch_num_elements;
}
@ -196,7 +201,7 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

View File

@ -100,7 +100,8 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
* directly if it's not 0.
*/
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
bst_row_t num_rows, size_t columns, size_t nnz, int device,
bst_row_t num_rows, bst_feature_t columns,
size_t nnz, int device,
size_t num_cuts, bool has_weight);
// Compute number of sample cuts needed on local node to maintain accuracy

View File

@ -48,6 +48,18 @@ TEST(HistUtil, DeviceSketch) {
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
}
TEST(HistUtil, SketchBatchNumElements) {
size_t constexpr kCols = 10000;
int device;
dh::safe_cuda(cudaGetDevice(&device));
auto avail = static_cast<size_t>(dh::AvailableMemory(device) * 0.8);
auto per_elem = detail::BytesPerElement(false);
auto avail_elem = avail / per_elem;
size_t rows = avail_elem / kCols * 10;
auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false);
ASSERT_EQ(batch, avail_elem);
}
TEST(HistUtil, DeviceSketchMemory) {
int num_columns = 100;
int num_rows = 1000;