Implement weighted sketching for adapter. (#5760)
* Bounded memory tests. * Fixed memory estimation.
This commit is contained in:
parent
c35be9dc40
commit
3028fa6b42
@ -140,6 +140,10 @@ void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) {
|
||||
|
||||
bool CutsBuilder::UseGroup(DMatrix* dmat) {
|
||||
auto& info = dmat->Info();
|
||||
return CutsBuilder::UseGroup(info);
|
||||
}
|
||||
|
||||
bool CutsBuilder::UseGroup(MetaInfo const& info) {
|
||||
size_t const num_groups = info.group_ptr_.size() == 0 ?
|
||||
0 : info.group_ptr_.size() - 1;
|
||||
// Use group index for weights?
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
* Copyright 2018~2020 XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
@ -28,24 +28,10 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
// Count the entries in each column and exclusive scan
|
||||
void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
Span<const Entry> entries, size_t num_columns) {
|
||||
column_sizes_scan->resize(num_columns + 1, 0);
|
||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
||||
auto d_entries = entries.data();
|
||||
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
|
||||
auto& e = d_entries[idx];
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.index]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
}
|
||||
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void ExtractCuts(int device,
|
||||
size_t num_cuts_per_feature,
|
||||
Span<Entry const> sorted_data,
|
||||
@ -158,6 +144,23 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
||||
sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts_per_feature,
|
||||
@ -201,19 +204,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), temp_weights.begin(),
|
||||
EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
sorted_entries.begin(), sorted_entries.end(),
|
||||
temp_weights.begin(), temp_weights.begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
return a.index == b.index;
|
||||
});
|
||||
SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
@ -239,13 +230,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = has_weights ? 24 : 16;
|
||||
size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements =
|
||||
(dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element;
|
||||
}
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
@ -256,12 +243,12 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
size_t batch_nnz = batch.data.Size();
|
||||
auto const& info = dmat->Info();
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
|
||||
if (has_weights) {
|
||||
bool is_ranking = CutsBuilder::UseGroup(dmat);
|
||||
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.cbegin(),
|
||||
info.group_ptr_.cend());
|
||||
ProcessWeightedBatch(
|
||||
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
|
||||
&sketch_container,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2020 XGBoost contributors
|
||||
*/
|
||||
#ifndef COMMON_HIST_UTIL_CUH_
|
||||
#define COMMON_HIST_UTIL_CUH_
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include "hist_util.h"
|
||||
#include "threading_utils.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
@ -23,6 +27,7 @@ using SketchEntry = WQSketch::Entry;
|
||||
struct SketchContainer {
|
||||
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
|
||||
static constexpr int kOmpNumColsParallelizeLimit = 1000;
|
||||
static constexpr float kFactor = 8;
|
||||
|
||||
SketchContainer(int max_bin, size_t num_columns, size_t num_rows) {
|
||||
// Initialize Sketches for this dmatrix
|
||||
@ -93,11 +98,71 @@ void ExtractCuts(int device,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts);
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
inline void GetColumnSizesScan(int device,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
Span<const Entry> entries, size_t num_columns) {
|
||||
column_sizes_scan->resize(num_columns + 1, 0);
|
||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
||||
auto d_entries = entries.data();
|
||||
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
|
||||
auto& e = d_entries[idx];
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.index]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
}
|
||||
|
||||
// For adapter.
|
||||
template <typename Iter>
|
||||
void GetColumnSizesScan(int device, size_t num_columns,
|
||||
Iter batch_iter, data::IsValidFunctor is_valid,
|
||||
size_t begin, size_t end,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
column_sizes_scan->resize(num_columns + 1, 0);
|
||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
||||
dh::LaunchN(device, end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||
}
|
||||
|
||||
inline size_t BytesPerElement(bool has_weight) {
|
||||
// Double the memory usage for sorting. We need to assign weight for each element, so
|
||||
// sizeof(float) is added to all elements.
|
||||
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
|
||||
}
|
||||
|
||||
inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
size_t columns, int device,
|
||||
size_t num_cuts, bool has_weight) {
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
size_t bytes_per_element = BytesPerElement(has_weight);
|
||||
size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry);
|
||||
size_t bytes_num_columns = (columns + 1) * sizeof(size_t);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(device) -
|
||||
bytes_cuts - bytes_num_columns) *
|
||||
0.8 / bytes_per_element;
|
||||
}
|
||||
return sketch_batch_num_elements;
|
||||
}
|
||||
|
||||
|
||||
// Compute number of sample cuts needed on local node to maintain accuracy
|
||||
// We take more cuts than needed and then reduce them later
|
||||
inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
@ -109,52 +174,60 @@ inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements = 0);
|
||||
|
||||
template <typename AdapterT>
|
||||
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
adapter->BeforeFirst();
|
||||
adapter->Next();
|
||||
auto &batch = adapter->Value();
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
|
||||
template <typename AdapterBatch, typename BatchIter>
|
||||
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
||||
Range1d range, float missing,
|
||||
size_t columns, int device,
|
||||
thrust::host_vector<size_t>* host_column_sizes_scan,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
batch.GetElement(idx).value);
|
||||
});
|
||||
// Work out how many valid entries we have in each column
|
||||
dh::caching_device_vector<size_t> column_sizes_scan(adapter->NumColumns() + 1,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes_scan[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(),
|
||||
column_sizes_scan.end(), column_sizes_scan.begin());
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
size_t num_valid = host_column_sizes_scan.back();
|
||||
// Work out how many valid entries we have in each column
|
||||
GetColumnSizesScan(device, columns,
|
||||
batch_iter, is_valid,
|
||||
range.begin(), range.end(),
|
||||
column_sizes_scan);
|
||||
host_column_sizes_scan->resize(column_sizes_scan->size());
|
||||
thrust::copy(column_sizes_scan->begin(), column_sizes_scan->end(),
|
||||
host_column_sizes_scan->begin());
|
||||
|
||||
size_t num_valid = host_column_sizes_scan->back();
|
||||
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
dh::caching_device_vector<Entry> sorted_entries(num_valid);
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin,
|
||||
entry_iter + end, sorted_entries.begin(), is_valid);
|
||||
sorted_entries->resize(num_valid);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(),
|
||||
entry_iter + range.end(), sorted_entries->begin(), is_valid);
|
||||
}
|
||||
|
||||
template <typename AdapterBatch>
|
||||
void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
||||
size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
dh::caching_device_vector<Entry> sorted_entries;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
thrust::host_vector<size_t> host_column_sizes_scan;
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, device,
|
||||
&host_column_sizes_scan,
|
||||
&column_sizes_scan,
|
||||
&sorted_entries);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
|
||||
// Extract the cuts from all columns concurrently
|
||||
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
|
||||
ExtractCuts(adapter->DeviceIdx(), num_cuts,
|
||||
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts);
|
||||
ExtractCuts(device, num_cuts,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
@ -164,27 +237,105 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
void ExtractWeightedCuts(int device,
|
||||
size_t num_cuts_per_feature,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts);
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries);
|
||||
|
||||
template <typename Batch>
|
||||
void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
int num_cuts_per_feature,
|
||||
bool is_ranking, float missing, int device,
|
||||
size_t columns, size_t begin, size_t end,
|
||||
SketchContainer *sketch_container) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
info.weights_.SetDevice(device);
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
|
||||
auto d_group_ptr = dh::ToSpan(group_ptr);
|
||||
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
dh::caching_device_vector<Entry> sorted_entries;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
thrust::host_vector<size_t> host_column_sizes_scan;
|
||||
MakeEntriesFromAdapter(batch, batch_iter,
|
||||
{begin, end}, missing, columns, device,
|
||||
&host_column_sizes_scan,
|
||||
&column_sizes_scan,
|
||||
&sorted_entries);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = dh::ToSpan(temp_weights);
|
||||
|
||||
if (is_ranking) {
|
||||
auto const weight_iter = dh::MakeTransformIterator<float>(
|
||||
thrust::make_constant_iterator(0lu),
|
||||
[=]__device__(size_t idx) -> float {
|
||||
auto ridx = batch.GetElement(idx).row_idx;
|
||||
auto it = thrust::upper_bound(thrust::seq,
|
||||
d_group_ptr.cbegin(), d_group_ptr.cend(),
|
||||
ridx) - 1;
|
||||
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
|
||||
return weights[group];
|
||||
});
|
||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||
} else {
|
||||
auto const weight_iter = dh::MakeTransformIterator<float>(
|
||||
thrust::make_counting_iterator(0lu),
|
||||
[=]__device__(size_t idx) -> float {
|
||||
return weights[batch.GetElement(idx).row_idx];
|
||||
});
|
||||
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
|
||||
weight_iter + begin, weight_iter + end,
|
||||
batch_iter + begin,
|
||||
d_temp_weights.data(), // output
|
||||
is_valid);
|
||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||
}
|
||||
|
||||
SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
// Extract cuts
|
||||
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts_per_feature);
|
||||
ExtractWeightedCuts(device, num_cuts_per_feature,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
float missing,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows());
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
int bytes_per_element = 16;
|
||||
size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry);
|
||||
size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) -
|
||||
bytes_cuts - bytes_num_columns) *
|
||||
0.8 / bytes_per_element;
|
||||
}
|
||||
|
||||
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
|
||||
|
||||
adapter->BeforeFirst();
|
||||
adapter->Next();
|
||||
auto& batch = adapter->Value();
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
adapter->NumColumns(), adapter->DeviceIdx(), num_cuts, false);
|
||||
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
@ -197,12 +348,54 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
|
||||
for (auto begin = 0ull; begin < batch.Size();
|
||||
begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts);
|
||||
auto const& batch = adapter->Value();
|
||||
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
|
||||
begin, end, missing, &sketch_container, num_cuts);
|
||||
}
|
||||
|
||||
dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows());
|
||||
return cuts;
|
||||
}
|
||||
|
||||
template <typename Batch>
|
||||
void AdapterDeviceSketch(Batch batch, int num_bins,
|
||||
float missing, int device,
|
||||
SketchContainer* sketch_container,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
size_t num_rows = batch.NumRows();
|
||||
size_t num_cols = batch.NumCols();
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows);
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
num_cols, device, num_cuts, false);
|
||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessSlidingWindow(batch, device, num_cols,
|
||||
begin, end, missing, sketch_container, num_cuts);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Batch>
|
||||
void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
|
||||
MetaInfo const& info,
|
||||
float missing,
|
||||
int device,
|
||||
SketchContainer* sketch_container,
|
||||
size_t sketch_batch_num_elements = 0) {
|
||||
size_t num_rows = batch.NumRows();
|
||||
size_t num_cols = batch.NumCols();
|
||||
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows);
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
num_cols, device, num_cuts, true);
|
||||
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
|
||||
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
|
||||
ProcessWeightedSlidingWindow(batch, info,
|
||||
num_cuts,
|
||||
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
|
||||
sketch_container);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -129,6 +129,7 @@ class CutsBuilder {
|
||||
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
|
||||
/* \brief return whether group for ranking is used. */
|
||||
static bool UseGroup(DMatrix* dmat);
|
||||
static bool UseGroup(MetaInfo const& info);
|
||||
|
||||
protected:
|
||||
HistogramCuts* p_cuts_;
|
||||
|
||||
@ -52,6 +52,9 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
||||
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
||||
|
||||
private:
|
||||
common::Span<ArrayInterface> columns_;
|
||||
size_t num_rows_;
|
||||
@ -167,6 +170,9 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.num_rows; }
|
||||
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.num_cols; }
|
||||
|
||||
private:
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
@ -50,8 +50,7 @@ TEST(HistUtil, DeviceSketch) {
|
||||
// Duplicate this function from hist_util.cu so we don't have to expose it in
|
||||
// header
|
||||
size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
|
||||
constexpr int kFactor = 8;
|
||||
double eps = 1.0 / (kFactor * max_bins);
|
||||
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
@ -59,6 +58,15 @@ size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins,
|
||||
bool with_weights) {
|
||||
size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns;
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
// divide by 2 is because the memory quota used in sorting is reused for storing cuts.
|
||||
return bytes_num_elements / 2 + bytes_cuts;
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
@ -71,12 +79,10 @@ TEST(HistUtil, DeviceSketchMemory) {
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns*sizeof(Entry);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_constant);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMemoryWeights) {
|
||||
@ -92,12 +98,9 @@ TEST(HistUtil, DeviceSketchMemoryWeights) {
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements =
|
||||
num_rows * num_columns * (sizeof(Entry) + sizeof(float));
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
size_t((bytes_num_elements + bytes_cuts) * 1.05));
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchDeterminism) {
|
||||
@ -192,6 +195,20 @@ TEST(HistUtil, DeviceSketchBatches) {
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
|
||||
num_rows = 1000;
|
||||
size_t batches = 16;
|
||||
auto x = GenerateRandom(num_rows * batches, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns);
|
||||
auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows);
|
||||
auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0);
|
||||
|
||||
auto const& cut_values_batched = cuts_with_batches.Values();
|
||||
auto const& cut_values = cuts.Values();
|
||||
CHECK_EQ(cut_values.size(), cut_values_batched.size());
|
||||
for (size_t i = 0; i < cut_values.size(); ++i) {
|
||||
ASSERT_NEAR(cut_values_batched[i], cut_values[i], 1e5);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
@ -210,6 +227,19 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
|
||||
DMatrix* dmat) {
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
0, &sketch_container);
|
||||
common::DenseCuts dense_cuts(&batched_cuts);
|
||||
dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows);
|
||||
ValidateCuts(batched_cuts, dmat, num_bins);
|
||||
}
|
||||
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketch) {
|
||||
int rows = 5;
|
||||
int cols = 1;
|
||||
@ -244,14 +274,56 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
|
||||
size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry);
|
||||
size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t);
|
||||
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
|
||||
sizeof(DenseCuts::WQSketch::Entry);
|
||||
size_t bytes_constant = 1000;
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(),
|
||||
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchBatchMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
|
||||
0, &sketch_container);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_constant = 1000;
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
|
||||
int num_columns = 100;
|
||||
int num_rows = 1000;
|
||||
int num_bins = 256;
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(num_rows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
|
||||
dh::GlobalMemoryLogger().Clear();
|
||||
ConsoleLogger::Configure({{"verbosity", "3"}});
|
||||
common::HistogramCuts batched_cuts;
|
||||
SketchContainer sketch_container(num_bins, num_columns, num_rows);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0,
|
||||
&sketch_container);
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterDeviceSketchCategorical) {
|
||||
@ -284,6 +356,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
|
||||
auto cuts = AdapterDeviceSketch(&adapter, num_bins,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -302,6 +375,7 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
batch_size);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,6 +397,8 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
|
||||
ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,7 +406,7 @@ TEST(HistUtil, SketchingEquivalent) {
|
||||
TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto& h_weights = m->Info().weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
@ -357,6 +433,71 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
|
||||
}
|
||||
ValidateCuts(weighted_cuts, m.get(), kBins);
|
||||
}
|
||||
|
||||
void TestAdapterSketchFromWeights(bool with_group) {
|
||||
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
HostDeviceVector<float> storage;
|
||||
std::string m =
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
|
||||
std::vector<bst_group_t> groups(kGroups);
|
||||
if (with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.weights_.SetDevice(0);
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
|
||||
data::CupyAdapter adapter(m);
|
||||
auto const& batch = adapter.Value();
|
||||
SketchContainer sketch_container(kBins, kCols, kRows);
|
||||
AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
0,
|
||||
&sketch_container);
|
||||
common::HistogramCuts cuts;
|
||||
common::DenseCuts dense_cuts(&cuts);
|
||||
dense_cuts.Init(&sketch_container.sketches_, kBins, kRows);
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().num_col_ = kCols;
|
||||
dmat->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
ValidateCuts(cuts, dmat.get(), kBins);
|
||||
|
||||
if (with_group) {
|
||||
HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0);
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
|
||||
ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]);
|
||||
}
|
||||
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(HistUtil, AdapterSketchFromWeights) {
|
||||
TestAdapterSketchFromWeights(false);
|
||||
TestAdapterSketchFromWeights(true);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -151,7 +151,8 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
size_t num_bins) {
|
||||
|
||||
// Check the endpoints are correct
|
||||
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
|
||||
CHECK_GT(sorted_column.size(), 0);
|
||||
EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front());
|
||||
EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front());
|
||||
EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back());
|
||||
|
||||
@ -189,6 +190,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
|
||||
// Collect data into columns
|
||||
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
|
||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
CHECK_GT(batch.Size(), 0);
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
for (auto e : batch[i]) {
|
||||
columns[e.index].push_back(e.fvalue);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user