Implement weighted sketching for adapter. (#5760)
* Bounded memory tests. * Fixed memory estimation.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
* Copyright 2018~2020 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
@@ -28,24 +28,10 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
Span<const Entry> entries, size_t num_columns) {
|
||||
column_sizes_scan->resize(num_columns + 1, 0);
|
||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
||||
auto d_entries = entries.data();
|
||||
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
|
||||
auto& e = d_entries[idx];
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.index]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
}
|
||||
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void ExtractCuts(int device,
|
||||
size_t num_cuts_per_feature,
|
||||
Span<Entry const> sorted_data,
|
||||
@@ -158,6 +144,23 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
||||
sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||
@@ -201,19 +204,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), temp_weights.begin(),
|
||||
EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
temp_weights.begin(), temp_weights.begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
@@ -239,13 +230,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = has_weights ? 24 : 16;
|
||||
size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements =
|
||||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
|
||||
}
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
@@ -256,12 +243,12 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
auto const& info = dmat->Info();
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (has_weights) {
|
||||
bool is_ranking = CutsBuilder::UseGroup(dmat);
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container,
|
||||
|
||||
Reference in New Issue
Block a user