[EM] Avoid stream sync in quantile sketching. (#10765)

.
This commit is contained in:
Jiaming Yuan 2024-08-30 12:33:24 +08:00 committed by GitHub
parent 61dd854a52
commit 34d4ab455e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 313 additions and 313 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
#define XGBOOST_COMMON_ALGORITHM_CUH_
@ -258,5 +258,19 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}
template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_first,
Predicate pred) {
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
// See thrust issue #1302, XGBoost #6822
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
}
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

View File

@ -637,12 +637,11 @@ struct SegmentedUniqueReduceOp {
* \return Number of unique values in total.
*/
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename CompValue, typename CompKey>
size_t
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
typename ValOutIt, typename CompValue, typename CompKey = thrust::equal_to<size_t>>
size_t SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
CompValue comp, CompKey comp_key = thrust::equal_to<size_t>{}) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
@ -676,16 +675,6 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
return n_uniques;
}
template <typename... Inputs,
std::enable_if_t<std::tuple_size<std::tuple<Inputs...>>::value == 7>
* = nullptr>
size_t SegmentedUnique(Inputs &&...inputs) {
dh::XGBCachingDeviceAllocator<char> alloc;
return SegmentedUnique(thrust::cuda::par(alloc),
std::forward<Inputs &&>(inputs)...,
thrust::equal_to<size_t>{});
}
/**
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
*
@ -793,21 +782,6 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
#endif
}
template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
// See thrust issue #1302, XGBoost #6822
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
XGBCachingDeviceAllocator<char> alloc;
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
end_input, out_first, pred);
}
}
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);

View File

@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro
return std::min(sketch_batch_num_elements, kIntMax);
}
void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
dh::XGBDeviceAllocator<char> alloc;
auto cuctx = ctx->CUDACtx();
CHECK_EQ(weights->size(), sorted_entries->size());
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
weights->begin(), detail::EntryCompareOp());
thrust::sort_by_key(cuctx->TP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
detail::EntryCompareOp());
// Scan weights
dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
cuctx->CTP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
weights->begin(),
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
}
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
Span<bst_idx_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
info.feature_types.SetDevice(device);
info.feature_types.SetDevice(ctx->Device());
auto d_feature_types = info.feature_types.ConstDeviceSpan();
CHECK(!d_feature_types.empty());
auto& column_sizes_scan = *p_column_sizes_scan;
@ -142,30 +143,32 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
[=] __device__(Pair const& l, Pair const& r) {
Entry const& le = thrust::get<0>(l);
Entry const& re = thrust::get<0>(r);
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
return le.fvalue == re.fvalue;
}
return false;
});
n_uniques =
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(), val_in_it,
val_in_it + sorted_entries.size(), new_column_scan.data().get(),
val_out_it, [=] __device__(Pair const& l, Pair const& r) {
Entry const& le = thrust::get<0>(l);
Entry const& re = thrust::get<0>(r);
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
return le.fvalue == re.fvalue;
}
return false;
});
p_sorted_weights->resize(n_uniques);
} else {
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
n_uniques = dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const& l, Entry const& r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
}
sorted_entries.resize(n_uniques);
@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
}
});
// Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(),
d_cuts_ptr.data());
}
} // namespace detail
@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
});
detail::SortByWeight(&entry_weight, &sorted_entries);
detail::SortByWeight(ctx, &entry_weight, &sorted_entries);
} else {
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
detail::EntryCompareOp());
@ -238,13 +241,13 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(ctx->Device(), info.num_col_, num_cuts_per_feature,
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
detail::RemoveDuplicatedCategories(ctx->Device(), info, d_cuts_ptr, &sorted_entries, p_weight,
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
&column_sizes_scan);
}
@ -252,7 +255,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
// Add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(entry_weight));
sorted_entries.clear();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2023 by XGBoost contributors
* Copyright 2020-2024, XGBoost contributors
*
* \brief Front end and utilities for GPU based sketching. Works on sliding window
* instead of stream.
@ -13,6 +13,8 @@
#include <cstddef> // for size_t
#include "../data/adapter.h" // for IsValidFunctor
#include "algorithm.cuh" // for CopyIf
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh"
#include "hist_util.h"
#include "quantile.cuh"
@ -107,9 +109,10 @@ std::uint32_t EstimateGridSize(DeviceOrd device, Kernel kernel, std::size_t shar
* \param out_column_size Output buffer for the size of each column.
*/
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid, Span<std::size_t> out_column_size) {
thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0);
void LaunchGetColumnSizeKernel(CUDAContext const* cuctx, DeviceOrd device,
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
Span<std::size_t> out_column_size) {
thrust::fill_n(cuctx->CTP(), dh::tbegin(out_column_size), out_column_size.size(), 0);
std::size_t max_shared_memory = dh::MaxSharedMemory(device.ordinal);
// Not strictly correct as we should use number of samples to determine the type of
@ -135,17 +138,17 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
CHECK(!force_use_u64);
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}(
kernel, batch_iter, is_valid, out_column_size);
} else {
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}(
kernel, batch_iter, is_valid, out_column_size);
}
} else {
auto d_out_column_size = out_column_size;
dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) {
dh::LaunchN(batch_iter.size(), cuctx->Stream(), [=] __device__(size_t idx) {
auto e = batch_iter[idx];
if (is_valid(e)) {
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
@ -155,26 +158,26 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan<BatchIt> batch_iter,
}
template <typename BatchIt>
void GetColumnSizesScan(DeviceOrd device, size_t num_columns, std::size_t num_cuts_per_feature,
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
void GetColumnSizesScan(CUDAContext const* cuctx, DeviceOrd device, size_t num_columns,
std::size_t num_cuts_per_feature, IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid,
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
dh::caching_device_vector<size_t>* column_sizes_scan) {
column_sizes_scan->resize(num_columns + 1);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan);
LaunchGetColumnSizeKernel(cuctx, device, batch_iter, is_valid, d_column_sizes_scan);
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
thrust::exclusive_scan(cuctx->CTP(), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
thrust::exclusive_scan(cuctx->CTP(), column_sizes_scan->begin(), column_sizes_scan->end(),
column_sizes_scan->begin());
}
inline size_t constexpr BytesPerElement(bool has_weight) {
@ -215,9 +218,9 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz,
// Count the valid entries in each column and copy them out.
template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range,
float missing, size_t columns, size_t cuts_per_feature,
DeviceOrd device,
void MakeEntriesFromAdapter(CUDAContext const* cuctx, AdapterBatch const& batch,
BatchIter batch_iter, Range1d range, float missing, size_t columns,
size_t cuts_per_feature, DeviceOrd device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::device_vector<Entry>* sorted_entries) {
@ -229,19 +232,20 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran
auto span = IterSpan{batch_iter + range.begin(), n};
data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column
GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
GetColumnSizesScan(cuctx, device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
column_sizes_scan);
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
is_valid);
CopyIf(cuctx, entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
is_valid);
}
void SortByWeight(dh::device_vector<float>* weights,
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries);
void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
Span<bst_idx_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan);
@ -278,10 +282,9 @@ inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t
}
template <typename AdapterBatch>
void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
DeviceOrd device, size_t columns, size_t begin, size_t end,
float missing, SketchContainer *sketch_container,
int num_cuts) {
void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info,
size_t columns, size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
// Copy current subset of valid elements into temporary storage and sort
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
@ -289,54 +292,45 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info,
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
cuts_ptr.SetDevice(device);
detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing,
columns, num_cuts, device,
&cuts_ptr,
&column_sizes_scan,
&sorted_entries);
dh::XGBDeviceAllocator<char> alloc;
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
cuts_ptr.SetDevice(ctx->Device());
CUDAContext const* cuctx = ctx->CUDACtx();
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts,
ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries);
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp());
if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr,
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, nullptr,
&column_sizes_scan);
}
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
auto const &h_cuts_ptr = cuts_ptr.HostVector();
// Extract the cuts from all columns concurrently
sketch_container->Push(dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back());
sorted_entries.clear();
sorted_entries.shrink_to_fit();
}
template <typename Batch>
void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
int num_cuts_per_feature,
bool is_ranking, float missing, DeviceOrd device,
size_t columns, size_t begin, size_t end,
SketchContainer *sketch_container) {
dh::XGBCachingDeviceAllocator<char> alloc;
void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info,
int num_cuts_per_feature, bool is_ranking, float missing,
DeviceOrd device, size_t columns, size_t begin, size_t end,
SketchContainer* sketch_container) {
dh::safe_cuda(cudaSetDevice(device.ordinal));
info.weights_.SetDevice(device);
auto weights = info.weights_.ConstDeviceSpan();
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto cuctx = ctx->CUDACtx();
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(batch, batch_iter,
{begin, end}, missing,
columns, num_cuts_per_feature, device,
&cuts_ptr,
&column_sizes_scan,
detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns,
num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan,
&sorted_entries);
data::IsValidFunctor is_valid(missing);
@ -355,7 +349,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx);
return weights[group_idx];
});
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
auto retit = thrust::copy_if(cuctx->CTP(),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
@ -368,7 +362,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
[=]__device__(size_t idx) -> float {
return weights[batch.GetElement(idx).row_idx];
});
auto retit = thrust::copy_if(thrust::cuda::par(alloc),
auto retit = thrust::copy_if(cuctx->CTP(),
weight_iter + begin, weight_iter + end,
batch_iter + begin,
d_temp_weights.data(), // output
@ -376,11 +370,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
}
detail::SortByWeight(&temp_weights, &sorted_entries);
detail::SortByWeight(ctx, &temp_weights, &sorted_entries);
if (sketch_container->HasCategorical()) {
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights,
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, &temp_weights,
&column_sizes_scan);
}
@ -388,8 +382,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
// Extract cuts
sketch_container->Push(dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), d_cuts_ptr,
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(temp_weights));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
@ -407,8 +400,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
* testing.
*/
template <typename Batch>
void AdapterDeviceSketch(Batch batch, int num_bins,
MetaInfo const& info,
void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
size_t num_rows = batch.NumRows();
@ -419,27 +411,24 @@ void AdapterDeviceSketch(Batch batch, int num_bins,
if (weighted) {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
device.ordinal, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts_per_feature,
HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature,
HostSketchContainer::UseGroup(info), missing, device, num_cols,
begin, end, sketch_container);
}
} else {
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_rows, num_cols, std::numeric_limits<size_t>::max(),
sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits<size_t>::max(),
device.ordinal, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end =
std::min(batch.Size(), static_cast<std::size_t>(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing,
sketch_container, num_cuts_per_feature);
ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container,
num_cuts_per_feature);
}
}
}

View File

@ -18,6 +18,8 @@
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
#include "categorical.h"
#include "common.h"
#include "cuda_context.cuh" // for CUDAContext
#include "cuda_rt_utils.h" // for SetDevice
#include "device_helpers.cuh"
#include "hist_util.h"
#include "quantile.cuh"
@ -117,6 +119,7 @@ void CopyTo(Span<T> out, Span<U> src) {
// Compute the merge path.
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
Context const* ctx,
Span<SketchEntry const> const &d_x, Span<bst_idx_t const> const &x_ptr,
Span<SketchEntry const> const &d_y, Span<bst_idx_t const> const &y_ptr,
Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
@ -142,13 +145,12 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
auto y_merge_val_it =
thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder));
dh::XGBCachingDeviceAllocator<Tuple> alloc;
static_assert(sizeof(Tuple) == sizeof(SketchEntry));
// We reuse the memory for storing merge path.
common::Span<Tuple> merge_path{reinterpret_cast<Tuple *>(out.data()), out.size()};
// Determine the merge path, 0 if element is from x, 1 if it's from y.
thrust::merge_by_key(
thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(),
ctx->CUDACtx()->CTP(), x_merge_key_it, x_merge_key_it + d_x.size(),
y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it,
y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(),
[=] __device__(auto const &l, auto const &r) -> bool {
@ -163,10 +165,9 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
// Compute output ptr
auto transform_it =
thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data()));
thrust::transform(
thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(),
out_ptr.data(),
[] __device__(auto const& t) { return thrust::get<0>(t) + thrust::get<1>(t); });
thrust::transform(ctx->CUDACtx()->CTP(), transform_it, transform_it + x_ptr.size(),
out_ptr.data(),
[] __device__(auto const &t) { return thrust::get<0>(t) + thrust::get<1>(t); });
// 0^th is the indicator, 1^th is placeholder
auto get_ind = []XGBOOST_DEVICE(Tuple const& t) { return thrust::get<0>(t); };
@ -194,7 +195,7 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
// is landed into output as the first element in merge result. The scan result is the
// subscript of x and y.
thrust::exclusive_scan_by_key(
thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(),
ctx->CUDACtx()->CTP(), scan_key_it, scan_key_it + merge_path.size(),
scan_val_it, merge_path.data(),
thrust::make_tuple<uint64_t, uint64_t>(0ul, 0ul),
thrust::equal_to<size_t>{},
@ -209,18 +210,17 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
// summary does the output element come from) result by definition of merged rank. So we
// run it in 2 passes to obtain the merge path and then customize the standard merge
// algorithm.
void MergeImpl(DeviceOrd device, Span<SketchEntry const> const &d_x,
void MergeImpl(Context const *ctx, Span<SketchEntry const> const &d_x,
Span<bst_idx_t const> const &x_ptr, Span<SketchEntry const> const &d_y,
Span<bst_idx_t const> const &y_ptr, Span<SketchEntry> out, Span<bst_idx_t> out_ptr) {
dh::safe_cuda(cudaSetDevice(device.ordinal));
CHECK_EQ(d_x.size() + d_y.size(), out.size());
CHECK_EQ(x_ptr.size(), out_ptr.size());
CHECK_EQ(y_ptr.size(), out_ptr.size());
auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr);
auto d_merge_path = MergePath(ctx, d_x, x_ptr, d_y, y_ptr, out, out_ptr);
auto d_out = out;
dh::LaunchN(d_out.size(), [=] __device__(size_t idx) {
dh::LaunchN(d_out.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(out_ptr, idx);
idx -= out_ptr[column_id];
@ -307,10 +307,9 @@ void MergeImpl(DeviceOrd device, Span<SketchEntry const> const &d_x,
});
}
void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr,
size_t total_cuts, Span<float> weights) {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
void SketchContainer::Push(Context const *ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights) {
common::SetDevice(device_.ordinal);
Span<SketchEntry> out;
dh::device_vector<SketchEntry> cuts;
bool first_window = this->Current().empty();
@ -346,12 +345,12 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
}; // NOLINT
PruneImpl<Entry>(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry);
}
auto n_uniques = this->ScanInput(out, cuts_ptr);
auto n_uniques = this->ScanInput(ctx, out, cuts_ptr);
if (!first_window) {
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
out = out.subspan(0, n_uniques);
this->Merge(cuts_ptr, out);
this->Merge(ctx, cuts_ptr, out);
this->FixError();
} else {
this->Current().resize(n_uniques);
@ -363,7 +362,8 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
}
}
size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in) {
size_t SketchContainer::ScanInput(Context const *ctx, Span<SketchEntry> entries,
Span<OffsetT> d_columns_ptr_in) {
/* There are 2 types of duplication. First is duplicated feature values, which comes
* from user input data. Second is duplicated sketching entries, which is generated by
* pruning or merging. We preserve the first type and remove the second type.
@ -371,7 +371,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal));
CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1);
dh::XGBCachingDeviceAllocator<char> alloc;
auto key_it = dh::MakeTransformIterator<size_t>(
thrust::make_reverse_iterator(thrust::make_counting_iterator(entries.size())),
@ -381,7 +380,7 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
// Reverse scan to accumulate weights into first duplicated element on left.
auto val_it = thrust::make_reverse_iterator(dh::tend(entries));
thrust::inclusive_scan_by_key(
thrust::cuda::par(alloc), key_it, key_it + entries.size(),
ctx->CUDACtx()->CTP(), key_it, key_it + entries.size(),
val_it, val_it,
thrust::equal_to<size_t>{},
[] __device__(SketchEntry const &r, SketchEntry const &l) {
@ -396,18 +395,18 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan();
// thrust unique_by_key preserves the first element.
auto n_uniques = dh::SegmentedUnique(
d_columns_ptr_in.data(),
d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(),
entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(),
detail::SketchUnique{});
auto n_uniques =
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), d_columns_ptr_in.data(),
d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(),
entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(),
detail::SketchUnique{});
CopyTo(d_columns_ptr_in, d_columns_ptr_out);
timer_.Stop(__func__);
return n_uniques;
}
void SketchContainer::Prune(size_t to) {
void SketchContainer::Prune(Context const* ctx, std::size_t to) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal));
@ -438,19 +437,19 @@ void SketchContainer::Prune(size_t to) {
this->columns_ptr_.Copy(columns_ptr_b_);
this->Alternate();
this->Unique();
this->Unique(ctx);
timer_.Stop(__func__);
}
void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
void SketchContainer::Merge(Context const *ctx, Span<OffsetT const> d_that_columns_ptr,
Span<SketchEntry const> that) {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
common::SetDevice(device_.ordinal);
timer_.Start(__func__);
if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size());
CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1);
thrust::copy(thrust::device, d_that_columns_ptr.data(),
thrust::copy(ctx->CUDACtx()->CTP(), d_that_columns_ptr.data(),
d_that_columns_ptr.data() + d_that_columns_ptr.size(),
this->columns_ptr_.DevicePointer());
auto total = this->columns_ptr_.HostVector().back();
@ -463,7 +462,7 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
this->Other().resize(this->Current().size() + that.size());
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
MergeImpl(ctx, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr,
dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan());
this->columns_ptr_.Copy(columns_ptr_b_);
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
@ -471,7 +470,7 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
if (this->HasCategorical()) {
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
this->Unique(ctx, [d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
});
}
@ -517,7 +516,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
SafeColl(rc);
bst_idx_t intermediate_num_cuts =
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
this->Prune(ctx, intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
@ -570,9 +569,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) {
for (size_t i = 0; i < allworkers.size(); ++i) {
auto worker = allworkers[i];
auto worker_ptr =
dh::ToSpan(gathered_ptrs)
.subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
new_sketch.Merge(worker_ptr, worker);
dh::ToSpan(gathered_ptrs).subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
new_sketch.Merge(ctx, worker_ptr, worker);
new_sketch.FixError();
}
@ -602,7 +600,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
this->AllReduce(ctx, is_column_split);
// Prune to final number of bins.
this->Prune(num_bins_ + 1);
this->Prune(ctx, num_bins_ + 1);
this->FixError();
// Set up inputs
@ -624,7 +622,6 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
std::vector<SketchEntry> max_values;
float max_cat{-1.f};
if (has_categorical_) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto key_it = dh::MakeTransformIterator<bst_feature_t>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t {
return dh::SegmentId(d_in_columns_ptr, i);
@ -651,7 +648,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
dh::caching_device_vector<size_t> d_max_keys(d_in_columns_ptr.size() - 1);
dh::caching_device_vector<SketchEntry> d_max_values(d_in_columns_ptr.size() - 1);
auto new_end = thrust::reduce_by_key(
thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
ctx->CUDACtx()->CTP(), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(),
d_max_values.begin(), thrust::equal_to<bst_feature_t>{},
[] __device__(auto l, auto r) { return l.value > r.value ? l : r; });
d_max_keys.erase(new_end.first, d_max_keys.end());
@ -661,7 +658,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
SketchEntry default_entry{};
dh::caching_device_vector<SketchEntry> d_max_results(d_in_columns_ptr.size() - 1,
default_entry);
thrust::scatter(thrust::cuda::par(alloc), d_max_values.begin(), d_max_values.end(),
thrust::scatter(ctx->CUDACtx()->CTP(), d_max_values.begin(), d_max_values.end(),
d_max_keys.begin(), d_max_results.begin());
dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results));
auto max_it = MakeIndexTransformIter([&](auto i) {

View File

@ -7,6 +7,7 @@
#include <thrust/logical.h> // for any_of
#include "categorical.h"
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh"
#include "error_msg.h" // for InvalidMaxBin
#include "quantile.h"
@ -127,7 +128,7 @@ class SketchContainer {
/* \brief Whether the predictor matrix contains categorical features. */
bool HasCategorical() const { return has_categorical_; }
/* \brief Accumulate weights of duplicated entries in input. */
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
size_t ScanInput(Context const* ctx, Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
void FixError();
@ -140,19 +141,18 @@ class SketchContainer {
* \param total_cuts Total number of cuts, equal to the back of cuts_ptr.
* \param weights (optional) data weights.
*/
void Push(Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr, size_t total_cuts,
Span<float> weights = {});
void Push(Context const* ctx, Span<Entry const> entries, Span<size_t> columns_ptr,
common::Span<OffsetT> cuts_ptr, size_t total_cuts, Span<float> weights = {});
/* \brief Prune the quantile structure.
*
* \param to The maximum size of pruned quantile. If the size of quantile
* structure is already less than `to`, then no operation is performed.
*/
void Prune(size_t to);
void Prune(Context const* ctx, size_t to);
/* \brief Merge another set of sketch.
* \param that columns of other.
*/
void Merge(Span<OffsetT const> that_columns_ptr,
void Merge(Context const* ctx, Span<OffsetT const> that_columns_ptr,
Span<SketchEntry const> that);
/* \brief Merge quantiles from other GPU workers. */
@ -175,7 +175,7 @@ class SketchContainer {
/* \brief Removes all the duplicated elements in quantile structure. */
template <typename KeyComp = thrust::equal_to<size_t>>
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to<size_t>{}) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_.ordinal));
this->columns_ptr_.SetDevice(device_);
@ -185,14 +185,12 @@ class SketchContainer {
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);
ctx->CUDACtx()->CTP(), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), entries.data(),
detail::SketchUnique{}, key_comp);
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());

View File

@ -11,6 +11,7 @@
#include <type_traits> // for invoke_result_t, declval
#include <vector> // for vector
#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE
#include "adapter.h"
#include "xgboost/c_api.h"
#include "xgboost/context.h"
@ -36,6 +37,8 @@ class DataIterProxy {
DataIterProxy& operator=(DataIterProxy const& that) = default;
[[nodiscard]] bool Next() {
xgboost_NVTX_FN_RANGE();
bool ret = !!next_(iter_);
if (!ret) {
return ret;

View File

@ -30,14 +30,13 @@ void MakeSketches(Context const* ctx,
ExternalDataInfo* p_ext_info) {
xgboost_NVTX_FN_RANGE();
CUDAContext const* cuctx = ctx->CUDACtx();
std::unique_ptr<common::SketchContainer> sketch;
auto& ext_info = *p_ext_info;
do {
// We use do while here as the first batch is fetched in ctor
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
common::SetDevice(dh::GetDevice(ctx).ordinal);
if (ext_info.n_features == 0) {
ext_info.n_features = data::BatchColumns(proxy);
auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1),
@ -55,7 +54,16 @@ void MakeSketches(Context const* ctx,
}
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
cuda_impl::Dispatch(proxy, [&](auto const& value) {
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get());
// Workaround empty input with CPU ctx.
Context new_ctx;
Context const* p_ctx;
if (ctx->IsCUDA()) {
p_ctx = ctx;
} else {
new_ctx.UpdateAllowUnknown(Args{{"device", dh::GetDevice(ctx).Name()}});
p_ctx = &new_ctx;
}
common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, sketch.get());
});
}
auto batch_rows = data::BatchSamples(proxy);
@ -66,7 +74,7 @@ void MakeSketches(Context const* ctx,
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing);
}));
ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end());
ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end());
ext_info.n_batches++;
ext_info.base_rows.push_back(batch_rows);
} while (iter->Next());
@ -77,7 +85,7 @@ void MakeSketches(Context const* ctx,
ext_info.base_rows.begin());
// Get reference
dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal));
common::SetDevice(dh::GetDevice(ctx).ordinal);
if (!ref) {
sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit());
} else {

View File

@ -11,6 +11,7 @@
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InfInData
#include "../common/algorithm.cuh" // for CopyIf
#include "device_adapter.cuh" // for NoInfInData
namespace xgboost::data {
@ -27,16 +28,15 @@ struct COOToEntryOp {
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
float missing) {
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data, float missing) {
auto counting = thrust::make_counting_iterator(0llu);
dh::XGBCachingDeviceAllocator<char> alloc;
COOToEntryOp<decltype(batch)> transform_op{batch};
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
transform_iter(counting, transform_op);
thrust::transform_iterator<decltype(transform_op), decltype(counting)> transform_iter(
counting, transform_op);
auto begin_output = thrust::device_pointer_cast(data.data());
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
auto ctx = Context{}.MakeCUDA(dh::CurrentDevice());
common::CopyIf(ctx.CUDACtx(), transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
}
template <typename AdapterBatchT>

View File

@ -9,8 +9,10 @@
#include <cstdint>
#include <vector>
#include "../../../src/common/cuda_context.cuh"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
#include "../helpers.h"
#include "gtest/gtest.h"
TEST(SumReduce, Test) {
@ -61,11 +63,11 @@ TEST(SegmentedUnique, Basic) {
thrust::device_vector<xgboost::bst_feature_t> d_segs_out(d_segments.size());
thrust::device_vector<float> d_vals_out(d_values.size());
auto ctx = xgboost::MakeCUDACtx(0);
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(),
d_vals_out.data().get(), thrust::equal_to<float>{});
CHECK_EQ(n_uniques, 5);
std::vector<float> values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f};
@ -81,10 +83,9 @@ TEST(SegmentedUnique, Basic) {
d_segments[1] = 4;
d_segments[2] = 6;
n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(),
d_vals_out.data().get(), thrust::equal_to<float>{});
ASSERT_EQ(n_uniques, values.size());
for (size_t i = 0 ; i < values.size(); i ++) {
ASSERT_EQ(d_vals_out[i], values[i]);
@ -113,10 +114,12 @@ void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_dup
thrust::device_vector<bst_feature_t> d_segments(segments);
thrust::device_vector<bst_feature_t> d_segments_out(segments.size());
auto ctx = xgboost::MakeCUDACtx(0);
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(),
d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(),
SketchUnique{});
ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(), d_segments_out.data().get(),
d_values.data().get(), SketchUnique{});
ASSERT_EQ(n_uniques, values.size() - n_duplicated);
ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(),
d_values.begin() + n_uniques, IsSorted{}));

View File

@ -221,8 +221,8 @@ TEST(HistUtil, RemoveDuplicatedCategories) {
thrust::sort_by_key(sorted_entries.begin(), sorted_entries.end(), weight.begin(),
detail::EntryCompareOp());
detail::RemoveDuplicatedCategories(ctx.Device(), info, cuts_ptr.DeviceSpan(), &sorted_entries,
&weight, &columns_ptr);
detail::RemoveDuplicatedCategories(&ctx, info, cuts_ptr.DeviceSpan(), &sorted_entries, &weight,
&columns_ptr);
auto const& h_cptr = cuts_ptr.ConstHostVector();
ASSERT_EQ(h_cptr.back(), n_samples * 2 + n_categories);
@ -367,7 +367,7 @@ auto MakeUnweightedCutsForTest(Context const* ctx, Adapter adapter, int32_t num_
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(),
DeviceOrd::CUDA(0));
MetaInfo info;
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
AdapterDeviceSketch(ctx, adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit());
return batched_cuts;
}
@ -437,8 +437,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
common::HistogramCuts batched_cuts;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
size_t bytes_required = detail::RequiredMemory(
@ -466,9 +466,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
common::HistogramCuts batched_cuts;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0));
AdapterDeviceSketch(adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
@ -502,7 +501,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
ASSERT_EQ(info.feature_types.Size(), 1);
SketchContainer container(info.feature_types, num_bins, 1, n, DeviceOrd::CUDA(0));
AdapterDeviceSketch(adapter.Value(), num_bins, info,
AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), &container);
HistogramCuts cuts;
container.MakeCuts(&ctx, &cuts, info.IsColumnSplit());
@ -616,22 +615,27 @@ void TestGetColumnSize(std::size_t n_samples) {
std::vector<std::size_t> h_column_size(column_sizes_scan.size());
std::vector<std::size_t> h_column_size_1(column_sizes_scan.size());
auto cuctx = ctx.CUDACtx();
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, true>(
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
dh::ToSpan(column_sizes_scan));
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin());
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, false>(
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
dh::ToSpan(column_sizes_scan));
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
ASSERT_EQ(h_column_size, h_column_size_1);
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, true>(
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
dh::ToSpan(column_sizes_scan));
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
ASSERT_EQ(h_column_size, h_column_size_1);
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, false>(
ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid,
dh::ToSpan(column_sizes_scan));
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
ASSERT_EQ(h_column_size, h_column_size_1);
}
@ -737,7 +741,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
auto const& batch = adapter.Value();
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_container(ft, kBins, kCols, kRows, DeviceOrd::CUDA(0));
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
common::HistogramCuts cuts;
@ -780,7 +784,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast<float>(kGroups);
}
SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)};
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit());
ValidateCuts(weighted, dmat.get(), kBins);

View File

@ -24,14 +24,15 @@ namespace common {
class MGPUQuantileTest : public collective::BaseMGPUTest {};
TEST(GPUQuantile, Basic) {
auto ctx = MakeCUDACtx(0);
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, kBins, kCols, kRows, FstCU());
SketchContainer sketch(ft, kBins, kCols, kRows, ctx.Device());
dh::caching_device_vector<Entry> entries;
dh::device_vector<bst_idx_t> cuts_ptr(kCols+1);
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
// Push empty
sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
sketch.Push(&ctx, dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0);
ASSERT_EQ(sketch.Data().size(), 0);
}
@ -39,16 +40,17 @@ void TestSketchUnique(float sparsity) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins,
MetaInfo const& info) {
auto ctx = MakeCUDACtx(0);
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity}
.Seed(seed)
.Device(FstCU())
.Device(ctx.Device())
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketch(adapter.Value(), n_bins, info,
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
@ -60,8 +62,9 @@ void TestSketchUnique(float sparsity) {
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto end = kCols * kRows;
detail::GetColumnSizesScan(FstCU(), kCols, n_cuts, IterSpan{batch_iter, end}, is_valid,
&cut_sizes_scan, &column_sizes_scan);
detail::GetColumnSizesScan(ctx.CUDACtx(), ctx.Device(), kCols, n_cuts,
IterSpan{batch_iter, end}, is_valid, &cut_sizes_scan,
&column_sizes_scan);
auto const& cut_sizes = cut_sizes_scan.HostVector();
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
@ -69,7 +72,7 @@ void TestSketchUnique(float sparsity) {
dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr());
ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back());
sketch.Unique();
sketch.Unique(&ctx);
std::vector<SketchEntry> h_data(sketch.Data().size());
thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin());
@ -124,44 +127,46 @@ void TestQuantileElemRank(DeviceOrd device, Span<SketchEntry const> in,
TEST(GPUQuantile, Prune) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
auto ctx = MakeCUDACtx(0);
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
HostDeviceVector<float> storage;
std::string interface_str =
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
&storage);
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(ctx.Device())
.Seed(seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketch(adapter.Value(), n_bins, info,
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
// LE because kRows * kCols is pushed into sketch, after removing
// duplicated entries we might not have that much inputs for prune.
ASSERT_LE(sketch.Data().size(), n_cuts * kCols);
sketch.Prune(n_bins);
sketch.Prune(&ctx, n_bins);
ASSERT_LE(sketch.Data().size(), kRows * kCols);
// This is not necessarily true for all inputs without calling unique after
// prune.
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
sketch.Data().data() + sketch.Data().size(),
detail::SketchUnique{}));
TestQuantileElemRank(FstCU(), sketch.Data(), sketch.ColumnsPtr());
TestQuantileElemRank(ctx.Device(), sketch.Data(), sketch.ColumnsPtr());
});
}
TEST(GPUQuantile, MergeEmpty) {
constexpr size_t kRows = 1000, kCols = 100;
size_t n_bins = 10;
auto ctx = MakeCUDACtx(0);
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU());
SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device());
HostDeviceVector<float> storage_0;
std::string interface_str_0 =
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).GenerateArrayInterface(
&storage_0);
RandomDataGenerator{kRows, kCols, 0}.Device(ctx.Device()).GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
MetaInfo info;
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
@ -170,7 +175,7 @@ TEST(GPUQuantile, MergeEmpty) {
dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr());
thrust::device_vector<size_t> columns_ptr(kCols + 1);
// Merge an empty sketch
sketch_0.Merge(dh::ToSpan(columns_ptr), Span<SketchEntry>{});
sketch_0.Merge(&ctx, dh::ToSpan(columns_ptr), Span<SketchEntry>{});
std::vector<SketchEntry> entries_after(sketch_0.Data().size());
dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data());
@ -193,34 +198,36 @@ TEST(GPUQuantile, MergeEmpty) {
TEST(GPUQuantile, MergeBasic) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
auto ctx = MakeCUDACtx(0);
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU());
SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device());
HostDeviceVector<float> storage_0;
std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}
.Device(FstCU())
.Device(ctx.Device())
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, FstCU());
SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, ctx.Device());
HostDeviceVector<float> storage_1;
std::string interface_str_1 =
RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
&storage_1);
std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0}
.Device(ctx.Device())
.Seed(seed)
.GenerateArrayInterface(&storage_1);
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketch(adapter_1.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_1);
AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
if (info.weights_.Size() != 0) {
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), true);
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), true);
sketch_0.FixError();
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), false);
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), false);
} else {
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr());
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr());
}
auto columns_ptr = sketch_0.ColumnsPtr();
@ -228,7 +235,7 @@ TEST(GPUQuantile, MergeBasic) {
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
sketch_0.Unique(&ctx);
ASSERT_TRUE(
thrust::is_sorted(thrust::device, sketch_0.Data().data(),
sketch_0.Data().data() + sketch_0.Data().size(),
@ -237,25 +244,27 @@ TEST(GPUQuantile, MergeBasic) {
}
void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
auto ctx = MakeCUDACtx(0);
MetaInfo info;
int32_t seed = 0;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_0(ft, n_bins, cols, rows, FstCU());
SketchContainer sketch_0(ft, n_bins, cols, rows, ctx.Device());
HostDeviceVector<float> storage_0;
std::string interface_str_0 =
RandomDataGenerator{rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
&storage_0);
std::string interface_str_0 = RandomDataGenerator{rows, cols, 0}
.Device(ctx.Device())
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketch(adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_0);
AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
size_t f_rows = rows * frac;
SketchContainer sketch_1(ft, n_bins, cols, f_rows, FstCU());
SketchContainer sketch_1(ft, n_bins, cols, f_rows, ctx.Device());
HostDeviceVector<float> storage_1;
std::string interface_str_1 =
RandomDataGenerator{f_rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface(
&storage_1);
std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0}
.Device(ctx.Device())
.Seed(seed)
.GenerateArrayInterface(&storage_1);
auto data_1 = storage_1.DeviceSpan();
auto tuple_it = thrust::make_tuple(
thrust::make_counting_iterator<size_t>(0ul), data_1.data());
@ -271,20 +280,19 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
}
});
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketch(adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_1);
AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr());
sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data());
TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr());
auto columns_ptr = sketch_0.ColumnsPtr();
std::vector<bst_idx_t> h_columns_ptr(columns_ptr.size());
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
sketch_0.Unique(&ctx);
columns_ptr = sketch_0.ColumnsPtr();
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
@ -311,7 +319,8 @@ TEST(GPUQuantile, MultiMerge) {
RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) {
// Set up single node version
HostDeviceVector<FeatureType> ft;
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, FstCU());
auto ctx = MakeCUDACtx(0);
SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, ctx.Device());
size_t intermediate_num_cuts = std::min(
kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
@ -319,25 +328,26 @@ TEST(GPUQuantile, MultiMerge) {
for (auto rank = 0; rank < world; ++rank) {
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(FstCU())
.Device(ctx.Device())
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
HostDeviceVector<FeatureType> ft;
containers.emplace_back(ft, n_bins, kCols, kRows, FstCU());
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&containers.back());
containers.emplace_back(ft, n_bins, kCols, kRows, ctx.Device());
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &containers.back());
}
for (auto &sketch : containers) {
sketch.Prune(intermediate_num_cuts);
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
sketch.Prune(&ctx, intermediate_num_cuts);
sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data());
sketch_on_single_node.FixError();
}
TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr());
TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(),
sketch_on_single_node.ColumnsPtr());
sketch_on_single_node.Unique();
TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr());
sketch_on_single_node.Unique(&ctx);
TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(),
sketch_on_single_node.ColumnsPtr());
});
}
@ -392,15 +402,15 @@ void TestAllReduceBasic() {
data::CupyAdapter adapter(interface_str);
HostDeviceVector<FeatureType> ft({}, device);
containers.emplace_back(ft, n_bins, kCols, kRows, device);
AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
&containers.back());
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &containers.back());
}
for (auto& sketch : containers) {
sketch.Prune(intermediate_num_cuts);
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
sketch.Prune(&ctx, intermediate_num_cuts);
sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data());
sketch_on_single_node.FixError();
}
sketch_on_single_node.Unique();
sketch_on_single_node.Unique(&ctx);
TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(),
true);
@ -416,16 +426,16 @@ void TestAllReduceBasic() {
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_distributed);
if (world == 1) {
auto n_samples_global = kRows * world;
intermediate_num_cuts =
std::min(n_samples_global, static_cast<size_t>(n_bins * SketchContainer::kFactor));
sketch_distributed.Prune(intermediate_num_cuts);
sketch_distributed.Prune(&ctx, intermediate_num_cuts);
}
sketch_distributed.AllReduce(&ctx, false);
sketch_distributed.Unique();
sketch_distributed.Unique(&ctx);
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size());
ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size());
@ -535,11 +545,10 @@ void TestSameOnAllWorkers() {
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_distributed);
sketch_distributed.AllReduce(&ctx, false);
sketch_distributed.Unique();
sketch_distributed.Unique(&ctx);
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
// Test for all workers having the same sketch.
@ -547,16 +556,13 @@ void TestSameOnAllWorkers() {
auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax);
SafeColl(rc);
ASSERT_EQ(n_data, sketch_distributed.Data().size());
size_t size_as_float =
sketch_distributed.Data().size_bytes() / sizeof(float);
size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float);
auto local_data = Span<float const>{
reinterpret_cast<float const *>(sketch_distributed.Data().data()),
size_as_float};
reinterpret_cast<float const*>(sketch_distributed.Data().data()), size_as_float};
dh::caching_device_vector<float> all_workers(size_as_float * world);
thrust::fill(all_workers.begin(), all_workers.end(), 0);
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
rc = collective::Allreduce(
&ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()),
@ -590,6 +596,7 @@ TEST_F(MGPUQuantileTest, SameOnAllWorkers) {
TEST(GPUQuantile, Push) {
size_t constexpr kRows = 100;
std::vector<float> data(kRows);
auto ctx = MakeCUDACtx(0);
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
std::fill(data.begin() + (data.size() / 2), data.end(), 0.5f);
@ -608,8 +615,8 @@ TEST(GPUQuantile, Push) {
columns_ptr[1] = kRows;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {});
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {});
auto sketch_data = sketch.Data();
@ -633,9 +640,9 @@ TEST(GPUQuantile, Push) {
TEST(GPUQuantile, MultiColPush) {
size_t constexpr kRows = 100, kCols = 4;
std::vector<float> data(kRows * kCols);
std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f);
auto ctx = MakeCUDACtx(0);
std::vector<Entry> entries(kRows * kCols);
for (bst_feature_t c = 0; c < kCols; ++c) {
@ -648,7 +655,7 @@ TEST(GPUQuantile, MultiColPush) {
int32_t n_bins = 16;
HostDeviceVector<FeatureType> ft;
SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU());
SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device());
dh::device_vector<Entry> d_entries {entries};
dh::device_vector<size_t> columns_ptr(kCols + 1, 0);
@ -659,8 +666,8 @@ TEST(GPUQuantile, MultiColPush) {
columns_ptr.begin());
dh::device_vector<size_t> cuts_ptr(columns_ptr);
sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr),
dh::ToSpan(cuts_ptr), kRows * kCols, {});
sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr),
kRows * kCols, {});
auto sketch_data = sketch.Data();
ASSERT_EQ(sketch_data.size(), kCols * 2);