diff --git a/src/collective/loop.cc b/src/collective/loop.cc index 10fce0516..5cfb0034d 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -117,11 +117,14 @@ void Loop::Process() { break; } - auto unlock_notify = [&](bool is_blocking) { + auto unlock_notify = [&](bool is_blocking, bool stop) { if (!is_blocking) { - return; + std::lock_guard guard{mu_}; + stop_ = stop; + } else { + stop_ = stop; + lock.unlock(); } - lock.unlock(); cv_.notify_one(); }; @@ -145,13 +148,14 @@ void Loop::Process() { auto rc = this->EmptyQueue(&qcopy); // Handle error if (!rc.OK()) { + unlock_notify(is_blocking, true); + std::lock_guard guard{rc_lock_}; this->rc_ = std::move(rc); - unlock_notify(is_blocking); return; } CHECK(qcopy.empty()); - unlock_notify(is_blocking); + unlock_notify(is_blocking, false); } } @@ -170,12 +174,21 @@ Result Loop::Stop() { } [[nodiscard]] Result Loop::Block() { + { + std::lock_guard guard{rc_lock_}; + if (!rc_.OK()) { + return std::move(rc_); + } + } this->Submit(Op{Op::kBlock}); { std::unique_lock lock{mu_}; cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; }); } - return std::move(rc_); + { + std::lock_guard lock{rc_lock_}; + return std::move(rc_); + } } Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { diff --git a/src/collective/loop.h b/src/collective/loop.h index 4f5cb12b3..0c1fdcbfe 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -42,7 +42,10 @@ class Loop { std::mutex mu_; std::queue queue_; std::chrono::seconds timeout_; + Result rc_; + std::mutex rc_lock_; // lock for transferring error info. + bool stop_{false}; std::exception_ptr curr_exce_{nullptr}; common::Monitor mutable timer_; diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh index 53acc65e1..5f0986d5b 100644 --- a/src/common/algorithm.cuh +++ b/src/common/algorithm.cuh @@ -23,8 +23,7 @@ #include "xgboost/logging.h" // CHECK #include "xgboost/span.h" // Span,byte -namespace xgboost { -namespace common { +namespace xgboost::common { namespace detail { // Wrapper around cub sort to define is_decending template group_ptr, template void SegmentedArgSort(Context const *ctx, Span values, Span group_ptr, Span sorted_idx) { + auto cuctx = ctx->CUDACtx(); CHECK_GE(group_ptr.size(), 1ul); std::size_t n_groups = group_ptr.size() - 1; std::size_t bytes = 0; if (per_seg_index) { SegmentedSequence(ctx, group_ptr, sorted_idx); } else { - dh::Iota(sorted_idx); + dh::Iota(sorted_idx, cuctx->Stream()); } dh::TemporaryArray> values_out(values.size()); dh::TemporaryArray> sorted_idx_out(sorted_idx.size()); @@ -141,15 +141,16 @@ void SegmentedArgSort(Context const *ctx, Span values, Span group_ptr, detail::DeviceSegmentedRadixSortPair( nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(), - group_ptr.data() + 1, ctx->CUDACtx()->Stream()); + group_ptr.data() + 1, cuctx->Stream()); dh::TemporaryArray temp_storage(bytes); detail::DeviceSegmentedRadixSortPair( temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(), - group_ptr.data() + 1, ctx->CUDACtx()->Stream()); + group_ptr.data() + 1, cuctx->Stream()); dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), - sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); + sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice, + cuctx->Stream())); } /** @@ -159,11 +160,12 @@ void SegmentedArgSort(Context const *ctx, Span values, Span group_ptr, template void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end, dh::device_vector *p_sorted_idx) { + auto cuctx = ctx->CUDACtx(); using Tup = thrust::tuple; auto &sorted_idx = *p_sorted_idx; std::size_t n = std::distance(val_begin, val_end); sorted_idx.resize(n); - dh::Iota(dh::ToSpan(sorted_idx)); + dh::Iota(dh::ToSpan(sorted_idx), cuctx->Stream()); dh::device_vector keys(sorted_idx.size()); auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> Tup { @@ -177,7 +179,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V return thrust::make_tuple(seg_idx, residue); }); thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin()); - thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(), + thrust::stable_sort_by_key(cuctx->TP(), keys.begin(), keys.end(), sorted_idx.begin(), [=] XGBOOST_DEVICE(Tup const &l, Tup const &r) { if (thrust::get<0>(l) != thrust::get<0>(r)) { return thrust::get<0>(l) < thrust::get<0>(r); // segment index @@ -185,6 +187,75 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V return thrust::get<1>(l) < thrust::get<1>(r); // residue }); } -} // namespace common -} // namespace xgboost + +template +void ArgSort(xgboost::Context const *ctx, xgboost::common::Span keys, + xgboost::common::Span sorted_idx) { + std::size_t bytes = 0; + auto cuctx = ctx->CUDACtx(); + dh::Iota(sorted_idx, cuctx->Stream()); + + using KeyT = typename decltype(keys)::value_type; + using ValueT = std::remove_const_t; + + dh::TemporaryArray out(keys.size()); + cub::DoubleBuffer d_keys(const_cast(keys.data()), out.data().get()); + dh::TemporaryArray sorted_idx_out(sorted_idx.size()); + cub::DoubleBuffer d_values(const_cast(sorted_idx.data()), + sorted_idx_out.data().get()); + + // track https://github.com/NVIDIA/cub/pull/340 for 64bit length support + using OffsetT = std::conditional_t; + CHECK_LE(sorted_idx.size(), std::numeric_limits::max()); + if (accending) { + void *d_temp_storage = nullptr; +#if THRUST_MAJOR_VERSION >= 2 + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + cuctx->Stream()))); +#else + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + nullptr, false))); +#endif + dh::TemporaryArray storage(bytes); + d_temp_storage = storage.data().get(); +#if THRUST_MAJOR_VERSION >= 2 + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + cuctx->Stream()))); +#else + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + nullptr, false))); +#endif + } else { + void *d_temp_storage = nullptr; +#if THRUST_MAJOR_VERSION >= 2 + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + cuctx->Stream()))); +#else + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + nullptr, false))); +#endif + dh::TemporaryArray storage(bytes); + d_temp_storage = storage.data().get(); +#if THRUST_MAJOR_VERSION >= 2 + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + cuctx->Stream()))); +#else + dh::safe_cuda((cub::DispatchRadixSort::Dispatch( + d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false, + nullptr, false))); +#endif + } + + dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), + sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice, + cuctx->Stream())); +} +} // namespace xgboost::common #endif // XGBOOST_COMMON_ALGORITHM_CUH_ diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 74336ac61..066f8a3e6 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -313,8 +313,8 @@ inline void LaunchN(size_t n, L lambda) { } template -void Iota(Container array) { - LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; }); +void Iota(Container array, cudaStream_t stream) { + LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; }); } namespace detail { @@ -597,6 +597,16 @@ class DoubleBuffer { T *Other() { return buff.Alternate(); } }; +template +xgboost::common::Span LazyResize(xgboost::Context const *ctx, + xgboost::HostDeviceVector *buffer, std::size_t n) { + buffer->SetDevice(ctx->Device()); + if (buffer->Size() < n) { + buffer->Resize(n); + } + return buffer->DeviceSpan().subspan(0, n); +} + /** * \brief Copies device span to std::vector. * @@ -1060,74 +1070,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) InclusiveScan(d_in, d_out, cub::Sum(), num_items); } -template -void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_idx) { - size_t bytes = 0; - Iota(sorted_idx); - - using KeyT = typename decltype(keys)::value_type; - using ValueT = std::remove_const_t; - - TemporaryArray out(keys.size()); - cub::DoubleBuffer d_keys(const_cast(keys.data()), - out.data().get()); - TemporaryArray sorted_idx_out(sorted_idx.size()); - cub::DoubleBuffer d_values(const_cast(sorted_idx.data()), - sorted_idx_out.data().get()); - - // track https://github.com/NVIDIA/cub/pull/340 for 64bit length support - using OffsetT = std::conditional_t; - CHECK_LE(sorted_idx.size(), std::numeric_limits::max()); - if (accending) { - void *d_temp_storage = nullptr; -#if THRUST_MAJOR_VERSION >= 2 - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr))); -#else - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false))); -#endif - TemporaryArray storage(bytes); - d_temp_storage = storage.data().get(); -#if THRUST_MAJOR_VERSION >= 2 - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr))); -#else - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false))); -#endif - } else { - void *d_temp_storage = nullptr; -#if THRUST_MAJOR_VERSION >= 2 - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr))); -#else - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false))); -#endif - TemporaryArray storage(bytes); - d_temp_storage = storage.data().get(); -#if THRUST_MAJOR_VERSION >= 2 - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr))); -#else - safe_cuda((cub::DispatchRadixSort::Dispatch( - d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false))); -#endif - } - - safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), - sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); -} - class CUDAStreamView; class CUDAEvent { diff --git a/src/common/random.cc b/src/common/random.cc index d0e75729d..e0d1a2255 100644 --- a/src/common/random.cc +++ b/src/common/random.cc @@ -1,32 +1,50 @@ -/*! - * Copyright 2020 by XGBoost Contributors - * \file random.cc +/** + * Copyright 2020-2023, XGBoost Contributors */ #include "random.h" -namespace xgboost { -namespace common { +#include // for sort, max, copy +#include // for shared_ptr + +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost::common { std::shared_ptr> ColumnSampler::ColSample( std::shared_ptr> p_features, float colsample) { if (colsample == 1.0f) { return p_features; } + + int n = std::max(1, static_cast(colsample * p_features->Size())); + auto p_new_features = std::make_shared>(); + + if (ctx_->IsCUDA()) { +#if defined(XGBOOST_USE_CUDA) + cuda_impl::SampleFeature(ctx_, n, p_features, p_new_features, this->feature_weights_, + &this->weight_buffer_, &this->idx_buffer_, &rng_); + return p_new_features; +#else + AssertGPUSupport(); + return nullptr; +#endif // defined(XGBOOST_USE_CUDA) + } + const auto &features = p_features->HostVector(); CHECK_GT(features.size(), 0); - int n = std::max(1, static_cast(colsample * features.size())); - auto p_new_features = std::make_shared>(); auto &new_features = *p_new_features; - if (feature_weights_.size() != 0) { + if (!feature_weights_.Empty()) { auto const &h_features = p_features->HostVector(); - std::vector weights(h_features.size()); + auto const &h_feature_weight = feature_weights_.ConstHostVector(); + auto &weight = this->weight_buffer_.HostVector(); + weight.resize(h_features.size()); for (size_t i = 0; i < h_features.size(); ++i) { - weights[i] = feature_weights_[h_features[i]]; + weight[i] = h_feature_weight[h_features[i]]; } CHECK(ctx_); new_features.HostVector() = - WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n); + WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weight, n); } else { new_features.Resize(features.size()); std::copy(features.begin(), features.end(), new_features.HostVector().begin()); @@ -36,5 +54,4 @@ std::shared_ptr> ColumnSampler::ColSample( std::sort(new_features.HostVector().begin(), new_features.HostVector().end()); return p_new_features; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/src/common/random.cu b/src/common/random.cu new file mode 100644 index 000000000..f5811d924 --- /dev/null +++ b/src/common/random.cu @@ -0,0 +1,106 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include // for shuffle + +#include // for shared_ptr + +#include "algorithm.cuh" // for ArgSort +#include "cuda_context.cuh" // for CUDAContext +#include "device_helpers.cuh" +#include "random.h" +#include "xgboost/base.h" // for bst_feature_t +#include "xgboost/context.h" // for Context +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost::common::cuda_impl { +// GPU implementation for sampling without replacement, see the CPU version for references. +void WeightedSamplingWithoutReplacement(Context const *ctx, common::Span array, + common::Span weights, + common::Span results, + HostDeviceVector *sorted_idx, + GlobalRandomEngine *grng) { + CUDAContext const *cuctx = ctx->CUDACtx(); + CHECK_EQ(array.size(), weights.size()); + // Sampling keys + dh::caching_device_vector keys(weights.size()); + + auto d_keys = dh::ToSpan(keys); + + auto seed = (*grng)(); + constexpr auto kEps = kRtEps; // avoid CUDA compilation error + thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), array.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + thrust::default_random_engine rng; + rng.seed(seed); + rng.discard(i); + thrust::uniform_real_distribution dist; + + auto w = std::max(weights[i], kEps); + auto u = dist(rng); + auto k = std::log(u) / w; + d_keys[i] = k; + }); + // Allocate buffer for sorted index. + auto d_idx = dh::LazyResize(ctx, sorted_idx, keys.size()); + + ArgSort(ctx, d_keys, d_idx); + + // Filter the result according to sorted index. + auto it = thrust::make_permutation_iterator(dh::tbegin(array), dh::tbegin(d_idx)); + // |array| == |weights| == |keys| == |sorted_idx| >= |results| + for (auto size : {array.size(), weights.size(), keys.size()}) { + CHECK_EQ(size, d_idx.size()); + } + CHECK_GE(array.size(), results.size()); + thrust::copy_n(cuctx->CTP(), it, results.size(), dh::tbegin(results)); +} + +void SampleFeature(Context const *ctx, bst_feature_t n_features, + std::shared_ptr> p_features, + std::shared_ptr> p_new_features, + HostDeviceVector const &feature_weights, + HostDeviceVector *weight_buffer, + HostDeviceVector *idx_buffer, GlobalRandomEngine *grng) { + CUDAContext const *cuctx = ctx->CUDACtx(); + auto &new_features = *p_new_features; + new_features.SetDevice(ctx->Device()); + p_features->SetDevice(ctx->Device()); + CHECK_LE(n_features, p_features->Size()); + + if (!feature_weights.Empty()) { + CHECK_LE(p_features->Size(), feature_weights.Size()); + idx_buffer->SetDevice(ctx->Device()); + feature_weights.SetDevice(ctx->Device()); + + auto d_old_features = p_features->DeviceSpan(); + auto d_weight_buffer = dh::LazyResize(ctx, weight_buffer, d_old_features.size()); + // Filter weights according to the existing feature index. + auto d_feature_weight = feature_weights.ConstDeviceSpan(); + auto it = thrust::make_permutation_iterator(dh::tcbegin(d_feature_weight), + dh::tcbegin(d_old_features)); + thrust::copy_n(cuctx->CTP(), it, d_old_features.size(), dh::tbegin(d_weight_buffer)); + new_features.Resize(n_features); + WeightedSamplingWithoutReplacement(ctx, d_old_features, d_weight_buffer, + new_features.DeviceSpan(), idx_buffer, grng); + } else { + new_features.Resize(p_features->Size()); + new_features.Copy(*p_features); + auto d_feat = new_features.DeviceSpan(); + thrust::default_random_engine rng; + rng.seed((*grng)()); + thrust::shuffle(cuctx->CTP(), dh::tbegin(d_feat), dh::tend(d_feat), rng); + new_features.Resize(n_features); + } + + auto d_new_features = new_features.DeviceSpan(); + thrust::sort(cuctx->CTP(), dh::tbegin(d_new_features), dh::tend(d_new_features)); +} + +void InitFeatureSet(Context const *ctx, + std::shared_ptr> p_features) { + CUDAContext const *cuctx = ctx->CUDACtx(); + auto d_features = p_features->DeviceSpan(); + thrust::sequence(cuctx->CTP(), dh::tbegin(d_features), dh::tend(d_features), 0); +} +} // namespace xgboost::common::cuda_impl diff --git a/src/common/random.h b/src/common/random.h index 5efdb486d..2a94123a3 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2020 by Contributors +/** + * Copyright 2015-2020, XGBoost Contributors * \file random.h * \brief Utility related to random. * \author Tianqi Chen @@ -25,8 +25,7 @@ #include "xgboost/context.h" // Context #include "xgboost/host_device_vector.h" -namespace xgboost { -namespace common { +namespace xgboost::common { /*! * \brief Define mt19937 as default type Random Engine. */ @@ -113,6 +112,18 @@ std::vector WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto return results; } +namespace cuda_impl { +void SampleFeature(Context const* ctx, bst_feature_t n_features, + std::shared_ptr> p_features, + std::shared_ptr> p_new_features, + HostDeviceVector const& feature_weights, + HostDeviceVector* weight_buffer, + HostDeviceVector* idx_buffer, GlobalRandomEngine* grng); + +void InitFeatureSet(Context const* ctx, + std::shared_ptr> p_features); +} // namespace cuda_impl + /** * \class ColumnSampler * @@ -123,46 +134,37 @@ std::vector WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto class ColumnSampler { std::shared_ptr> feature_set_tree_; std::map>> feature_set_level_; - std::vector feature_weights_; + HostDeviceVector feature_weights_; float colsample_bylevel_{1.0f}; float colsample_bytree_{1.0f}; float colsample_bynode_{1.0f}; GlobalRandomEngine rng_; Context const* ctx_; + // Used for weighted sampling. + HostDeviceVector idx_buffer_; + HostDeviceVector weight_buffer_; + public: std::shared_ptr> ColSample( std::shared_ptr> p_features, float colsample); /** - * \brief Column sampler constructor. - * \note This constructor manually sets the rng seed + * @brief Column sampler constructor. + * @note This constructor manually sets the rng seed */ - explicit ColumnSampler(uint32_t seed) { - rng_.seed(seed); - } + explicit ColumnSampler(std::uint32_t seed) { rng_.seed(seed); } /** - * \brief Column sampler constructor. - * \note This constructor synchronizes the RNG seed across processes. - */ - ColumnSampler() { - uint32_t seed = common::GlobalRandom()(); - collective::Broadcast(&seed, sizeof(seed), 0); - rng_.seed(seed); - } - - /** - * \brief Initialise this object before use. + * @brief Initialise this object before use. * - * \param num_col - * \param colsample_bynode - * \param colsample_bylevel - * \param colsample_bytree - * \param skip_index_0 (Optional) True to skip index 0. + * @param num_col + * @param colsample_bynode Sampling rate for node. + * @param colsample_bylevel Sampling rate for tree level. + * @param colsample_bytree Sampling rate for tree. */ void Init(Context const* ctx, int64_t num_col, std::vector feature_weights, float colsample_bynode, float colsample_bylevel, float colsample_bytree) { - feature_weights_ = std::move(feature_weights); + feature_weights_.HostVector() = std::move(feature_weights); colsample_bylevel_ = colsample_bylevel; colsample_bytree_ = colsample_bytree; colsample_bynode_ = colsample_bynode; @@ -173,8 +175,17 @@ class ColumnSampler { } Reset(); + feature_set_tree_->SetDevice(ctx->Device()); feature_set_tree_->Resize(num_col); - std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0); + if (ctx->IsCPU()) { + std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0); + } else { +#if defined(XGBOOST_USE_CUDA) + cuda_impl::InitFeatureSet(ctx, feature_set_tree_); +#else + AssertGPUSupport(); +#endif + } feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_); } @@ -216,6 +227,11 @@ class ColumnSampler { } }; -} // namespace common -} // namespace xgboost +inline auto MakeColumnSampler(Context const*) { + std::uint32_t seed = common::GlobalRandomEngine()(); + collective::Broadcast(&seed, sizeof(seed), 0); + auto cs = std::make_shared(seed); + return cs; +} +} // namespace xgboost::common #endif // XGBOOST_COMMON_RANDOM_H_ diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 2e5c88174..4a8aa8a4b 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -360,7 +360,7 @@ class EvalROCAUC : public EvalAUC { common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { std::tie(fp, tp, auc) = - GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_); + GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); } return std::make_tuple(fp, tp, auc); } @@ -376,8 +376,9 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc") .set_body([](const char*) { return new EvalROCAUC(); }); #if !defined(XGBOOST_USE_CUDA) -std::tuple GPUBinaryROCAUC(common::Span, MetaInfo const &, - DeviceOrd, std::shared_ptr *) { +std::tuple GPUBinaryROCAUC(Context const *, common::Span, + MetaInfo const &, + std::shared_ptr *) { common::AssertGPUSupport(); return {}; } @@ -409,8 +410,7 @@ class EvalPRAUC : public EvalAUC { BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0), common::OptionalWeights{info.weights_.ConstHostSpan()}); } else { - std::tie(pr, re, auc) = - GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_); + std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_); } return std::make_tuple(pr, re, auc); } @@ -453,8 +453,9 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr") .set_body([](char const *) { return new EvalPRAUC{}; }); #if !defined(XGBOOST_USE_CUDA) -std::tuple GPUBinaryPRAUC(common::Span, MetaInfo const &, - DeviceOrd, std::shared_ptr *) { +std::tuple GPUBinaryPRAUC(Context const *, common::Span, + MetaInfo const &, + std::shared_ptr *) { common::AssertGPUSupport(); return {}; } diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 8b8349e1b..4ce10d094 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -83,13 +83,14 @@ void InitCacheOnce(common::Span predts, std::shared_ptr -std::tuple -GPUBinaryAUC(common::Span predts, MetaInfo const &info, - DeviceOrd device, common::Span d_sorted_idx, - Fn area_fn, std::shared_ptr cache) { - auto labels = info.labels.View(device); +std::tuple GPUBinaryAUC(Context const *ctx, + common::Span predts, + MetaInfo const &info, + common::Span d_sorted_idx, Fn area_fn, + std::shared_ptr cache) { + auto labels = info.labels.View(ctx->Device()); auto weights = info.weights_.ConstDeviceSpan(); - dh::safe_cuda(cudaSetDevice(device.ordinal)); + dh::safe_cuda(cudaSetDevice(ctx->Ordinal())); CHECK_NE(labels.Size(), 0); CHECK_EQ(labels.Size(), predts.size()); @@ -115,7 +116,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); - dh::Iota(d_unique_idx); + dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream()); auto uni_key = dh::MakeTransformIterator( thrust::make_counting_iterator(0), @@ -167,8 +168,9 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, return std::make_tuple(last.first, last.second, auc); } -std::tuple GPUBinaryROCAUC(common::Span predts, - MetaInfo const &info, DeviceOrd device, +std::tuple GPUBinaryROCAUC(Context const *ctx, + common::Span predts, + MetaInfo const &info, std::shared_ptr *p_cache) { auto &cache = *p_cache; InitCacheOnce(predts, p_cache); @@ -177,10 +179,10 @@ std::tuple GPUBinaryROCAUC(common::Span pre * Create sorted index for each class */ auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - dh::ArgSort(predts, d_sorted_idx); + common::ArgSort(ctx, predts, d_sorted_idx); // Create lambda to avoid pass function pointer. return GPUBinaryAUC( - predts, info, device, d_sorted_idx, + ctx, predts, info, d_sorted_idx, [] XGBOOST_DEVICE(double x0, double x1, double y0, double y1) -> double { return TrapezoidArea(x0, x1, y0, y1); }, @@ -361,7 +363,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info, */ dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); - dh::Iota(d_unique_idx); + dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream()); auto uni_key = dh::MakeTransformIterator>( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { uint32_t class_id = i / n_samples; @@ -603,8 +605,9 @@ std::pair GPURankingAUC(Context const *ctx, common::Span< return std::make_pair(auc, n_valid); } -std::tuple GPUBinaryPRAUC(common::Span predts, - MetaInfo const &info, DeviceOrd device, +std::tuple GPUBinaryPRAUC(Context const *ctx, + common::Span predts, + MetaInfo const &info, std::shared_ptr *p_cache) { auto& cache = *p_cache; InitCacheOnce(predts, p_cache); @@ -613,9 +616,9 @@ std::tuple GPUBinaryPRAUC(common::Span pred * Create sorted index for each class */ auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - dh::ArgSort(predts, d_sorted_idx); + common::ArgSort(ctx, predts, d_sorted_idx); - auto labels = info.labels.View(device); + auto labels = info.labels.View(ctx->Device()); auto d_weights = info.weights_.ConstDeviceSpan(); auto get_weight = common::OptionalWeights{d_weights}; auto it = dh::MakeTransformIterator( @@ -639,7 +642,7 @@ std::tuple GPUBinaryPRAUC(common::Span pred return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos); }; double fp, tp, auc; - std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache); + std::tie(fp, tp, auc) = GPUBinaryAUC(ctx, predts, info, d_sorted_idx, fn, cache); return std::make_tuple(1.0, 1.0, auc); } @@ -699,16 +702,17 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span predts, } template -std::pair -GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, - common::Span d_group_ptr, DeviceOrd device, - std::shared_ptr cache, Fn area_fn) { +std::pair GPURankingPRAUCImpl(Context const *ctx, + common::Span predts, + MetaInfo const &info, + common::Span d_group_ptr, + std::shared_ptr cache, Fn area_fn) { /** * Sorted idx */ auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); - auto labels = info.labels.View(device); + auto labels = info.labels.View(ctx->Device()); auto weights = info.weights_.ConstDeviceSpan(); uint32_t n_groups = static_cast(info.group_ptr_.size() - 1); @@ -739,7 +743,7 @@ GPURankingPRAUCImpl(common::Span predts, MetaInfo const &info, */ dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); - dh::Iota(d_unique_idx); + dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream()); auto uni_key = dh::MakeTransformIterator>( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { auto idx = d_sorted_idx[i]; @@ -882,7 +886,7 @@ std::pair GPURankingPRAUC(Context const *ctx, return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, d_totals[group_id].first); }; - return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn); + return GPURankingPRAUCImpl(ctx, predts, info, d_group_ptr, cache, fn); } } // namespace metric } // namespace xgboost diff --git a/src/metric/auc.h b/src/metric/auc.h index fce1cc757..4fe2ecec4 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2021 by XGBoost Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ #ifndef XGBOOST_METRIC_AUC_H_ #define XGBOOST_METRIC_AUC_H_ @@ -18,8 +18,7 @@ #include "xgboost/metric.h" #include "xgboost/span.h" -namespace xgboost { -namespace metric { +namespace xgboost::metric { /*********** * ROC AUC * ***********/ @@ -29,8 +28,9 @@ XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, doub struct DeviceAUCCache; -std::tuple GPUBinaryROCAUC(common::Span predts, - MetaInfo const &info, DeviceOrd, +std::tuple GPUBinaryROCAUC(Context const *ctx, + common::Span predts, + MetaInfo const &info, std::shared_ptr *p_cache); double GPUMultiClassROCAUC(Context const *ctx, common::Span predts, @@ -44,8 +44,9 @@ std::pair GPURankingAUC(Context const *ctx, common::Span< /********** * PR AUC * **********/ -std::tuple GPUBinaryPRAUC(common::Span predts, - MetaInfo const &info, DeviceOrd, +std::tuple GPUBinaryPRAUC(Context const *ctx, + common::Span predts, + MetaInfo const &info, std::shared_ptr *p_cache); double GPUMultiClassPRAUC(Context const *ctx, common::Span predts, @@ -111,6 +112,5 @@ struct PRAUCLabelInvalid { inline void InvalidLabels() { LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank."; } -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric #endif // XGBOOST_METRIC_AUC_H_ diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index cea211622..07644146b 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -13,9 +13,7 @@ #include "adaptive.h" #include "xgboost/context.h" -namespace xgboost { -namespace obj { -namespace detail { +namespace xgboost::obj::detail { void EncodeTreeLeafDevice(Context const* ctx, common::Span position, dh::device_vector* p_ridx, HostDeviceVector* p_nptr, HostDeviceVector* p_nidx, RegTree const& tree) { @@ -28,7 +26,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span pos position.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream())); p_ridx->resize(position.size()); - dh::Iota(dh::ToSpan(*p_ridx)); + dh::Iota(dh::ToSpan(*p_ridx), cuctx->Stream()); // sort row index according to node index thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(), sorted_position.begin() + n_samples, p_ridx->begin()); @@ -190,6 +188,4 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos }); UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree); } -} // namespace detail -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj::detail diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index f862e048e..6eed74c56 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -72,7 +72,7 @@ common::Span GPUHistEvaluator::SortHistogram( TreeEvaluator::SplitEvaluator evaluator) { dh::XGBCachingDeviceAllocator alloc; auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size()); - dh::Iota(sorted_idx); + dh::Iota(sorted_idx, dh::DefaultStream()); auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size()); auto it = thrust::make_counting_iterator(0u); auto d_feature_idx = dh::ToSpan(feature_idx_); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 3c37556e1..94e7547ee 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -248,8 +248,7 @@ class GlobalApproxUpdater : public TreeUpdater { std::unique_ptr pimpl_; // pointer to the last DMatrix, used for update prediction cache. DMatrix *cached_{nullptr}; - std::shared_ptr column_sampler_ = - std::make_shared(); + std::shared_ptr column_sampler_; ObjInfo const *task_; HistMakerTrainParam hist_param_; @@ -284,6 +283,9 @@ class GlobalApproxUpdater : public TreeUpdater { common::Span> out_position, const std::vector &trees) override { CHECK(hist_param_.GetInitialised()); + if (!column_sampler_) { + column_sampler_ = common::MakeColumnSampler(ctx_); + } pimpl_ = std::make_unique(param, &hist_param_, m->Info(), ctx_, column_sampler_, task_, &monitor_); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 7a88bd30e..e366811f7 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -225,9 +225,12 @@ class ColMaker: public TreeUpdater { } } { - column_sampler_.Init(ctx_, fmat.Info().num_col_, - fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode, - param_.colsample_bylevel, param_.colsample_bytree); + if (!column_sampler_) { + column_sampler_ = common::MakeColumnSampler(ctx_); + } + column_sampler_->Init( + ctx_, fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(), + param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree); } { // setup temp space for each thread @@ -467,7 +470,7 @@ class ColMaker: public TreeUpdater { RegTree *p_tree) { auto evaluator = tree_evaluator_.GetEvaluator(); - auto feat_set = column_sampler_.GetFeatureSet(depth); + auto feat_set = column_sampler_->GetFeatureSet(depth); for (const auto &batch : p_fmat->GetBatches(ctx_)) { this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat); } @@ -586,7 +589,7 @@ class ColMaker: public TreeUpdater { const ColMakerTrainParam& colmaker_train_param_; // number of omp thread used during training Context const* ctx_; - common::ColumnSampler column_sampler_; + std::shared_ptr column_sampler_; // Instance Data: current node position in the tree of each instance std::vector position_; // PerThread x PerTreeNode: statistics for per thread construction diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 375b24cfa..2bb5b0b49 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2023, XGBoost Contributors * \file updater_quantile_hist.cc * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov @@ -470,8 +470,7 @@ class HistUpdater { class QuantileHistMaker : public TreeUpdater { std::unique_ptr p_impl_{nullptr}; std::unique_ptr p_mtimpl_{nullptr}; - std::shared_ptr column_sampler_ = - std::make_shared(); + std::shared_ptr column_sampler_; common::Monitor monitor_; ObjInfo const *task_{nullptr}; HistMakerTrainParam hist_param_; @@ -495,6 +494,10 @@ class QuantileHistMaker : public TreeUpdater { void Update(TrainParam const *param, linalg::Matrix *gpair, DMatrix *p_fmat, common::Span> out_position, const std::vector &trees) override { + if (!column_sampler_) { + column_sampler_ = common::MakeColumnSampler(ctx_); + } + if (trees.front()->IsMultiTarget()) { CHECK(hist_param_.GetInitialised()); CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented(); diff --git a/tests/cpp/common/test_algorithm.cu b/tests/cpp/common/test_algorithm.cu index c36073397..8f857ff50 100644 --- a/tests/cpp/common/test_algorithm.cu +++ b/tests/cpp/common/test_algorithm.cu @@ -57,13 +57,13 @@ TEST(Algorithm, GpuArgSort) { auto ctx = MakeCUDACtx(0); dh::device_vector values(20); - dh::Iota(dh::ToSpan(values)); // accending + dh::Iota(dh::ToSpan(values), ctx.CUDACtx()->Stream()); // accending dh::device_vector sorted_idx(20); - dh::ArgSort(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending - ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(), sorted_idx.end(), + ArgSort(&ctx, dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending + ASSERT_TRUE(thrust::is_sorted(ctx.CUDACtx()->CTP(), sorted_idx.begin(), sorted_idx.end(), thrust::greater{})); - dh::Iota(dh::ToSpan(values)); + dh::Iota(dh::ToSpan(values), ctx.CUDACtx()->Stream()); dh::device_vector groups(3); groups[0] = 0; groups[1] = 10; diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 92d8ff753..c0d5c5ddc 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -16,6 +16,7 @@ #include // for vector #include "../../../include/xgboost/logging.h" +#include "../../../src/common/cuda_context.cuh" #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/hist_util.h" @@ -211,7 +212,7 @@ TEST(HistUtil, RemoveDuplicatedCategories) { cuts_ptr.SetDevice(DeviceOrd::CUDA(0)); dh::device_vector weight(n_samples * n_features, 0); - dh::Iota(dh::ToSpan(weight)); + dh::Iota(dh::ToSpan(weight), ctx.CUDACtx()->Stream()); dh::caching_device_vector columns_ptr(4); for (std::size_t i = 0; i < columns_ptr.size(); ++i) { diff --git a/tests/cpp/common/test_random.cc b/tests/cpp/common/test_random.cc index e2ecd0990..a51776475 100644 --- a/tests/cpp/common/test_random.cc +++ b/tests/cpp/common/test_random.cc @@ -1,19 +1,20 @@ -#include +/** + * Copyright 2018-2023, XGBoost Contributors + */ #include "../../../src/common/random.h" #include "../helpers.h" #include "gtest/gtest.h" -#include "xgboost/context.h" // Context +#include "xgboost/context.h" // for Context -namespace xgboost { -namespace common { -TEST(ColumnSampler, Test) { - Context ctx; +namespace xgboost::common { +namespace { +void TestBasic(Context const* ctx) { int n = 128; - ColumnSampler cs; + ColumnSampler cs{1u}; std::vector feature_weights; // No node sampling - cs.Init(&ctx, n, feature_weights, 1.0f, 0.5f, 0.5f); + cs.Init(ctx, n, feature_weights, 1.0f, 0.5f, 0.5f); auto set0 = cs.GetFeatureSet(0); ASSERT_EQ(set0->Size(), 32); @@ -26,7 +27,7 @@ TEST(ColumnSampler, Test) { ASSERT_EQ(set2->Size(), 32); // Node sampling - cs.Init(&ctx, n, feature_weights, 0.5f, 1.0f, 0.5f); + cs.Init(ctx, n, feature_weights, 0.5f, 1.0f, 0.5f); auto set3 = cs.GetFeatureSet(0); ASSERT_EQ(set3->Size(), 32); @@ -36,21 +37,33 @@ TEST(ColumnSampler, Test) { ASSERT_EQ(set4->Size(), 32); // No level or node sampling, should be the same at different depth - cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 0.5f); - ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), - cs.GetFeatureSet(1)->HostVector()); + cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 0.5f); + ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector()); - cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f); + cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 1.0f); auto set5 = cs.GetFeatureSet(0); ASSERT_EQ(set5->Size(), n); - cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f); + cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 1.0f); auto set6 = cs.GetFeatureSet(0); ASSERT_EQ(set5->HostVector(), set6->HostVector()); // Should always be a minimum of one feature - cs.Init(&ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f); + cs.Init(ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f); ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1); } +} // namespace + +TEST(ColumnSampler, Test) { + Context ctx; + TestBasic(&ctx); +} + +#if defined(XGBOOST_USE_CUDA) +TEST(ColumnSampler, GPUTest) { + auto ctx = MakeCUDACtx(0); + TestBasic(&ctx); +} +#endif // defined(XGBOOST_USE_CUDA) // Test if different threads using the same seed produce the same result TEST(ColumnSampler, ThreadSynchronisation) { @@ -81,16 +94,16 @@ TEST(ColumnSampler, ThreadSynchronisation) { ASSERT_TRUE(success); } -TEST(ColumnSampler, WeightedSampling) { - auto test_basic = [](int first) { - Context ctx; +namespace { +void TestWeightedSampling(Context const* ctx) { + auto test_basic = [ctx](int first) { std::vector feature_weights(2); feature_weights[0] = std::abs(first - 1.0f); feature_weights[1] = first - 0.0f; ColumnSampler cs{0}; - cs.Init(&ctx, 2, feature_weights, 1.0, 1.0, 0.5); + cs.Init(ctx, 2, feature_weights, 1.0, 1.0, 0.5); auto feature_sets = cs.GetFeatureSet(0); - auto const &h_feat_set = feature_sets->HostVector(); + auto const& h_feat_set = feature_sets->HostVector(); ASSERT_EQ(h_feat_set.size(), 1); ASSERT_EQ(h_feat_set[0], first - 0); }; @@ -104,8 +117,7 @@ TEST(ColumnSampler, WeightedSampling) { SimpleRealUniformDistribution dist(.0f, 12.0f); std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); }); ColumnSampler cs{0}; - Context ctx; - cs.Init(&ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f); + cs.Init(ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f); std::vector features(kCols); std::iota(features.begin(), features.end(), 0); std::vector freq(kCols, 0); @@ -131,8 +143,22 @@ TEST(ColumnSampler, WeightedSampling) { EXPECT_NEAR(freq[i], feature_weights[i], 1e-2); } } +} // namespace -TEST(ColumnSampler, WeightedMultiSampling) { +TEST(ColumnSampler, WeightedSampling) { + Context ctx; + TestWeightedSampling(&ctx); +} + +#if defined(XGBOOST_USE_CUDA) +TEST(ColumnSampler, GPUWeightedSampling) { + auto ctx = MakeCUDACtx(0); + TestWeightedSampling(&ctx); +} +#endif // defined(XGBOOST_USE_CUDA) + +namespace { +void TestWeightedMultiSampling(Context const* ctx) { size_t constexpr kCols = 32; std::vector feature_weights(kCols, 0); for (size_t i = 0; i < feature_weights.size(); ++i) { @@ -140,13 +166,24 @@ TEST(ColumnSampler, WeightedMultiSampling) { } ColumnSampler cs{0}; float bytree{0.5}, bylevel{0.5}, bynode{0.5}; - Context ctx; - cs.Init(&ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode); + cs.Init(ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode); auto feature_set = cs.GetFeatureSet(0); size_t n_sampled = kCols * bytree * bylevel * bynode; ASSERT_EQ(feature_set->Size(), n_sampled); feature_set = cs.GetFeatureSet(1); ASSERT_EQ(feature_set->Size(), n_sampled); } -} // namespace common -} // namespace xgboost +} // namespace + +TEST(ColumnSampler, WeightedMultiSampling) { + Context ctx; + TestWeightedMultiSampling(&ctx); +} + +#if defined(XGBOOST_USE_CUDA) +TEST(ColumnSampler, GPUWeightedMultiSampling) { + auto ctx = MakeCUDACtx(0); + TestWeightedMultiSampling(&ctx); +} +#endif // defined(XGBOOST_USE_CUDA) +} // namespace xgboost::common diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 78fda5ce5..329379b5b 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -28,7 +28,7 @@ void TestEvaluateSplits(bool force_read_by_column) { Context ctx; ctx.nthread = 4; int static constexpr kRows = 8, kCols = 16; - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); TrainParam param; param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); @@ -102,7 +102,7 @@ TEST(HistMultiEvaluator, Evaluate) { TrainParam param; param.Init(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); std::size_t n_samples = 3; bst_feature_t n_features = 2; @@ -166,7 +166,7 @@ TEST(HistEvaluator, Apply) { TrainParam param; param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}}); auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); auto evaluator_ = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler}; CPUExpandEntry entry{0, 0}; @@ -194,7 +194,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) { Context ctx; // check the evaluator is returning the optimal split std::vector ft{FeatureType::kCategorical}; - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); HistEvaluator evaluator{&ctx, ¶m_, info_, sampler}; evaluator.InitRoot(GradStats{total_gpair_}); RegTree tree; @@ -224,7 +224,7 @@ auto CompareOneHotAndPartition(bool onehot) { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); auto evaluator = HistEvaluator{&ctx, ¶m, dmat->Info(), sampler}; std::vector entries(1); HistMakerTrainParam hist_param; @@ -271,7 +271,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { ASSERT_EQ(node_hist.size(), feature_histogram_.size()); std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin()); - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); MetaInfo info; info.num_col_ = 1; info.feature_types = {FeatureType::kCategorical}; diff --git a/tests/cpp/tree/test_constraints.cc b/tests/cpp/tree/test_constraints.cc index 912d608a3..4f810102d 100644 --- a/tests/cpp/tree/test_constraints.cc +++ b/tests/cpp/tree/test_constraints.cc @@ -1,3 +1,6 @@ +/** + * Copyright 2019-2023, XGBoost Contributors + */ #include #include #include @@ -9,9 +12,7 @@ #include "../../../src/tree/hist/evaluate_splits.h" #include "../helpers.h" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { TEST(CPUFeatureInteractionConstraint, Empty) { TrainParam param; param.UpdateAllowUnknown(Args{}); @@ -77,7 +78,7 @@ TEST(CPUMonoConstraint, Basic) { param.UpdateAllowUnknown(Args{{"monotone_constraints", str_mono}}); auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true); - auto sampler = std::make_shared(); + auto sampler = std::make_shared(1u); HistEvaluator evalutor{&ctx, ¶m, Xy->Info(), sampler}; evalutor.InitRoot(GradStats{2.0, 2.0}); @@ -90,5 +91,4 @@ TEST(CPUMonoConstraint, Basic) { ASSERT_TRUE(evalutor.Evaluator().has_constraint); } -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree