Fix memory usage of device sketching (#5407)
This commit is contained in:
@@ -378,6 +378,11 @@ public:
|
||||
{
|
||||
return stats_.peak_allocated_bytes;
|
||||
}
|
||||
void Clear()
|
||||
{
|
||||
stats_ = DeviceStats();
|
||||
}
|
||||
|
||||
void Log() {
|
||||
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
|
||||
return;
|
||||
@@ -475,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
template <typename T>
|
||||
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
|
||||
/*! Be careful that the initialization constructor is a no-op, which means calling
|
||||
* `vec.resize(n, 1)` won't initialize the memory region to 1. */
|
||||
* `vec.resize(n)` won't initialize the memory region to 0. Instead use
|
||||
* `vec.resize(n, 0)`*/
|
||||
template <typename T>
|
||||
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
|
||||
/** \brief Specialisation of thrust device vector using custom allocator. */
|
||||
|
||||
@@ -97,6 +97,19 @@ struct EntryCompareOp {
|
||||
}
|
||||
};
|
||||
|
||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||
// We take more cuts than needed and then reduce them later
|
||||
size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
@@ -210,7 +223,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
@@ -237,11 +250,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
@@ -288,28 +301,29 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = has_weights ? 24 : 16;
|
||||
size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements =
|
||||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_);
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
dmat->Info().num_row_, eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = dmat->Info().num_nonzero_;
|
||||
}
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
for (auto begin = 0ull; begin < batch_nnz;
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (dmat->Info().weights_.Size() > 0) {
|
||||
if (has_weights) {
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container, num_cuts, dmat->Info().num_col_);
|
||||
@@ -369,6 +383,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
// Work out how many valid entries we have in each column
|
||||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
@@ -385,7 +400,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
size_t num_valid = host_column_sizes_scan.back();
|
||||
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
thrust::device_vector<Entry> sorted_entries(num_valid);
|
||||
dh::caching_device_vector<Entry> sorted_entries(num_valid);
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
|
||||
entry_iter + end, sorted_entries.begin(), is_valid);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
@@ -406,6 +421,17 @@ template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements) {
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = 16;
|
||||
size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry);
|
||||
size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) -
|
||||
bytes_cuts - bytes_num_columns) *
|
||||
0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
@@ -421,16 +447,6 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
|
||||
adapter->NumRows());
|
||||
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * num_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
adapter->NumRows(), eps, &dummy_nlevel, &num_cuts);
|
||||
num_cuts = std::min(num_cuts, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
sketch_batch_num_elements = batch.Size();
|
||||
}
|
||||
for (auto begin = 0ull; begin < batch.Size();
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
|
||||
@@ -199,14 +199,15 @@ class DenseCuts : public CutsBuilder {
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
|
||||
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
|
||||
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements = 10000000);
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
|
||||
@@ -101,8 +101,7 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
|
||||
monitor_.StartCuda("Quantiles");
|
||||
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
|
||||
row_stride = GetRowStride(dmat);
|
||||
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("InitCompressedData");
|
||||
|
||||
@@ -45,8 +45,7 @@ EllpackPageSource::EllpackPageSource(DMatrix* dmat,
|
||||
|
||||
monitor_.StartCuda("Quantiles");
|
||||
size_t row_stride = GetRowStride(dmat);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
|
||||
param.gpu_batch_nrows);
|
||||
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.StartCuda("WriteEllpackPages");
|
||||
|
||||
@@ -44,8 +44,6 @@ struct GPUHistMakerTrainParam
|
||||
: public XGBoostParameter<GPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
bool deterministic_histogram;
|
||||
// number of rows in a single GPU batch
|
||||
int gpu_batch_nrows;
|
||||
bool debug_synchronize;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) {
|
||||
@@ -53,11 +51,6 @@ struct GPUHistMakerTrainParam
|
||||
"Use single precision to build histograms.");
|
||||
DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe(
|
||||
"Pre-round the gradient for obtaining deterministic gradient histogram.");
|
||||
DMLC_DECLARE_FIELD(gpu_batch_nrows)
|
||||
.set_lower_bound(-1)
|
||||
.set_default(0)
|
||||
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
|
||||
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
|
||||
DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe(
|
||||
"Check if all distributed tree are identical after tree construction.");
|
||||
}
|
||||
@@ -1018,7 +1011,6 @@ class GPUHistMakerSpecialised {
|
||||
BatchParam batch_param{
|
||||
device_,
|
||||
param_.max_bin,
|
||||
hist_maker_param_.gpu_batch_nrows,
|
||||
generic_param_->gpu_page_size
|
||||
};
|
||||
auto page = (*dmat->GetBatches<EllpackPage>(batch_param).begin()).Impl();
|
||||
|
||||
Reference in New Issue
Block a user