Fix sketch size calculation. (#5898)
This commit is contained in:
parent
730866a7bc
commit
e471056ec4
@ -162,13 +162,18 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
}
|
||||
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
bst_row_t num_rows, size_t columns, size_t nnz, int device,
|
||||
bst_row_t num_rows, bst_feature_t columns,
|
||||
size_t nnz, int device,
|
||||
size_t num_cuts, bool has_weight) {
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(device) -
|
||||
required_memory * 0.8);
|
||||
auto avail = dh::AvailableMemory(device) * 0.8;
|
||||
if (required_memory > avail) {
|
||||
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
|
||||
} else {
|
||||
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
|
||||
}
|
||||
}
|
||||
return sketch_batch_num_elements;
|
||||
}
|
||||
@ -196,7 +201,7 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
@ -100,7 +100,8 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
|
||||
* directly if it's not 0.
|
||||
*/
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
bst_row_t num_rows, size_t columns, size_t nnz, int device,
|
||||
bst_row_t num_rows, bst_feature_t columns,
|
||||
size_t nnz, int device,
|
||||
size_t num_cuts, bool has_weight);
|
||||
|
||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||
|
||||
@ -48,6 +48,18 @@ TEST(HistUtil, DeviceSketch) {
|
||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||
}
|
||||
|
||||
TEST(HistUtil, SketchBatchNumElements) {
|
||||
size_t constexpr kCols = 10000;
|
||||
int device;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
auto avail = static_cast<size_t>(dh::AvailableMemory(device) * 0.8);
|
||||
auto per_elem = detail::BytesPerElement(false);
|
||||
auto avail_elem = avail / per_elem;
|
||||
size_t rows = avail_elem / kCols * 10;
|
||||
auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false);
|
||||
ASSERT_EQ(batch, avail_elem);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user