Implement weighted sketching for adapter. (#5760)

* Bounded memory tests.
* Fixed memory estimation.
This commit is contained in:
Jiaming Yuan
2020-06-12 06:20:39 +08:00
committed by GitHub
parent c35be9dc40
commit 3028fa6b42
7 changed files with 443 additions and 109 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2018 XGBoost contributors
* Copyright 2018~2020 XGBoost contributors
*/
#include <xgboost/logging.h>
@@ -28,24 +28,10 @@
namespace xgboost {
namespace common {
// Count the entries in each column and exclusive scan
void GetColumnSizesScan(int device,
dh::caching_device_vector<size_t>* column_sizes_scan,
Span<const Entry> entries, size_t num_columns) {
column_sizes_scan->resize(num_columns + 1, 0);
auto d_column_sizes_scan = column_sizes_scan->data().get();
auto d_entries = entries.data();
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
auto& e = d_entries[idx];
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.index]),
static_cast<unsigned long long>(1)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
constexpr float SketchContainer::kFactor;
// Count the entries in each column and exclusive scan
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
@@ -158,6 +144,23 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
sorted_entries->begin(), sorted_entries->end(),
weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
}
void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts_per_feature,
@@ -201,19 +204,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), temp_weights.begin(),
EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
sorted_entries.begin(), sorted_entries.end(),
temp_weights.begin(), temp_weights.begin(),
[=] __device__(const Entry& a, const Entry& b) {
return a.index == b.index;
});
SortByWeight(&alloc, &temp_weights, &sorted_entries);
dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
@@ -239,13 +230,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
if (sketch_batch_num_elements == 0) {
int bytes_per_element = has_weights ? 24 : 16;
size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
// use up to 80% of available space
sketch_batch_num_elements =
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
}
sketch_batch_num_elements = SketchBatchNumElements(
sketch_batch_num_elements,
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);
HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
@@ -256,12 +243,12 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
size_t batch_nnz = batch.data.Size();
auto const& info = dmat->Info();
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
if (has_weights) {
bool is_ranking = CutsBuilder::UseGroup(dmat);
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
info.group_ptr_.cend());
ProcessWeightedBatch(
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
&sketch_container,