From 34d4ab455e25687f087119af1c74ee763a721cb3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 30 Aug 2024 12:33:24 +0800 Subject: [PATCH] [EM] Avoid stream sync in quantile sketching. (#10765) . --- src/common/algorithm.cuh | 16 ++- src/common/device_helpers.cuh | 36 +---- src/common/hist_util.cu | 73 +++++----- src/common/hist_util.cuh | 127 ++++++++--------- src/common/quantile.cu | 77 +++++----- src/common/quantile.cuh | 22 ++- src/data/proxy_dmatrix.h | 3 + src/data/quantile_dmatrix.cu | 18 ++- src/data/simple_dmatrix.cuh | 14 +- tests/cpp/common/test_device_helpers.cu | 25 ++-- tests/cpp/common/test_hist_util.cu | 34 +++-- tests/cpp/common/test_quantile.cu | 181 ++++++++++++------------ 12 files changed, 313 insertions(+), 313 deletions(-) diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh index 137832def..b0bec3488 100644 --- a/src/common/algorithm.cuh +++ b/src/common/algorithm.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #ifndef XGBOOST_COMMON_ALGORITHM_CUH_ #define XGBOOST_COMMON_ALGORITHM_CUH_ @@ -258,5 +258,19 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span keys, sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream())); } + +template +void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_first, + Predicate pred) { + // We loop over batches because thrust::copy_if can't deal with sizes > 2^31 + // See thrust issue #1302, XGBoost #6822 + size_t constexpr kMaxCopySize = std::numeric_limits::max() / 2; + size_t length = std::distance(in_first, in_second); + for (size_t offset = 0; offset < length; offset += kMaxCopySize) { + auto begin_input = in_first + offset; + auto end_input = in_first + std::min(offset + kMaxCopySize, length); + out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred); + } +} } // namespace xgboost::common #endif // XGBOOST_COMMON_ALGORITHM_CUH_ diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 7d35beb72..2e5fb5cd9 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -637,12 +637,11 @@ struct SegmentedUniqueReduceOp { * \return Number of unique values in total. */ template -size_t -SegmentedUnique(const thrust::detail::execution_policy_base &exec, - KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first, - ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out, - CompValue comp, CompKey comp_key=thrust::equal_to{}) { + typename ValOutIt, typename CompValue, typename CompKey = thrust::equal_to> +size_t SegmentedUnique(const thrust::detail::execution_policy_base &exec, + KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first, + ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out, + CompValue comp, CompKey comp_key = thrust::equal_to{}) { using Key = thrust::pair::value_type>; auto unique_key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(static_cast(0)), @@ -676,16 +675,6 @@ SegmentedUnique(const thrust::detail::execution_policy_base &exec return n_uniques; } -template >::value == 7> - * = nullptr> -size_t SegmentedUnique(Inputs &&...inputs) { - dh::XGBCachingDeviceAllocator alloc; - return SegmentedUnique(thrust::cuda::par(alloc), - std::forward(inputs)..., - thrust::equal_to{}); -} - /** * \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`. * @@ -793,21 +782,6 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op, #endif } -template -void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) { - // We loop over batches because thrust::copy_if can't deal with sizes > 2^31 - // See thrust issue #1302, XGBoost #6822 - size_t constexpr kMaxCopySize = std::numeric_limits::max() / 2; - size_t length = std::distance(in_first, in_second); - XGBCachingDeviceAllocator alloc; - for (size_t offset = 0; offset < length; offset += kMaxCopySize) { - auto begin_input = in_first + offset; - auto end_input = in_first + std::min(offset + kMaxCopySize, length); - out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input, - end_input, out_first, pred); - } -} - template void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) { InclusiveScan(d_in, d_out, cub::Sum(), num_items); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 3bf4047e2..f81e2116c 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro return std::min(sketch_batch_num_elements, kIntMax); } -void SortByWeight(dh::device_vector* weights, dh::device_vector* sorted_entries) { +void SortByWeight(Context const* ctx, dh::device_vector* weights, + dh::device_vector* sorted_entries) { // Sort both entries and wegihts. - dh::XGBDeviceAllocator alloc; + auto cuctx = ctx->CUDACtx(); CHECK_EQ(weights->size(), sorted_entries->size()); - thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(), - weights->begin(), detail::EntryCompareOp()); + thrust::sort_by_key(cuctx->TP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(), + detail::EntryCompareOp()); // Scan weights - dh::XGBCachingDeviceAllocator caching; thrust::inclusive_scan_by_key( - thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(), + cuctx->CTP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(), weights->begin(), [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); } -void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, +void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info, + Span d_cuts_ptr, dh::device_vector* p_sorted_entries, dh::device_vector* p_sorted_weights, dh::caching_device_vector* p_column_sizes_scan) { - info.feature_types.SetDevice(device); + info.feature_types.SetDevice(ctx->Device()); auto d_feature_types = info.feature_types.ConstDeviceSpan(); CHECK(!d_feature_types.empty()); auto& column_sizes_scan = *p_column_sizes_scan; @@ -142,30 +143,32 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span(l); - Entry const& re = thrust::get<0>(r); - if (le.index == re.index && IsCat(d_feature_types, le.index)) { - return le.fvalue == re.fvalue; - } - return false; - }); + n_uniques = + dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(), + column_sizes_scan.data().get() + column_sizes_scan.size(), val_in_it, + val_in_it + sorted_entries.size(), new_column_scan.data().get(), + val_out_it, [=] __device__(Pair const& l, Pair const& r) { + Entry const& le = thrust::get<0>(l); + Entry const& re = thrust::get<0>(r); + if (le.index == re.index && IsCat(d_feature_types, le.index)) { + return le.fvalue == re.fvalue; + } + return false; + }); p_sorted_weights->resize(n_uniques); } else { - n_uniques = dh::SegmentedUnique( - column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(), - sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(), - sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) { - if (l.index == r.index) { - if (IsCat(d_feature_types, l.index)) { - return l.fvalue == r.fvalue; - } - } - return false; - }); + n_uniques = dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(), + column_sizes_scan.data().get() + column_sizes_scan.size(), + sorted_entries.begin(), sorted_entries.end(), + new_column_scan.data().get(), sorted_entries.begin(), + [=] __device__(Entry const& l, Entry const& r) { + if (l.index == r.index) { + if (IsCat(d_feature_types, l.index)) { + return l.fvalue == r.fvalue; + } + } + return false; + }); } sorted_entries.resize(n_uniques); @@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, SpanCUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(), d_cuts_ptr.data()); } } // namespace detail @@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c std::size_t ridx = dh::SegmentId(row_ptrs, element_idx); d_temp_weight[idx] = sample_weight[ridx + base_rowid]; }); - detail::SortByWeight(&entry_weight, &sorted_entries); + detail::SortByWeight(ctx, &entry_weight, &sorted_entries); } else { thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); @@ -238,13 +241,13 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple { return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. }); - detail::GetColumnSizesScan(ctx->Device(), info.num_col_, num_cuts_per_feature, + detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature, IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); if (sketch_container->HasCategorical()) { auto p_weight = entry_weight.empty() ? nullptr : &entry_weight; - detail::RemoveDuplicatedCategories(ctx->Device(), info, d_cuts_ptr, &sorted_entries, p_weight, + detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight, &column_sizes_scan); } @@ -252,7 +255,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); // Add cuts into sketches - sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, + sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, h_cuts_ptr.back(), dh::ToSpan(entry_weight)); sorted_entries.clear(); diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index cf1043ddb..416a0be9e 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost contributors + * Copyright 2020-2024, XGBoost contributors * * \brief Front end and utilities for GPU based sketching. Works on sliding window * instead of stream. @@ -13,6 +13,8 @@ #include // for size_t #include "../data/adapter.h" // for IsValidFunctor +#include "algorithm.cuh" // for CopyIf +#include "cuda_context.cuh" // for CUDAContext #include "device_helpers.cuh" #include "hist_util.h" #include "quantile.cuh" @@ -107,9 +109,10 @@ std::uint32_t EstimateGridSize(DeviceOrd device, Kernel kernel, std::size_t shar * \param out_column_size Output buffer for the size of each column. */ template -void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan batch_iter, - data::IsValidFunctor is_valid, Span out_column_size) { - thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0); +void LaunchGetColumnSizeKernel(CUDAContext const* cuctx, DeviceOrd device, + IterSpan batch_iter, data::IsValidFunctor is_valid, + Span out_column_size) { + thrust::fill_n(cuctx->CTP(), dh::tbegin(out_column_size), out_column_size.size(), 0); std::size_t max_shared_memory = dh::MaxSharedMemory(device.ordinal); // Not strictly correct as we should use number of samples to determine the type of @@ -135,17 +138,17 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan batch_iter, CHECK(!force_use_u64); auto kernel = GetColumnSizeSharedMemKernel; auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); - dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}( + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}( kernel, batch_iter, is_valid, out_column_size); } else { auto kernel = GetColumnSizeSharedMemKernel; auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); - dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}( + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, cuctx->Stream()}( kernel, batch_iter, is_valid, out_column_size); } } else { auto d_out_column_size = out_column_size; - dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) { + dh::LaunchN(batch_iter.size(), cuctx->Stream(), [=] __device__(size_t idx) { auto e = batch_iter[idx]; if (is_valid(e)) { atomicAdd(&d_out_column_size[e.column_idx], static_cast(1)); @@ -155,26 +158,26 @@ void LaunchGetColumnSizeKernel(DeviceOrd device, IterSpan batch_iter, } template -void GetColumnSizesScan(DeviceOrd device, size_t num_columns, std::size_t num_cuts_per_feature, - IterSpan batch_iter, data::IsValidFunctor is_valid, +void GetColumnSizesScan(CUDAContext const* cuctx, DeviceOrd device, size_t num_columns, + std::size_t num_cuts_per_feature, IterSpan batch_iter, + data::IsValidFunctor is_valid, HostDeviceVector* cuts_ptr, dh::caching_device_vector* column_sizes_scan) { column_sizes_scan->resize(num_columns + 1); cuts_ptr->SetDevice(device); cuts_ptr->Resize(num_columns + 1, 0); - dh::XGBCachingDeviceAllocator alloc; auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan); - LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan); + LaunchGetColumnSizeKernel(cuctx, device, batch_iter, is_valid, d_column_sizes_scan); // Calculate cuts CSC pointer auto cut_ptr_it = dh::MakeTransformIterator( column_sizes_scan->begin(), [=] __device__(size_t column_size) { return thrust::min(num_cuts_per_feature, column_size); }); - thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it, + thrust::exclusive_scan(cuctx->CTP(), cut_ptr_it, cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer()); - thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), - column_sizes_scan->end(), column_sizes_scan->begin()); + thrust::exclusive_scan(cuctx->CTP(), column_sizes_scan->begin(), column_sizes_scan->end(), + column_sizes_scan->begin()); } inline size_t constexpr BytesPerElement(bool has_weight) { @@ -215,9 +218,9 @@ size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz, // Count the valid entries in each column and copy them out. template -void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range, - float missing, size_t columns, size_t cuts_per_feature, - DeviceOrd device, +void MakeEntriesFromAdapter(CUDAContext const* cuctx, AdapterBatch const& batch, + BatchIter batch_iter, Range1d range, float missing, size_t columns, + size_t cuts_per_feature, DeviceOrd device, HostDeviceVector* cut_sizes_scan, dh::caching_device_vector* column_sizes_scan, dh::device_vector* sorted_entries) { @@ -229,19 +232,20 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran auto span = IterSpan{batch_iter + range.begin(), n}; data::IsValidFunctor is_valid(missing); // Work out how many valid entries we have in each column - GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan, + GetColumnSizesScan(cuctx, device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan, column_sizes_scan); size_t num_valid = column_sizes_scan->back(); // Copy current subset of valid elements into temporary storage and sort sorted_entries->resize(num_valid); - dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(), - is_valid); + CopyIf(cuctx, entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(), + is_valid); } -void SortByWeight(dh::device_vector* weights, +void SortByWeight(Context const* ctx, dh::device_vector* weights, dh::device_vector* sorted_entries); -void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, +void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info, + Span d_cuts_ptr, dh::device_vector* p_sorted_entries, dh::device_vector* p_sorted_weights, dh::caching_device_vector* p_column_sizes_scan); @@ -278,10 +282,9 @@ inline HistogramCuts DeviceSketch(Context const* ctx, DMatrix* p_fmat, bst_bin_t } template -void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, - DeviceOrd device, size_t columns, size_t begin, size_t end, - float missing, SketchContainer *sketch_container, - int num_cuts) { +void ProcessSlidingWindow(Context const* ctx, AdapterBatch const& batch, MetaInfo const& info, + size_t columns, size_t begin, size_t end, float missing, + SketchContainer* sketch_container, int num_cuts) { // Copy current subset of valid elements into temporary storage and sort dh::device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; @@ -289,54 +292,45 @@ void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); HostDeviceVector cuts_ptr; - cuts_ptr.SetDevice(device); - detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, - columns, num_cuts, device, - &cuts_ptr, - &column_sizes_scan, - &sorted_entries); - dh::XGBDeviceAllocator alloc; - thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), detail::EntryCompareOp()); + cuts_ptr.SetDevice(ctx->Device()); + CUDAContext const* cuctx = ctx->CUDACtx(); + detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, num_cuts, + ctx->Device(), &cuts_ptr, &column_sizes_scan, &sorted_entries); + thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); if (sketch_container->HasCategorical()) { auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, nullptr, + detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, nullptr, &column_sizes_scan); } auto d_cuts_ptr = cuts_ptr.DeviceSpan(); auto const &h_cuts_ptr = cuts_ptr.HostVector(); // Extract the cuts from all columns concurrently - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, + sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, h_cuts_ptr.back()); sorted_entries.clear(); sorted_entries.shrink_to_fit(); } template -void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, - int num_cuts_per_feature, - bool is_ranking, float missing, DeviceOrd device, - size_t columns, size_t begin, size_t end, - SketchContainer *sketch_container) { - dh::XGBCachingDeviceAllocator alloc; +void ProcessWeightedSlidingWindow(Context const* ctx, Batch batch, MetaInfo const& info, + int num_cuts_per_feature, bool is_ranking, float missing, + DeviceOrd device, size_t columns, size_t begin, size_t end, + SketchContainer* sketch_container) { dh::safe_cuda(cudaSetDevice(device.ordinal)); info.weights_.SetDevice(device); auto weights = info.weights_.ConstDeviceSpan(); auto batch_iter = dh::MakeTransformIterator( - thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) { return batch.GetElement(idx); }); + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); + auto cuctx = ctx->CUDACtx(); dh::device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; HostDeviceVector cuts_ptr; - detail::MakeEntriesFromAdapter(batch, batch_iter, - {begin, end}, missing, - columns, num_cuts_per_feature, device, - &cuts_ptr, - &column_sizes_scan, + detail::MakeEntriesFromAdapter(cuctx, batch, batch_iter, {begin, end}, missing, columns, + num_cuts_per_feature, device, &cuts_ptr, &column_sizes_scan, &sorted_entries); data::IsValidFunctor is_valid(missing); @@ -355,7 +349,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx); return weights[group_idx]; }); - auto retit = thrust::copy_if(thrust::cuda::par(alloc), + auto retit = thrust::copy_if(cuctx->CTP(), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output @@ -368,7 +362,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, [=]__device__(size_t idx) -> float { return weights[batch.GetElement(idx).row_idx]; }); - auto retit = thrust::copy_if(thrust::cuda::par(alloc), + auto retit = thrust::copy_if(cuctx->CTP(), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output @@ -376,11 +370,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } - detail::SortByWeight(&temp_weights, &sorted_entries); + detail::SortByWeight(ctx, &temp_weights, &sorted_entries); if (sketch_container->HasCategorical()) { auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, &temp_weights, + detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, &temp_weights, &column_sizes_scan); } @@ -388,8 +382,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, auto d_cuts_ptr = cuts_ptr.DeviceSpan(); // Extract cuts - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, + sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr, h_cuts_ptr.back(), dh::ToSpan(temp_weights)); sorted_entries.clear(); sorted_entries.shrink_to_fit(); @@ -407,8 +400,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, * testing. */ template -void AdapterDeviceSketch(Batch batch, int num_bins, - MetaInfo const& info, +void AdapterDeviceSketch(Context const* ctx, Batch batch, int num_bins, MetaInfo const& info, float missing, SketchContainer* sketch_container, size_t sketch_batch_num_elements = 0) { size_t num_rows = batch.NumRows(); @@ -419,27 +411,24 @@ void AdapterDeviceSketch(Batch batch, int num_bins, if (weighted) { sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), + sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), device.ordinal, num_cuts_per_feature, true); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); - ProcessWeightedSlidingWindow(batch, info, - num_cuts_per_feature, - HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end, - sketch_container); + ProcessWeightedSlidingWindow(ctx, batch, info, num_cuts_per_feature, + HostSketchContainer::UseGroup(info), missing, device, num_cols, + begin, end, sketch_container); } } else { sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), + sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), device.ordinal, num_cuts_per_feature, false); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); - ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing, - sketch_container, num_cuts_per_feature); + ProcessSlidingWindow(ctx, batch, info, num_cols, begin, end, missing, sketch_container, + num_cuts_per_feature); } } } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index d807bd7af..295206f0a 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -18,6 +18,8 @@ #include "../collective/communicator-inl.h" // for GetWorldSize, GetRank #include "categorical.h" #include "common.h" +#include "cuda_context.cuh" // for CUDAContext +#include "cuda_rt_utils.h" // for SetDevice #include "device_helpers.cuh" #include "hist_util.h" #include "quantile.cuh" @@ -117,6 +119,7 @@ void CopyTo(Span out, Span src) { // Compute the merge path. common::Span> MergePath( + Context const* ctx, Span const &d_x, Span const &x_ptr, Span const &d_y, Span const &y_ptr, Span out, Span out_ptr) { @@ -142,13 +145,12 @@ common::Span> MergePath( auto y_merge_val_it = thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder)); - dh::XGBCachingDeviceAllocator alloc; static_assert(sizeof(Tuple) == sizeof(SketchEntry)); // We reuse the memory for storing merge path. common::Span merge_path{reinterpret_cast(out.data()), out.size()}; // Determine the merge path, 0 if element is from x, 1 if it's from y. thrust::merge_by_key( - thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(), + ctx->CUDACtx()->CTP(), x_merge_key_it, x_merge_key_it + d_x.size(), y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it, y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(), [=] __device__(auto const &l, auto const &r) -> bool { @@ -163,10 +165,9 @@ common::Span> MergePath( // Compute output ptr auto transform_it = thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data())); - thrust::transform( - thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(), - out_ptr.data(), - [] __device__(auto const& t) { return thrust::get<0>(t) + thrust::get<1>(t); }); + thrust::transform(ctx->CUDACtx()->CTP(), transform_it, transform_it + x_ptr.size(), + out_ptr.data(), + [] __device__(auto const &t) { return thrust::get<0>(t) + thrust::get<1>(t); }); // 0^th is the indicator, 1^th is placeholder auto get_ind = []XGBOOST_DEVICE(Tuple const& t) { return thrust::get<0>(t); }; @@ -194,7 +195,7 @@ common::Span> MergePath( // is landed into output as the first element in merge result. The scan result is the // subscript of x and y. thrust::exclusive_scan_by_key( - thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(), + ctx->CUDACtx()->CTP(), scan_key_it, scan_key_it + merge_path.size(), scan_val_it, merge_path.data(), thrust::make_tuple(0ul, 0ul), thrust::equal_to{}, @@ -209,18 +210,17 @@ common::Span> MergePath( // summary does the output element come from) result by definition of merged rank. So we // run it in 2 passes to obtain the merge path and then customize the standard merge // algorithm. -void MergeImpl(DeviceOrd device, Span const &d_x, +void MergeImpl(Context const *ctx, Span const &d_x, Span const &x_ptr, Span const &d_y, Span const &y_ptr, Span out, Span out_ptr) { - dh::safe_cuda(cudaSetDevice(device.ordinal)); CHECK_EQ(d_x.size() + d_y.size(), out.size()); CHECK_EQ(x_ptr.size(), out_ptr.size()); CHECK_EQ(y_ptr.size(), out_ptr.size()); - auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr); + auto d_merge_path = MergePath(ctx, d_x, x_ptr, d_y, y_ptr, out, out_ptr); auto d_out = out; - dh::LaunchN(d_out.size(), [=] __device__(size_t idx) { + dh::LaunchN(d_out.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { auto column_id = dh::SegmentId(out_ptr, idx); idx -= out_ptr[column_id]; @@ -307,10 +307,9 @@ void MergeImpl(DeviceOrd device, Span const &d_x, }); } -void SketchContainer::Push(Span entries, Span columns_ptr, - common::Span cuts_ptr, - size_t total_cuts, Span weights) { - dh::safe_cuda(cudaSetDevice(device_.ordinal)); +void SketchContainer::Push(Context const *ctx, Span entries, Span columns_ptr, + common::Span cuts_ptr, size_t total_cuts, Span weights) { + common::SetDevice(device_.ordinal); Span out; dh::device_vector cuts; bool first_window = this->Current().empty(); @@ -346,12 +345,12 @@ void SketchContainer::Push(Span entries, Span columns_ptr, }; // NOLINT PruneImpl(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry); } - auto n_uniques = this->ScanInput(out, cuts_ptr); + auto n_uniques = this->ScanInput(ctx, out, cuts_ptr); if (!first_window) { CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); out = out.subspan(0, n_uniques); - this->Merge(cuts_ptr, out); + this->Merge(ctx, cuts_ptr, out); this->FixError(); } else { this->Current().resize(n_uniques); @@ -363,7 +362,8 @@ void SketchContainer::Push(Span entries, Span columns_ptr, } } -size_t SketchContainer::ScanInput(Span entries, Span d_columns_ptr_in) { +size_t SketchContainer::ScanInput(Context const *ctx, Span entries, + Span d_columns_ptr_in) { /* There are 2 types of duplication. First is duplicated feature values, which comes * from user input data. Second is duplicated sketching entries, which is generated by * pruning or merging. We preserve the first type and remove the second type. @@ -371,7 +371,6 @@ size_t SketchContainer::ScanInput(Span entries, Span d_col timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_.ordinal)); CHECK_EQ(d_columns_ptr_in.size(), num_columns_ + 1); - dh::XGBCachingDeviceAllocator alloc; auto key_it = dh::MakeTransformIterator( thrust::make_reverse_iterator(thrust::make_counting_iterator(entries.size())), @@ -381,7 +380,7 @@ size_t SketchContainer::ScanInput(Span entries, Span d_col // Reverse scan to accumulate weights into first duplicated element on left. auto val_it = thrust::make_reverse_iterator(dh::tend(entries)); thrust::inclusive_scan_by_key( - thrust::cuda::par(alloc), key_it, key_it + entries.size(), + ctx->CUDACtx()->CTP(), key_it, key_it + entries.size(), val_it, val_it, thrust::equal_to{}, [] __device__(SketchEntry const &r, SketchEntry const &l) { @@ -396,18 +395,18 @@ size_t SketchContainer::ScanInput(Span entries, Span d_col auto d_columns_ptr_out = columns_ptr_b_.DeviceSpan(); // thrust unique_by_key preserves the first element. - auto n_uniques = dh::SegmentedUnique( - d_columns_ptr_in.data(), - d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(), - entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(), - detail::SketchUnique{}); + auto n_uniques = + dh::SegmentedUnique(ctx->CUDACtx()->CTP(), d_columns_ptr_in.data(), + d_columns_ptr_in.data() + d_columns_ptr_in.size(), entries.data(), + entries.data() + entries.size(), d_columns_ptr_out.data(), entries.data(), + detail::SketchUnique{}); CopyTo(d_columns_ptr_in, d_columns_ptr_out); timer_.Stop(__func__); return n_uniques; } -void SketchContainer::Prune(size_t to) { +void SketchContainer::Prune(Context const* ctx, std::size_t to) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_.ordinal)); @@ -438,19 +437,19 @@ void SketchContainer::Prune(size_t to) { this->columns_ptr_.Copy(columns_ptr_b_); this->Alternate(); - this->Unique(); + this->Unique(ctx); timer_.Stop(__func__); } -void SketchContainer::Merge(Span d_that_columns_ptr, +void SketchContainer::Merge(Context const *ctx, Span d_that_columns_ptr, Span that) { - dh::safe_cuda(cudaSetDevice(device_.ordinal)); + common::SetDevice(device_.ordinal); timer_.Start(__func__); if (this->Current().size() == 0) { CHECK_EQ(this->columns_ptr_.HostVector().back(), 0); CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size()); CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1); - thrust::copy(thrust::device, d_that_columns_ptr.data(), + thrust::copy(ctx->CUDACtx()->CTP(), d_that_columns_ptr.data(), d_that_columns_ptr.data() + d_that_columns_ptr.size(), this->columns_ptr_.DevicePointer()); auto total = this->columns_ptr_.HostVector().back(); @@ -463,7 +462,7 @@ void SketchContainer::Merge(Span d_that_columns_ptr, this->Other().resize(this->Current().size() + that.size()); CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size()); - MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr, + MergeImpl(ctx, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr, dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan()); this->columns_ptr_.Copy(columns_ptr_b_); CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1); @@ -471,7 +470,7 @@ void SketchContainer::Merge(Span d_that_columns_ptr, if (this->HasCategorical()) { auto d_feature_types = this->FeatureTypes().ConstDeviceSpan(); - this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) { + this->Unique(ctx, [d_feature_types] __device__(size_t l_fidx, size_t r_fidx) { return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx); }); } @@ -517,7 +516,7 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { SafeColl(rc); bst_idx_t intermediate_num_cuts = std::min(global_sum_rows, static_cast(num_bins_ * kFactor)); - this->Prune(intermediate_num_cuts); + this->Prune(ctx, intermediate_num_cuts); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); @@ -570,9 +569,8 @@ void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { for (size_t i = 0; i < allworkers.size(); ++i) { auto worker = allworkers[i]; auto worker_ptr = - dh::ToSpan(gathered_ptrs) - .subspan(i * d_columns_ptr.size(), d_columns_ptr.size()); - new_sketch.Merge(worker_ptr, worker); + dh::ToSpan(gathered_ptrs).subspan(i * d_columns_ptr.size(), d_columns_ptr.size()); + new_sketch.Merge(ctx, worker_ptr, worker); new_sketch.FixError(); } @@ -602,7 +600,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i this->AllReduce(ctx, is_column_split); // Prune to final number of bins. - this->Prune(num_bins_ + 1); + this->Prune(ctx, num_bins_ + 1); this->FixError(); // Set up inputs @@ -624,7 +622,6 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i std::vector max_values; float max_cat{-1.f}; if (has_categorical_) { - dh::XGBCachingDeviceAllocator alloc; auto key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) -> bst_feature_t { return dh::SegmentId(d_in_columns_ptr, i); @@ -651,7 +648,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i dh::caching_device_vector d_max_keys(d_in_columns_ptr.size() - 1); dh::caching_device_vector d_max_values(d_in_columns_ptr.size() - 1); auto new_end = thrust::reduce_by_key( - thrust::cuda::par(alloc), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(), + ctx->CUDACtx()->CTP(), key_it, key_it + in_cut_values.size(), val_it, d_max_keys.begin(), d_max_values.begin(), thrust::equal_to{}, [] __device__(auto l, auto r) { return l.value > r.value ? l : r; }); d_max_keys.erase(new_end.first, d_max_keys.end()); @@ -661,7 +658,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i SketchEntry default_entry{}; dh::caching_device_vector d_max_results(d_in_columns_ptr.size() - 1, default_entry); - thrust::scatter(thrust::cuda::par(alloc), d_max_values.begin(), d_max_values.end(), + thrust::scatter(ctx->CUDACtx()->CTP(), d_max_values.begin(), d_max_values.end(), d_max_keys.begin(), d_max_results.begin()); dh::CopyDeviceSpanToVector(&max_values, dh::ToSpan(d_max_results)); auto max_it = MakeIndexTransformIter([&](auto i) { diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index ae286c3b3..1bd1672eb 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -7,6 +7,7 @@ #include // for any_of #include "categorical.h" +#include "cuda_context.cuh" // for CUDAContext #include "device_helpers.cuh" #include "error_msg.h" // for InvalidMaxBin #include "quantile.h" @@ -127,7 +128,7 @@ class SketchContainer { /* \brief Whether the predictor matrix contains categorical features. */ bool HasCategorical() const { return has_categorical_; } /* \brief Accumulate weights of duplicated entries in input. */ - size_t ScanInput(Span entries, Span d_columns_ptr_in); + size_t ScanInput(Context const* ctx, Span entries, Span d_columns_ptr_in); /* Fix rounding error and re-establish invariance. The error is mostly generated by the * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ void FixError(); @@ -140,19 +141,18 @@ class SketchContainer { * \param total_cuts Total number of cuts, equal to the back of cuts_ptr. * \param weights (optional) data weights. */ - void Push(Span entries, Span columns_ptr, - common::Span cuts_ptr, size_t total_cuts, - Span weights = {}); + void Push(Context const* ctx, Span entries, Span columns_ptr, + common::Span cuts_ptr, size_t total_cuts, Span weights = {}); /* \brief Prune the quantile structure. * * \param to The maximum size of pruned quantile. If the size of quantile * structure is already less than `to`, then no operation is performed. */ - void Prune(size_t to); + void Prune(Context const* ctx, size_t to); /* \brief Merge another set of sketch. * \param that columns of other. */ - void Merge(Span that_columns_ptr, + void Merge(Context const* ctx, Span that_columns_ptr, Span that); /* \brief Merge quantiles from other GPU workers. */ @@ -175,7 +175,7 @@ class SketchContainer { /* \brief Removes all the duplicated elements in quantile structure. */ template > - size_t Unique(KeyComp key_comp = thrust::equal_to{}) { + size_t Unique(Context const* ctx, KeyComp key_comp = thrust::equal_to{}) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_.ordinal)); this->columns_ptr_.SetDevice(device_); @@ -185,14 +185,12 @@ class SketchContainer { HostDeviceVector scan_out(d_column_scan.size()); scan_out.SetDevice(device_); auto d_scan_out = scan_out.DeviceSpan(); - dh::XGBCachingDeviceAllocator alloc; d_column_scan = this->columns_ptr_.DeviceSpan(); size_t n_uniques = dh::SegmentedUnique( - thrust::cuda::par(alloc), d_column_scan.data(), - d_column_scan.data() + d_column_scan.size(), entries.data(), - entries.data() + entries.size(), scan_out.DevicePointer(), - entries.data(), detail::SketchUnique{}, key_comp); + ctx->CUDACtx()->CTP(), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(), + entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), entries.data(), + detail::SketchUnique{}, key_comp); this->columns_ptr_.Copy(scan_out); CHECK(!this->columns_ptr_.HostCanRead()); diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 221e13fb3..97a339cac 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -11,6 +11,7 @@ #include // for invoke_result_t, declval #include // for vector +#include "../common/cuda_rt_utils.h" // for xgboost_NVTX_FN_RANGE #include "adapter.h" #include "xgboost/c_api.h" #include "xgboost/context.h" @@ -36,6 +37,8 @@ class DataIterProxy { DataIterProxy& operator=(DataIterProxy const& that) = default; [[nodiscard]] bool Next() { + xgboost_NVTX_FN_RANGE(); + bool ret = !!next_(iter_); if (!ret) { return ret; diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu index f90ca882f..04db88405 100644 --- a/src/data/quantile_dmatrix.cu +++ b/src/data/quantile_dmatrix.cu @@ -30,14 +30,13 @@ void MakeSketches(Context const* ctx, ExternalDataInfo* p_ext_info) { xgboost_NVTX_FN_RANGE(); - CUDAContext const* cuctx = ctx->CUDACtx(); std::unique_ptr sketch; auto& ext_info = *p_ext_info; do { // We use do while here as the first batch is fetched in ctor CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs()); - dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); + common::SetDevice(dh::GetDevice(ctx).ordinal); if (ext_info.n_features == 0) { ext_info.n_features = data::BatchColumns(proxy); auto rc = collective::Allreduce(ctx, linalg::MakeVec(&ext_info.n_features, 1), @@ -55,7 +54,16 @@ void MakeSketches(Context const* ctx, } proxy->Info().weights_.SetDevice(dh::GetDevice(ctx)); cuda_impl::Dispatch(proxy, [&](auto const& value) { - common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, sketch.get()); + // Workaround empty input with CPU ctx. + Context new_ctx; + Context const* p_ctx; + if (ctx->IsCUDA()) { + p_ctx = ctx; + } else { + new_ctx.UpdateAllowUnknown(Args{{"device", dh::GetDevice(ctx).Name()}}); + p_ctx = &new_ctx; + } + common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, sketch.get()); }); } auto batch_rows = data::BatchSamples(proxy); @@ -66,7 +74,7 @@ void MakeSketches(Context const* ctx, std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); })); - ext_info.nnz += thrust::reduce(cuctx->CTP(), row_counts.begin(), row_counts.end()); + ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end()); ext_info.n_batches++; ext_info.base_rows.push_back(batch_rows); } while (iter->Next()); @@ -77,7 +85,7 @@ void MakeSketches(Context const* ctx, ext_info.base_rows.begin()); // Get reference - dh::safe_cuda(cudaSetDevice(dh::GetDevice(ctx).ordinal)); + common::SetDevice(dh::GetDevice(ctx).ordinal); if (!ref) { sketch->MakeCuts(ctx, cuts.get(), info.IsColumnSplit()); } else { diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index e3c241886..0b34be44d 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -11,6 +11,7 @@ #include "../common/device_helpers.cuh" #include "../common/error_msg.h" // for InfInData +#include "../common/algorithm.cuh" // for CopyIf #include "device_adapter.cuh" // for NoInfInData namespace xgboost::data { @@ -27,16 +28,15 @@ struct COOToEntryOp { // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataToDMatrix(AdapterBatchT batch, common::Span data, - float missing) { +void CopyDataToDMatrix(AdapterBatchT batch, common::Span data, float missing) { auto counting = thrust::make_counting_iterator(0llu); - dh::XGBCachingDeviceAllocator alloc; COOToEntryOp transform_op{batch}; - thrust::transform_iterator - transform_iter(counting, transform_op); + thrust::transform_iterator transform_iter( + counting, transform_op); auto begin_output = thrust::device_pointer_cast(data.data()); - dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output, - IsValidFunctor(missing)); + auto ctx = Context{}.MakeCUDA(dh::CurrentDevice()); + common::CopyIf(ctx.CUDACtx(), transform_iter, transform_iter + batch.Size(), begin_output, + IsValidFunctor(missing)); } template diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 4178e55d8..169516c67 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -9,8 +9,10 @@ #include #include +#include "../../../src/common/cuda_context.cuh" #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" +#include "../helpers.h" #include "gtest/gtest.h" TEST(SumReduce, Test) { @@ -61,11 +63,11 @@ TEST(SegmentedUnique, Basic) { thrust::device_vector d_segs_out(d_segments.size()); thrust::device_vector d_vals_out(d_values.size()); + auto ctx = xgboost::MakeCUDACtx(0); size_t n_uniques = dh::SegmentedUnique( - d_segments.data().get(), d_segments.data().get() + d_segments.size(), - d_values.data().get(), d_values.data().get() + d_values.size(), - d_segs_out.data().get(), d_vals_out.data().get(), - thrust::equal_to{}); + ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(), + d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(), + d_vals_out.data().get(), thrust::equal_to{}); CHECK_EQ(n_uniques, 5); std::vector values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f}; @@ -81,10 +83,9 @@ TEST(SegmentedUnique, Basic) { d_segments[1] = 4; d_segments[2] = 6; n_uniques = dh::SegmentedUnique( - d_segments.data().get(), d_segments.data().get() + d_segments.size(), - d_values.data().get(), d_values.data().get() + d_values.size(), - d_segs_out.data().get(), d_vals_out.data().get(), - thrust::equal_to{}); + ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(), + d_values.data().get(), d_values.data().get() + d_values.size(), d_segs_out.data().get(), + d_vals_out.data().get(), thrust::equal_to{}); ASSERT_EQ(n_uniques, values.size()); for (size_t i = 0 ; i < values.size(); i ++) { ASSERT_EQ(d_vals_out[i], values[i]); @@ -113,10 +114,12 @@ void TestSegmentedUniqueRegression(std::vector values, size_t n_dup thrust::device_vector d_segments(segments); thrust::device_vector d_segments_out(segments.size()); + auto ctx = xgboost::MakeCUDACtx(0); + size_t n_uniques = dh::SegmentedUnique( - d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(), - d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(), - SketchUnique{}); + ctx.CUDACtx()->CTP(), d_segments.data().get(), d_segments.data().get() + d_segments.size(), + d_values.data().get(), d_values.data().get() + d_values.size(), d_segments_out.data().get(), + d_values.data().get(), SketchUnique{}); ASSERT_EQ(n_uniques, values.size() - n_duplicated); ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(), d_values.begin() + n_uniques, IsSorted{})); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index df5ed9004..b3b77694c 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -221,8 +221,8 @@ TEST(HistUtil, RemoveDuplicatedCategories) { thrust::sort_by_key(sorted_entries.begin(), sorted_entries.end(), weight.begin(), detail::EntryCompareOp()); - detail::RemoveDuplicatedCategories(ctx.Device(), info, cuts_ptr.DeviceSpan(), &sorted_entries, - &weight, &columns_ptr); + detail::RemoveDuplicatedCategories(&ctx, info, cuts_ptr.DeviceSpan(), &sorted_entries, &weight, + &columns_ptr); auto const& h_cptr = cuts_ptr.ConstHostVector(); ASSERT_EQ(h_cptr.back(), n_samples * 2 + n_categories); @@ -367,7 +367,7 @@ auto MakeUnweightedCutsForTest(Context const* ctx, Adapter adapter, int32_t num_ SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), DeviceOrd::CUDA(0)); MetaInfo info; - AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size); + AdapterDeviceSketch(ctx, adapter.Value(), num_bins, info, missing, &sketch_container, batch_size); sketch_container.MakeCuts(ctx, &batched_cuts, info.IsColumnSplit()); return batched_cuts; } @@ -437,8 +437,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { common::HistogramCuts batched_cuts; HostDeviceVector ft; SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0)); - AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), - &sketch_container); + AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit()); size_t bytes_required = detail::RequiredMemory( @@ -466,9 +466,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { common::HistogramCuts batched_cuts; HostDeviceVector ft; SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, DeviceOrd::CUDA(0)); - AdapterDeviceSketch(adapter.Value(), num_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_container); + AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; sketch_container.MakeCuts(&ctx, &cuts, info.IsColumnSplit()); @@ -502,7 +501,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, ASSERT_EQ(info.feature_types.Size(), 1); SketchContainer container(info.feature_types, num_bins, 1, n, DeviceOrd::CUDA(0)); - AdapterDeviceSketch(adapter.Value(), num_bins, info, + AdapterDeviceSketch(&ctx, adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &container); HistogramCuts cuts; container.MakeCuts(&ctx, &cuts, info.IsColumnSplit()); @@ -616,22 +615,27 @@ void TestGetColumnSize(std::size_t n_samples) { std::vector h_column_size(column_sizes_scan.size()); std::vector h_column_size_1(column_sizes_scan.size()); + auto cuctx = ctx.CUDACtx(); detail::LaunchGetColumnSizeKernel( - ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, + dh::ToSpan(column_sizes_scan)); thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin()); detail::LaunchGetColumnSizeKernel( - ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, + dh::ToSpan(column_sizes_scan)); thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); ASSERT_EQ(h_column_size, h_column_size_1); detail::LaunchGetColumnSizeKernel( - ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, + dh::ToSpan(column_sizes_scan)); thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); ASSERT_EQ(h_column_size, h_column_size_1); detail::LaunchGetColumnSizeKernel( - ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + cuctx, ctx.Device(), IterSpan{batch_iter, batch.Size()}, is_valid, + dh::ToSpan(column_sizes_scan)); thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); ASSERT_EQ(h_column_size, h_column_size_1); } @@ -737,7 +741,7 @@ void TestAdapterSketchFromWeights(bool with_group) { auto const& batch = adapter.Value(); HostDeviceVector ft; SketchContainer sketch_container(ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)); - AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), &sketch_container); common::HistogramCuts cuts; @@ -780,7 +784,7 @@ void TestAdapterSketchFromWeights(bool with_group) { h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast(kGroups); } SketchContainer sketch_container{ft, kBins, kCols, kRows, DeviceOrd::CUDA(0)}; - AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(&ctx, adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), &sketch_container); sketch_container.MakeCuts(&ctx, &weighted, info.IsColumnSplit()); ValidateCuts(weighted, dmat.get(), kBins); diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 80c9c5c71..7be12ac9c 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -24,14 +24,15 @@ namespace common { class MGPUQuantileTest : public collective::BaseMGPUTest {}; TEST(GPUQuantile, Basic) { + auto ctx = MakeCUDACtx(0); constexpr size_t kRows = 1000, kCols = 100, kBins = 256; HostDeviceVector ft; - SketchContainer sketch(ft, kBins, kCols, kRows, FstCU()); + SketchContainer sketch(ft, kBins, kCols, kRows, ctx.Device()); dh::caching_device_vector entries; dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); // Push empty - sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); + sketch.Push(&ctx, dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); ASSERT_EQ(sketch.Data().size(), 0); } @@ -39,16 +40,17 @@ void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { + auto ctx = MakeCUDACtx(0); HostDeviceVector ft; - SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); + SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device()); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity} .Seed(seed) - .Device(FstCU()) + .Device(ctx.Device()) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch); auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); @@ -60,8 +62,9 @@ void TestSketchUnique(float sparsity) { thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); auto end = kCols * kRows; - detail::GetColumnSizesScan(FstCU(), kCols, n_cuts, IterSpan{batch_iter, end}, is_valid, - &cut_sizes_scan, &column_sizes_scan); + detail::GetColumnSizesScan(ctx.CUDACtx(), ctx.Device(), kCols, n_cuts, + IterSpan{batch_iter, end}, is_valid, &cut_sizes_scan, + &column_sizes_scan); auto const& cut_sizes = cut_sizes_scan.HostVector(); ASSERT_LE(sketch.Data().size(), cut_sizes.back()); @@ -69,7 +72,7 @@ void TestSketchUnique(float sparsity) { dh::CopyDeviceSpanToVector(&h_columns_ptr, sketch.ColumnsPtr()); ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back()); - sketch.Unique(); + sketch.Unique(&ctx); std::vector h_data(sketch.Data().size()); thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin()); @@ -124,44 +127,46 @@ void TestQuantileElemRank(DeviceOrd device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { + auto ctx = MakeCUDACtx(0); HostDeviceVector ft; - SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); + SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device()); HostDeviceVector storage; - std::string interface_str = - RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface( - &storage); + std::string interface_str = RandomDataGenerator{kRows, kCols, 0} + .Device(ctx.Device()) + .Seed(seed) + .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch); auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows); // LE because kRows * kCols is pushed into sketch, after removing // duplicated entries we might not have that much inputs for prune. ASSERT_LE(sketch.Data().size(), n_cuts * kCols); - sketch.Prune(n_bins); + sketch.Prune(&ctx, n_bins); ASSERT_LE(sketch.Data().size(), kRows * kCols); // This is not necessarily true for all inputs without calling unique after // prune. ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(), sketch.Data().data() + sketch.Data().size(), detail::SketchUnique{})); - TestQuantileElemRank(FstCU(), sketch.Data(), sketch.ColumnsPtr()); + TestQuantileElemRank(ctx.Device(), sketch.Data(), sketch.ColumnsPtr()); }); } TEST(GPUQuantile, MergeEmpty) { constexpr size_t kRows = 1000, kCols = 100; size_t n_bins = 10; + auto ctx = MakeCUDACtx(0); HostDeviceVector ft; - SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU()); + SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device()); HostDeviceVector storage_0; std::string interface_str_0 = - RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).GenerateArrayInterface( - &storage_0); + RandomDataGenerator{kRows, kCols, 0}.Device(ctx.Device()).GenerateArrayInterface(&storage_0); data::CupyAdapter adapter_0(interface_str_0); MetaInfo info; - AdapterDeviceSketch(adapter_0.Value(), n_bins, info, + AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); std::vector entries_before(sketch_0.Data().size()); @@ -170,7 +175,7 @@ TEST(GPUQuantile, MergeEmpty) { dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr()); thrust::device_vector columns_ptr(kCols + 1); // Merge an empty sketch - sketch_0.Merge(dh::ToSpan(columns_ptr), Span{}); + sketch_0.Merge(&ctx, dh::ToSpan(columns_ptr), Span{}); std::vector entries_after(sketch_0.Data().size()); dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data()); @@ -193,34 +198,36 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { + auto ctx = MakeCUDACtx(0); HostDeviceVector ft; - SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU()); + SketchContainer sketch_0(ft, n_bins, kCols, kRows, ctx.Device()); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0} - .Device(FstCU()) + .Device(ctx.Device()) .Seed(seed) .GenerateArrayInterface(&storage_0); data::CupyAdapter adapter_0(interface_str_0); - AdapterDeviceSketch(adapter_0.Value(), n_bins, info, + AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); - SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, FstCU()); + SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, ctx.Device()); HostDeviceVector storage_1; - std::string interface_str_1 = - RandomDataGenerator{kRows, kCols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface( - &storage_1); + std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0} + .Device(ctx.Device()) + .Seed(seed) + .GenerateArrayInterface(&storage_1); data::CupyAdapter adapter_1(interface_str_1); - AdapterDeviceSketch(adapter_1.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), - &sketch_1); + AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_1); size_t size_before_merge = sketch_0.Data().size(); - sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); + sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data()); if (info.weights_.Size() != 0) { - TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), true); + TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), true); sketch_0.FixError(); - TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr(), false); + TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr(), false); } else { - TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr()); + TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr()); } auto columns_ptr = sketch_0.ColumnsPtr(); @@ -228,7 +235,7 @@ TEST(GPUQuantile, MergeBasic) { dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); - sketch_0.Unique(); + sketch_0.Unique(&ctx); ASSERT_TRUE( thrust::is_sorted(thrust::device, sketch_0.Data().data(), sketch_0.Data().data() + sketch_0.Data().size(), @@ -237,25 +244,27 @@ TEST(GPUQuantile, MergeBasic) { } void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { + auto ctx = MakeCUDACtx(0); MetaInfo info; int32_t seed = 0; HostDeviceVector ft; - SketchContainer sketch_0(ft, n_bins, cols, rows, FstCU()); + SketchContainer sketch_0(ft, n_bins, cols, rows, ctx.Device()); HostDeviceVector storage_0; - std::string interface_str_0 = - RandomDataGenerator{rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface( - &storage_0); + std::string interface_str_0 = RandomDataGenerator{rows, cols, 0} + .Device(ctx.Device()) + .Seed(seed) + .GenerateArrayInterface(&storage_0); data::CupyAdapter adapter_0(interface_str_0); - AdapterDeviceSketch(adapter_0.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_0); + AdapterDeviceSketch(&ctx, adapter_0.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_0); size_t f_rows = rows * frac; - SketchContainer sketch_1(ft, n_bins, cols, f_rows, FstCU()); + SketchContainer sketch_1(ft, n_bins, cols, f_rows, ctx.Device()); HostDeviceVector storage_1; - std::string interface_str_1 = - RandomDataGenerator{f_rows, cols, 0}.Device(FstCU()).Seed(seed).GenerateArrayInterface( - &storage_1); + std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0} + .Device(ctx.Device()) + .Seed(seed) + .GenerateArrayInterface(&storage_1); auto data_1 = storage_1.DeviceSpan(); auto tuple_it = thrust::make_tuple( thrust::make_counting_iterator(0ul), data_1.data()); @@ -271,20 +280,19 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { } }); data::CupyAdapter adapter_1(interface_str_1); - AdapterDeviceSketch(adapter_1.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_1); + AdapterDeviceSketch(&ctx, adapter_1.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_1); size_t size_before_merge = sketch_0.Data().size(); - sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data()); - TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr()); + sketch_0.Merge(&ctx, sketch_1.ColumnsPtr(), sketch_1.Data()); + TestQuantileElemRank(ctx.Device(), sketch_0.Data(), sketch_0.ColumnsPtr()); auto columns_ptr = sketch_0.ColumnsPtr(); std::vector h_columns_ptr(columns_ptr.size()); dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); - sketch_0.Unique(); + sketch_0.Unique(&ctx); columns_ptr = sketch_0.ColumnsPtr(); dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); @@ -311,7 +319,8 @@ TEST(GPUQuantile, MultiMerge) { RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { // Set up single node version HostDeviceVector ft; - SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, FstCU()); + auto ctx = MakeCUDACtx(0); + SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, ctx.Device()); size_t intermediate_num_cuts = std::min( kRows * world, static_cast(n_bins * WQSketch::kFactor)); @@ -319,25 +328,26 @@ TEST(GPUQuantile, MultiMerge) { for (auto rank = 0; rank < world; ++rank) { HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} - .Device(FstCU()) + .Device(ctx.Device()) .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); HostDeviceVector ft; - containers.emplace_back(ft, n_bins, kCols, kRows, FstCU()); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &containers.back()); + containers.emplace_back(ft, n_bins, kCols, kRows, ctx.Device()); + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &containers.back()); } for (auto &sketch : containers) { - sketch.Prune(intermediate_num_cuts); - sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); + sketch.Prune(&ctx, intermediate_num_cuts); + sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } - TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr()); + TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(), + sketch_on_single_node.ColumnsPtr()); - sketch_on_single_node.Unique(); - TestQuantileElemRank(FstCU(), sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr()); + sketch_on_single_node.Unique(&ctx); + TestQuantileElemRank(ctx.Device(), sketch_on_single_node.Data(), + sketch_on_single_node.ColumnsPtr()); }); } @@ -392,15 +402,15 @@ void TestAllReduceBasic() { data::CupyAdapter adapter(interface_str); HostDeviceVector ft({}, device); containers.emplace_back(ft, n_bins, kCols, kRows, device); - AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), - &containers.back()); + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &containers.back()); } for (auto& sketch : containers) { - sketch.Prune(intermediate_num_cuts); - sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); + sketch.Prune(&ctx, intermediate_num_cuts); + sketch_on_single_node.Merge(&ctx, sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } - sketch_on_single_node.Unique(); + sketch_on_single_node.Unique(&ctx); TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(), true); @@ -416,16 +426,16 @@ void TestAllReduceBasic() { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), - &sketch_distributed); + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_distributed); if (world == 1) { auto n_samples_global = kRows * world; intermediate_num_cuts = std::min(n_samples_global, static_cast(n_bins * SketchContainer::kFactor)); - sketch_distributed.Prune(intermediate_num_cuts); + sketch_distributed.Prune(&ctx, intermediate_num_cuts); } sketch_distributed.AllReduce(&ctx, false); - sketch_distributed.Unique(); + sketch_distributed.Unique(&ctx); ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size()); ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size()); @@ -535,11 +545,10 @@ void TestSameOnAllWorkers() { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_distributed); + AdapterDeviceSketch(&ctx, adapter.Value(), n_bins, info, + std::numeric_limits::quiet_NaN(), &sketch_distributed); sketch_distributed.AllReduce(&ctx, false); - sketch_distributed.Unique(); + sketch_distributed.Unique(&ctx); TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true); // Test for all workers having the same sketch. @@ -547,16 +556,13 @@ void TestSameOnAllWorkers() { auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax); SafeColl(rc); ASSERT_EQ(n_data, sketch_distributed.Data().size()); - size_t size_as_float = - sketch_distributed.Data().size_bytes() / sizeof(float); + size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float); auto local_data = Span{ - reinterpret_cast(sketch_distributed.Data().data()), - size_as_float}; + reinterpret_cast(sketch_distributed.Data().data()), size_as_float}; dh::caching_device_vector all_workers(size_as_float * world); thrust::fill(all_workers.begin(), all_workers.end(), 0); - thrust::copy(thrust::device, local_data.data(), - local_data.data() + local_data.size(), + thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(), all_workers.begin() + local_data.size() * rank); rc = collective::Allreduce( &ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()), @@ -590,6 +596,7 @@ TEST_F(MGPUQuantileTest, SameOnAllWorkers) { TEST(GPUQuantile, Push) { size_t constexpr kRows = 100; std::vector data(kRows); + auto ctx = MakeCUDACtx(0); std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f); std::fill(data.begin() + (data.size() / 2), data.end(), 0.5f); @@ -608,8 +615,8 @@ TEST(GPUQuantile, Push) { columns_ptr[1] = kRows; HostDeviceVector ft; - SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); - sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {}); + SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device()); + sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {}); auto sketch_data = sketch.Data(); @@ -633,9 +640,9 @@ TEST(GPUQuantile, Push) { TEST(GPUQuantile, MultiColPush) { size_t constexpr kRows = 100, kCols = 4; std::vector data(kRows * kCols); - std::fill(data.begin(), data.begin() + (data.size() / 2), 0.3f); + auto ctx = MakeCUDACtx(0); std::vector entries(kRows * kCols); for (bst_feature_t c = 0; c < kCols; ++c) { @@ -648,7 +655,7 @@ TEST(GPUQuantile, MultiColPush) { int32_t n_bins = 16; HostDeviceVector ft; - SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); + SketchContainer sketch(ft, n_bins, kCols, kRows, ctx.Device()); dh::device_vector d_entries {entries}; dh::device_vector columns_ptr(kCols + 1, 0); @@ -659,8 +666,8 @@ TEST(GPUQuantile, MultiColPush) { columns_ptr.begin()); dh::device_vector cuts_ptr(columns_ptr); - sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), - dh::ToSpan(cuts_ptr), kRows * kCols, {}); + sketch.Push(&ctx, dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr), + kRows * kCols, {}); auto sketch_data = sketch.Data(); ASSERT_EQ(sketch_data.size(), kCols * 2);