Fix memory usage of device sketching (#5407)
This commit is contained in:
@@ -97,6 +97,19 @@ struct EntryCompareOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||
// We take more cuts than needed and then reduce them later
|
||||
size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
@@ -210,7 +223,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
@@ -237,11 +250,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
@@ -288,28 +301,29 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = has_weights ? 24 : 16;
|
||||
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements =
|
||||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_);
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = dmat->Info().num_nonzero_;
|
||||
}
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
for (auto begin = 0ull; begin < batch_nnz;
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (dmat->Info().weights_.Size() > 0) {
|
||||
if (has_weights) {
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container, num_cuts, dmat->Info().num_col_);
|
||||
@@ -369,6 +383,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
// Work out how many valid entries we have in each column
|
||||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
@@ -385,7 +400,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
size_t num_valid = host_column_sizes_scan.back();
|
||||
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
thrust::device_vector<Entry> sorted_entries(num_valid);
|
||||
dh::caching_device_vector<Entry> sorted_entries(num_valid);
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
|
||||
entry_iter + end, sorted_entries.begin(), is_valid);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
@@ -406,6 +421,17 @@ template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements) {
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = 16;
|
||||
size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry);
|
||||
size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) -
|
||||
bytes_cuts - bytes_num_columns) *
|
||||
0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
@@ -421,16 +447,6 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
|
||||
adapter->NumRows());
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * num_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = batch.Size();
|
||||
}
|
||||
for (auto begin = 0ull; begin < batch.Size();
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
|
||||
Reference in New Issue
Block a user