Fix sketch size calculation. (#5898)
This commit is contained in:
parent
730866a7bc
commit
e471056ec4
@ -162,13 +162,18 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||||
bst_row_t num_rows, size_t columns, size_t nnz, int device,
|
bst_row_t num_rows, bst_feature_t columns,
|
||||||
|
size_t nnz, int device,
|
||||||
size_t num_cuts, bool has_weight) {
|
size_t num_cuts, bool has_weight) {
|
||||||
if (sketch_batch_num_elements == 0) {
|
if (sketch_batch_num_elements == 0) {
|
||||||
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
|
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
|
||||||
// use up to 80% of available space
|
// use up to 80% of available space
|
||||||
sketch_batch_num_elements = (dh::AvailableMemory(device) -
|
auto avail = dh::AvailableMemory(device) * 0.8;
|
||||||
required_memory * 0.8);
|
if (required_memory > avail) {
|
||||||
|
sketch_batch_num_elements = avail / BytesPerElement(has_weight);
|
||||||
|
} else {
|
||||||
|
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return sketch_batch_num_elements;
|
return sketch_batch_num_elements;
|
||||||
}
|
}
|
||||||
@ -196,7 +201,7 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
|
|||||||
size_t num_columns) {
|
size_t num_columns) {
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
const auto& host_data = page.data.ConstHostVector();
|
const auto& host_data = page.data.ConstHostVector();
|
||||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||||
host_data.begin() + end);
|
host_data.begin() + end);
|
||||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||||
sorted_entries.end(), detail::EntryCompareOp());
|
sorted_entries.end(), detail::EntryCompareOp());
|
||||||
|
|||||||
@ -100,7 +100,8 @@ inline size_t constexpr BytesPerElement(bool has_weight) {
|
|||||||
* directly if it's not 0.
|
* directly if it's not 0.
|
||||||
*/
|
*/
|
||||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||||
bst_row_t num_rows, size_t columns, size_t nnz, int device,
|
bst_row_t num_rows, bst_feature_t columns,
|
||||||
|
size_t nnz, int device,
|
||||||
size_t num_cuts, bool has_weight);
|
size_t num_cuts, bool has_weight);
|
||||||
|
|
||||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||||
|
|||||||
@ -48,6 +48,18 @@ TEST(HistUtil, DeviceSketch) {
|
|||||||
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HistUtil, SketchBatchNumElements) {
|
||||||
|
size_t constexpr kCols = 10000;
|
||||||
|
int device;
|
||||||
|
dh::safe_cuda(cudaGetDevice(&device));
|
||||||
|
auto avail = static_cast<size_t>(dh::AvailableMemory(device) * 0.8);
|
||||||
|
auto per_elem = detail::BytesPerElement(false);
|
||||||
|
auto avail_elem = avail / per_elem;
|
||||||
|
size_t rows = avail_elem / kCols * 10;
|
||||||
|
auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false);
|
||||||
|
ASSERT_EQ(batch, avail_elem);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(HistUtil, DeviceSketchMemory) {
|
TEST(HistUtil, DeviceSketchMemory) {
|
||||||
int num_columns = 100;
|
int num_columns = 100;
|
||||||
int num_rows = 1000;
|
int num_rows = 1000;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user