Implement GK sketching on GPU. (#5846)
* Implement GK sketching on GPU. * Strong tests on quantile building. * Handle sparse dataset by binary searching the column index. * Hypothesis test on dask.
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/binary_search.h>
|
||||
@@ -31,21 +32,20 @@ namespace common {
|
||||
|
||||
constexpr float SketchContainer::kFactor;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Count the entries in each column and exclusive scan
|
||||
void ExtractCuts(int device,
|
||||
size_t num_cuts_per_feature,
|
||||
Span<Entry const> sorted_data,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts) {
|
||||
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry const> sorted_data,
|
||||
Span<size_t const> column_sizes_scan,
|
||||
Span<SketchEntry> out_cuts) {
|
||||
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = idx / num_cuts_per_feature;
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts =
|
||||
min(static_cast<size_t>(num_cuts_per_feature), column_size);
|
||||
size_t cut_idx = idx % num_cuts_per_feature;
|
||||
if (cut_idx >= num_available_cuts) return;
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
Span<Entry const> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
size_t rank = (column_entries.size() * cut_idx) /
|
||||
@@ -55,31 +55,20 @@ void ExtractCuts(int device,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Extracts the cuts from sorted data, considering weights.
|
||||
*
|
||||
* \param device The device.
|
||||
* \param cuts Output cuts.
|
||||
* \param num_cuts_per_feature Number of cuts per feature.
|
||||
* \param sorted_data Sorted entries in segments of columns.
|
||||
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
|
||||
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
|
||||
*/
|
||||
void ExtractWeightedCuts(int device,
|
||||
size_t num_cuts_per_feature,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts) {
|
||||
void ExtractWeightedCutsSparse(int device,
|
||||
common::Span<SketchContainer::OffsetT const> cuts_ptr,
|
||||
Span<Entry> sorted_data,
|
||||
Span<float> weights_scan,
|
||||
Span<size_t> column_sizes_scan,
|
||||
Span<SketchEntry> cuts) {
|
||||
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
|
||||
// Each thread is responsible for obtaining one cut from the sorted input
|
||||
size_t column_idx = idx / num_cuts_per_feature;
|
||||
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
|
||||
size_t column_size =
|
||||
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
|
||||
size_t num_available_cuts =
|
||||
min(static_cast<size_t>(num_cuts_per_feature), column_size);
|
||||
size_t cut_idx = idx % num_cuts_per_feature;
|
||||
if (cut_idx >= num_available_cuts) return;
|
||||
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
|
||||
size_t cut_idx = idx - cuts_ptr[column_idx];
|
||||
|
||||
Span<Entry> column_entries =
|
||||
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
|
||||
|
||||
@@ -109,7 +98,7 @@ void ExtractWeightedCuts(int device,
|
||||
max(static_cast<size_t>(0),
|
||||
min(sample_idx, column_entries.size() - 1));
|
||||
}
|
||||
// repeated values will be filtered out on the CPU
|
||||
// repeated values will be filtered out later.
|
||||
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
|
||||
bst_float rmax = column_weights_scan[sample_idx];
|
||||
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
|
||||
@@ -117,31 +106,71 @@ void ExtractWeightedCuts(int device,
|
||||
});
|
||||
}
|
||||
|
||||
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
|
||||
SketchContainer* sketch_container, int num_cuts,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), EntryCompareOp());
|
||||
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
|
||||
double eps = 1.0 / (WQSketch::kFactor * max_bins);
|
||||
size_t dummy_nlevel;
|
||||
size_t num_cuts;
|
||||
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
|
||||
num_rows, eps, &dummy_nlevel, &num_cuts);
|
||||
return std::min(num_cuts, num_rows);
|
||||
}
|
||||
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
num_columns);
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns,
|
||||
size_t max_bins, size_t nnz) {
|
||||
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
|
||||
auto if_dense = num_columns * per_column;
|
||||
auto result = std::min(nnz, if_dense);
|
||||
return result;
|
||||
}
|
||||
|
||||
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
|
||||
ExtractCuts(device, num_cuts,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
||||
size_t num_bins, bool with_weights) {
|
||||
size_t peak = 0;
|
||||
// 0. Allocate cut pointer in quantile container by increasing: n_columns + 1
|
||||
size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT);
|
||||
// 1. Copy and sort: 2 * bytes_per_element * shape
|
||||
total += BytesPerElement(with_weights) * num_rows * num_columns;
|
||||
peak = std::max(peak, total);
|
||||
// 2. Deallocate bytes_per_element * shape due to reusing memory in sort.
|
||||
total -= BytesPerElement(with_weights) * num_rows * num_columns / 2;
|
||||
// 3. Allocate colomn size scan by increasing: n_columns + 1
|
||||
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
|
||||
// 4. Allocate cut pointer by increasing: n_columns + 1
|
||||
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
|
||||
// 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size
|
||||
total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry);
|
||||
// 6. Deallocate copied entries by reducing: bytes_per_element * shape.
|
||||
peak = std::max(peak, total);
|
||||
total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2;
|
||||
// 7. Deallocate column size scan.
|
||||
peak = std::max(peak, total);
|
||||
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
|
||||
// 8. Deallocate cut size scan.
|
||||
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
|
||||
// 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) *
|
||||
// n_columns + n_columns + n_columns + 1
|
||||
total += std::min(num_rows, num_bins) * num_columns * sizeof(float);
|
||||
total += num_columns *
|
||||
sizeof(std::remove_reference_t<decltype(
|
||||
std::declval<HistogramCuts>().MinValues())>::value_type);
|
||||
total += (num_columns + 1) *
|
||||
sizeof(std::remove_reference_t<decltype(
|
||||
std::declval<HistogramCuts>().Ptrs())>::value_type);
|
||||
peak = std::max(peak, total);
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
|
||||
return peak;
|
||||
}
|
||||
|
||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
bst_row_t num_rows, size_t columns, size_t nnz, int device,
|
||||
size_t num_cuts, bool has_weight) {
|
||||
if (sketch_batch_num_elements == 0) {
|
||||
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
|
||||
// use up to 80% of available space
|
||||
sketch_batch_num_elements = (dh::AvailableMemory(device) -
|
||||
required_memory * 0.8);
|
||||
}
|
||||
return sketch_batch_num_elements;
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
@@ -150,7 +179,7 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
EntryCompareOp());
|
||||
detail::EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
||||
@@ -160,6 +189,46 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
return a.index == b.index;
|
||||
});
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
|
||||
SketchContainer *sketch_container, int num_cuts_per_feature,
|
||||
size_t num_columns) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
|
||||
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
|
||||
sorted_entries.data().get(),
|
||||
[] __device__(Entry const &e) -> data::COOTuple {
|
||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
|
||||
});
|
||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||
batch_it, dummy_is_valid,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
|
||||
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
sorted_entries.clear();
|
||||
sorted_entries.shrink_to_fit();
|
||||
CHECK_EQ(sorted_entries.capacity(), 0);
|
||||
CHECK_NE(cuts_ptr.Size(), 0);
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
Span<const float> weights, size_t begin, size_t end,
|
||||
@@ -204,40 +273,53 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
GetColumnSizesScan(device, &column_sizes_scan,
|
||||
{sorted_entries.data().get(), sorted_entries.size()},
|
||||
num_columns);
|
||||
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
|
||||
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
|
||||
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
|
||||
sorted_entries.data().get(),
|
||||
[] __device__(Entry const &e) -> data::COOTuple {
|
||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
||||
});
|
||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||
batch_it, dummy_is_valid,
|
||||
0, sorted_entries.size(),
|
||||
&cuts_ptr, &column_sizes_scan);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
// Extract cuts
|
||||
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts_per_feature);
|
||||
ExtractWeightedCuts(device, num_cuts_per_feature,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
|
||||
dh::ToSpan(sorted_entries),
|
||||
dh::ToSpan(temp_weights),
|
||||
dh::ToSpan(column_sizes_scan),
|
||||
dh::ToSpan(cuts));
|
||||
|
||||
// add cuts into sketches
|
||||
thrust::host_vector<SketchEntry> host_cuts(cuts);
|
||||
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
|
||||
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
|
||||
}
|
||||
|
||||
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
size_t sketch_batch_num_elements) {
|
||||
// Configure batch size based on available memory
|
||||
bool has_weights = dmat->Info().weights_.Size() > 0;
|
||||
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
|
||||
sketch_batch_num_elements = SketchBatchNumElements(
|
||||
size_t num_cuts_per_feature =
|
||||
detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
|
||||
sketch_batch_num_elements = detail::SketchBatchNumElements(
|
||||
sketch_batch_num_elements,
|
||||
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);
|
||||
dmat->Info().num_row_,
|
||||
dmat->Info().num_col_,
|
||||
dmat->Info().num_nonzero_,
|
||||
device, num_cuts_per_feature, has_weights);
|
||||
|
||||
HistogramCuts cuts;
|
||||
DenseCuts dense_cuts(&cuts);
|
||||
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
|
||||
dmat->Info().num_row_);
|
||||
dmat->Info().num_row_, device);
|
||||
|
||||
dmat->Info().weights_.SetDevice(device);
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
@@ -261,8 +343,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_);
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
return cuts;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user