Implement column sampler in CUDA. (#9785)

- CUDA implementation.
- Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation.
- Some changes to the event loop for fixing a deadlock in CI.
- Move argsort into algorithms.cuh, add support for cuda stream.
This commit is contained in:
Jiaming Yuan 2023-11-17 04:29:08 +08:00 committed by GitHub
parent 178cfe70a8
commit fedd9674c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 447 additions and 232 deletions

View File

@ -117,11 +117,14 @@ void Loop::Process() {
break;
}
auto unlock_notify = [&](bool is_blocking) {
auto unlock_notify = [&](bool is_blocking, bool stop) {
if (!is_blocking) {
return;
}
std::lock_guard guard{mu_};
stop_ = stop;
} else {
stop_ = stop;
lock.unlock();
}
cv_.notify_one();
};
@ -145,13 +148,14 @@ void Loop::Process() {
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
unlock_notify(is_blocking, true);
std::lock_guard<std::mutex> guard{rc_lock_};
this->rc_ = std::move(rc);
unlock_notify(is_blocking);
return;
}
CHECK(qcopy.empty());
unlock_notify(is_blocking);
unlock_notify(is_blocking, false);
}
}
@ -170,12 +174,21 @@ Result Loop::Stop() {
}
[[nodiscard]] Result Loop::Block() {
{
std::lock_guard<std::mutex> guard{rc_lock_};
if (!rc_.OK()) {
return std::move(rc_);
}
}
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
{
std::lock_guard<std::mutex> lock{rc_lock_};
return std::move(rc_);
}
}
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {

View File

@ -42,7 +42,10 @@ class Loop {
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;
Result rc_;
std::mutex rc_lock_; // lock for transferring error info.
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor mutable timer_;

View File

@ -23,8 +23,7 @@
#include "xgboost/logging.h" // CHECK
#include "xgboost/span.h" // Span,byte
namespace xgboost {
namespace common {
namespace xgboost::common {
namespace detail {
// Wrapper around cub sort to define is_decending
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
@ -127,13 +126,14 @@ inline void SegmentedSortKeys(Context const *ctx, Span<V const> group_ptr,
template <bool accending, bool per_seg_index, typename U, typename V, typename IdxT>
void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
Span<IdxT> sorted_idx) {
auto cuctx = ctx->CUDACtx();
CHECK_GE(group_ptr.size(), 1ul);
std::size_t n_groups = group_ptr.size() - 1;
std::size_t bytes = 0;
if (per_seg_index) {
SegmentedSequence(ctx, group_ptr, sorted_idx);
} else {
dh::Iota(sorted_idx);
dh::Iota(sorted_idx, cuctx->Stream());
}
dh::TemporaryArray<std::remove_const_t<U>> values_out(values.size());
dh::TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
@ -141,15 +141,16 @@ void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
group_ptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
group_ptr.data() + 1, cuctx->Stream());
dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}
/**
@ -159,11 +160,12 @@ void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
template <typename SegIt, typename ValIt>
void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin,
ValIt val_end, dh::device_vector<std::size_t> *p_sorted_idx) {
auto cuctx = ctx->CUDACtx();
using Tup = thrust::tuple<std::int32_t, float>;
auto &sorted_idx = *p_sorted_idx;
std::size_t n = std::distance(val_begin, val_end);
sorted_idx.resize(n);
dh::Iota(dh::ToSpan(sorted_idx));
dh::Iota(dh::ToSpan(sorted_idx), cuctx->Stream());
dh::device_vector<Tup> keys(sorted_idx.size());
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) -> Tup {
@ -177,7 +179,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
return thrust::make_tuple(seg_idx, residue);
});
thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin());
thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
thrust::stable_sort_by_key(cuctx->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
if (thrust::get<0>(l) != thrust::get<0>(r)) {
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
@ -185,6 +187,75 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
return thrust::get<1>(l) < thrust::get<1>(r); // residue
});
}
} // namespace common
} // namespace xgboost
template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
xgboost::common::Span<IdxT> sorted_idx) {
std::size_t bytes = 0;
auto cuctx = ctx->CUDACtx();
dh::Iota(sorted_idx, cuctx->Stream());
using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;
dh::TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()), out.data().get());
dh::TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx_out.data().get());
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!dh::BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
} else {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
cuctx->Stream())));
#else
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
nullptr, false)));
#endif
}
dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

View File

@ -313,8 +313,8 @@ inline void LaunchN(size_t n, L lambda) {
}
template <typename Container>
void Iota(Container array) {
LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; });
void Iota(Container array, cudaStream_t stream) {
LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; });
}
namespace detail {
@ -597,6 +597,16 @@ class DoubleBuffer {
T *Other() { return buff.Alternate(); }
};
template <typename T>
xgboost::common::Span<T> LazyResize(xgboost::Context const *ctx,
xgboost::HostDeviceVector<T> *buffer, std::size_t n) {
buffer->SetDevice(ctx->Device());
if (buffer->Size() < n) {
buffer->Resize(n);
}
return buffer->DeviceSpan().subspan(0, n);
}
/**
* \brief Copies device span to std::vector.
*
@ -1060,74 +1070,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items)
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}
template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0;
Iota(sorted_idx);
using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;
TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get());
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx_out.data().get());
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
if (accending) {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
} else {
void *d_temp_storage = nullptr;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr)));
#else
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
}
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
}
class CUDAStreamView;
class CUDAEvent {

View File

@ -1,32 +1,50 @@
/*!
* Copyright 2020 by XGBoost Contributors
* \file random.cc
/**
* Copyright 2020-2023, XGBoost Contributors
*/
#include "random.h"
namespace xgboost {
namespace common {
#include <algorithm> // for sort, max, copy
#include <memory> // for shared_ptr
#include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost::common {
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
if (colsample == 1.0f) {
return p_features;
}
int n = std::max(1, static_cast<int>(colsample * p_features->Size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
if (ctx_->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
cuda_impl::SampleFeature(ctx_, n, p_features, p_new_features, this->feature_weights_,
&this->weight_buffer_, &this->idx_buffer_, &rng_);
return p_new_features;
#else
AssertGPUSupport();
return nullptr;
#endif // defined(XGBOOST_USE_CUDA)
}
const auto &features = p_features->HostVector();
CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
auto &new_features = *p_new_features;
if (feature_weights_.size() != 0) {
if (!feature_weights_.Empty()) {
auto const &h_features = p_features->HostVector();
std::vector<float> weights(h_features.size());
auto const &h_feature_weight = feature_weights_.ConstHostVector();
auto &weight = this->weight_buffer_.HostVector();
weight.resize(h_features.size());
for (size_t i = 0; i < h_features.size(); ++i) {
weights[i] = feature_weights_[h_features[i]];
weight[i] = h_feature_weight[h_features[i]];
}
CHECK(ctx_);
new_features.HostVector() =
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weight, n);
} else {
new_features.Resize(features.size());
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
@ -36,5 +54,4 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
return p_new_features;
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common

106
src/common/random.cu Normal file
View File

@ -0,0 +1,106 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <thrust/shuffle.h> // for shuffle
#include <memory> // for shared_ptr
#include "algorithm.cuh" // for ArgSort
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh"
#include "random.h"
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost::common::cuda_impl {
// GPU implementation for sampling without replacement, see the CPU version for references.
void WeightedSamplingWithoutReplacement(Context const *ctx, common::Span<bst_feature_t const> array,
common::Span<float const> weights,
common::Span<bst_feature_t> results,
HostDeviceVector<bst_feature_t> *sorted_idx,
GlobalRandomEngine *grng) {
CUDAContext const *cuctx = ctx->CUDACtx();
CHECK_EQ(array.size(), weights.size());
// Sampling keys
dh::caching_device_vector<float> keys(weights.size());
auto d_keys = dh::ToSpan(keys);
auto seed = (*grng)();
constexpr auto kEps = kRtEps; // avoid CUDA compilation error
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), array.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
thrust::default_random_engine rng;
rng.seed(seed);
rng.discard(i);
thrust::uniform_real_distribution<float> dist;
auto w = std::max(weights[i], kEps);
auto u = dist(rng);
auto k = std::log(u) / w;
d_keys[i] = k;
});
// Allocate buffer for sorted index.
auto d_idx = dh::LazyResize(ctx, sorted_idx, keys.size());
ArgSort<false>(ctx, d_keys, d_idx);
// Filter the result according to sorted index.
auto it = thrust::make_permutation_iterator(dh::tbegin(array), dh::tbegin(d_idx));
// |array| == |weights| == |keys| == |sorted_idx| >= |results|
for (auto size : {array.size(), weights.size(), keys.size()}) {
CHECK_EQ(size, d_idx.size());
}
CHECK_GE(array.size(), results.size());
thrust::copy_n(cuctx->CTP(), it, results.size(), dh::tbegin(results));
}
void SampleFeature(Context const *ctx, bst_feature_t n_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
HostDeviceVector<float> const &feature_weights,
HostDeviceVector<float> *weight_buffer,
HostDeviceVector<bst_feature_t> *idx_buffer, GlobalRandomEngine *grng) {
CUDAContext const *cuctx = ctx->CUDACtx();
auto &new_features = *p_new_features;
new_features.SetDevice(ctx->Device());
p_features->SetDevice(ctx->Device());
CHECK_LE(n_features, p_features->Size());
if (!feature_weights.Empty()) {
CHECK_LE(p_features->Size(), feature_weights.Size());
idx_buffer->SetDevice(ctx->Device());
feature_weights.SetDevice(ctx->Device());
auto d_old_features = p_features->DeviceSpan();
auto d_weight_buffer = dh::LazyResize(ctx, weight_buffer, d_old_features.size());
// Filter weights according to the existing feature index.
auto d_feature_weight = feature_weights.ConstDeviceSpan();
auto it = thrust::make_permutation_iterator(dh::tcbegin(d_feature_weight),
dh::tcbegin(d_old_features));
thrust::copy_n(cuctx->CTP(), it, d_old_features.size(), dh::tbegin(d_weight_buffer));
new_features.Resize(n_features);
WeightedSamplingWithoutReplacement(ctx, d_old_features, d_weight_buffer,
new_features.DeviceSpan(), idx_buffer, grng);
} else {
new_features.Resize(p_features->Size());
new_features.Copy(*p_features);
auto d_feat = new_features.DeviceSpan();
thrust::default_random_engine rng;
rng.seed((*grng)());
thrust::shuffle(cuctx->CTP(), dh::tbegin(d_feat), dh::tend(d_feat), rng);
new_features.Resize(n_features);
}
auto d_new_features = new_features.DeviceSpan();
thrust::sort(cuctx->CTP(), dh::tbegin(d_new_features), dh::tend(d_new_features));
}
void InitFeatureSet(Context const *ctx,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features) {
CUDAContext const *cuctx = ctx->CUDACtx();
auto d_features = p_features->DeviceSpan();
thrust::sequence(cuctx->CTP(), dh::tbegin(d_features), dh::tend(d_features), 0);
}
} // namespace xgboost::common::cuda_impl

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015-2020 by Contributors
/**
* Copyright 2015-2020, XGBoost Contributors
* \file random.h
* \brief Utility related to random.
* \author Tianqi Chen
@ -25,8 +25,7 @@
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h"
namespace xgboost {
namespace common {
namespace xgboost::common {
/*!
* \brief Define mt19937 as default type Random Engine.
*/
@ -113,6 +112,18 @@ std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto
return results;
}
namespace cuda_impl {
void SampleFeature(Context const* ctx, bst_feature_t n_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
HostDeviceVector<float> const& feature_weights,
HostDeviceVector<float>* weight_buffer,
HostDeviceVector<bst_feature_t>* idx_buffer, GlobalRandomEngine* grng);
void InitFeatureSet(Context const* ctx,
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features);
} // namespace cuda_impl
/**
* \class ColumnSampler
*
@ -123,46 +134,37 @@ std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto
class ColumnSampler {
std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
std::vector<float> feature_weights_;
HostDeviceVector<float> feature_weights_;
float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f};
GlobalRandomEngine rng_;
Context const* ctx_;
// Used for weighted sampling.
HostDeviceVector<bst_feature_t> idx_buffer_;
HostDeviceVector<float> weight_buffer_;
public:
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample);
/**
* \brief Column sampler constructor.
* \note This constructor manually sets the rng seed
* @brief Column sampler constructor.
* @note This constructor manually sets the rng seed
*/
explicit ColumnSampler(uint32_t seed) {
rng_.seed(seed);
}
explicit ColumnSampler(std::uint32_t seed) { rng_.seed(seed); }
/**
* \brief Column sampler constructor.
* \note This constructor synchronizes the RNG seed across processes.
*/
ColumnSampler() {
uint32_t seed = common::GlobalRandom()();
collective::Broadcast(&seed, sizeof(seed), 0);
rng_.seed(seed);
}
/**
* \brief Initialise this object before use.
* @brief Initialise this object before use.
*
* \param num_col
* \param colsample_bynode
* \param colsample_bylevel
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
* @param num_col
* @param colsample_bynode Sampling rate for node.
* @param colsample_bylevel Sampling rate for tree level.
* @param colsample_bytree Sampling rate for tree.
*/
void Init(Context const* ctx, int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel, float colsample_bytree) {
feature_weights_ = std::move(feature_weights);
feature_weights_.HostVector() = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
colsample_bynode_ = colsample_bynode;
@ -173,8 +175,17 @@ class ColumnSampler {
}
Reset();
feature_set_tree_->SetDevice(ctx->Device());
feature_set_tree_->Resize(num_col);
if (ctx->IsCPU()) {
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
} else {
#if defined(XGBOOST_USE_CUDA)
cuda_impl::InitFeatureSet(ctx, feature_set_tree_);
#else
AssertGPUSupport();
#endif
}
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
}
@ -216,6 +227,11 @@ class ColumnSampler {
}
};
} // namespace common
} // namespace xgboost
inline auto MakeColumnSampler(Context const*) {
std::uint32_t seed = common::GlobalRandomEngine()();
collective::Broadcast(&seed, sizeof(seed), 0);
auto cs = std::make_shared<common::ColumnSampler>(seed);
return cs;
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_RANDOM_H_

View File

@ -360,7 +360,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) =
GPUBinaryROCAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
GPUBinaryROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
}
return std::make_tuple(fp, tp, auc);
}
@ -376,8 +376,9 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
.set_body([](const char*) { return new EvalROCAUC(); });
#if !defined(XGBOOST_USE_CUDA)
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, MetaInfo const &,
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
std::tuple<double, double, double> GPUBinaryROCAUC(Context const *, common::Span<float const>,
MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport();
return {};
}
@ -409,8 +410,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) =
GPUBinaryPRAUC(predts.ConstDeviceSpan(), info, ctx_->Device(), &this->d_cache_);
std::tie(pr, re, auc) = GPUBinaryPRAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
}
return std::make_tuple(pr, re, auc);
}
@ -453,8 +453,9 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
.set_body([](char const *) { return new EvalPRAUC{}; });
#if !defined(XGBOOST_USE_CUDA)
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, MetaInfo const &,
DeviceOrd, std::shared_ptr<DeviceAUCCache> *) {
std::tuple<double, double, double> GPUBinaryPRAUC(Context const *, common::Span<float const>,
MetaInfo const &,
std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport();
return {};
}

View File

@ -83,13 +83,14 @@ void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCa
* - Reduce the scan array into 1 AUC value.
*/
template <typename Fn>
std::tuple<double, double, double>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
DeviceOrd device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(device);
std::tuple<double, double, double> GPUBinaryAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
common::Span<size_t const> d_sorted_idx, Fn area_fn,
std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device.ordinal));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size());
@ -115,7 +116,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
@ -167,8 +168,9 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
return std::make_tuple(last.first, last.second, auc);
}
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd device,
std::tuple<double, double, double> GPUBinaryROCAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@ -177,10 +179,10 @@ std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> pre
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
common::ArgSort<false>(ctx, predts, d_sorted_idx);
// Create lambda to avoid pass function pointer.
return GPUBinaryAUC(
predts, info, device, d_sorted_idx,
ctx, predts, info, d_sorted_idx,
[] XGBOOST_DEVICE(double x0, double x1, double y0, double y1) -> double {
return TrapezoidArea(x0, x1, y0, y1);
},
@ -361,7 +363,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
uint32_t class_id = i / n_samples;
@ -603,8 +605,9 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
return std::make_pair(auc, n_valid);
}
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd device,
std::tuple<double, double, double> GPUBinaryPRAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@ -613,9 +616,9 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
common::ArgSort<false>(ctx, predts, d_sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = common::OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<Pair>(
@ -639,7 +642,7 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
};
double fp, tp, auc;
std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache);
std::tie(fp, tp, auc) = GPUBinaryAUC(ctx, predts, info, d_sorted_idx, fn, cache);
return std::make_tuple(1.0, 1.0, auc);
}
@ -699,16 +702,17 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
}
template <typename Fn>
std::pair<double, uint32_t>
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
common::Span<uint32_t> d_group_ptr, DeviceOrd device,
std::pair<double, uint32_t> GPURankingPRAUCImpl(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
common::Span<uint32_t> d_group_ptr,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
/**
* Sorted idx
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
uint32_t n_groups = static_cast<uint32_t>(info.group_ptr_.size() - 1);
@ -739,7 +743,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i];
@ -882,7 +886,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn);
return GPURankingPRAUCImpl(ctx, predts, info, d_group_ptr, cache, fn);
}
} // namespace metric
} // namespace xgboost

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2021 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#ifndef XGBOOST_METRIC_AUC_H_
#define XGBOOST_METRIC_AUC_H_
@ -18,8 +18,7 @@
#include "xgboost/metric.h"
#include "xgboost/span.h"
namespace xgboost {
namespace metric {
namespace xgboost::metric {
/***********
* ROC AUC *
***********/
@ -29,8 +28,9 @@ XGBOOST_DEVICE inline double TrapezoidArea(double x0, double x1, double y0, doub
struct DeviceAUCCache;
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd,
std::tuple<double, double, double> GPUBinaryROCAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
@ -44,8 +44,9 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
/**********
* PR AUC *
**********/
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd,
std::tuple<double, double, double> GPUBinaryPRAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache);
double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
@ -111,6 +112,5 @@ struct PRAUCLabelInvalid {
inline void InvalidLabels() {
LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank.";
}
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric
#endif // XGBOOST_METRIC_AUC_H_

View File

@ -13,9 +13,7 @@
#include "adaptive.h"
#include "xgboost/context.h"
namespace xgboost {
namespace obj {
namespace detail {
namespace xgboost::obj::detail {
void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
dh::device_vector<size_t>* p_ridx, HostDeviceVector<size_t>* p_nptr,
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
@ -28,7 +26,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
position.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream()));
p_ridx->resize(position.size());
dh::Iota(dh::ToSpan(*p_ridx));
dh::Iota(dh::ToSpan(*p_ridx), cuctx->Stream());
// sort row index according to node index
thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
sorted_position.begin() + n_samples, p_ridx->begin());
@ -190,6 +188,4 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
});
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
}
} // namespace detail
} // namespace obj
} // namespace xgboost
} // namespace xgboost::obj::detail

View File

@ -72,7 +72,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
dh::Iota(sorted_idx);
dh::Iota(sorted_idx, dh::DefaultStream());
auto data = this->SortInput(d_inputs.size(), shared_inputs.feature_values.size());
auto it = thrust::make_counting_iterator(0u);
auto d_feature_idx = dh::ToSpan(feature_idx_);

View File

@ -248,8 +248,7 @@ class GlobalApproxUpdater : public TreeUpdater {
std::unique_ptr<GloablApproxBuilder> pimpl_;
// pointer to the last DMatrix, used for update prediction cache.
DMatrix *cached_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
std::shared_ptr<common::ColumnSampler> column_sampler_;
ObjInfo const *task_;
HistMakerTrainParam hist_param_;
@ -284,6 +283,9 @@ class GlobalApproxUpdater : public TreeUpdater {
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
CHECK(hist_param_.GetInitialised());
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
column_sampler_, task_, &monitor_);

View File

@ -225,9 +225,12 @@ class ColMaker: public TreeUpdater {
}
}
{
column_sampler_.Init(ctx_, fmat.Info().num_col_,
fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
column_sampler_->Init(
ctx_, fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree);
}
{
// setup temp space for each thread
@ -467,7 +470,7 @@ class ColMaker: public TreeUpdater {
RegTree *p_tree) {
auto evaluator = tree_evaluator_.GetEvaluator();
auto feat_set = column_sampler_.GetFeatureSet(depth);
auto feat_set = column_sampler_->GetFeatureSet(depth);
for (const auto &batch : p_fmat->GetBatches<SortedCSCPage>(ctx_)) {
this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat);
}
@ -586,7 +589,7 @@ class ColMaker: public TreeUpdater {
const ColMakerTrainParam& colmaker_train_param_;
// number of omp thread used during training
Context const* ctx_;
common::ColumnSampler column_sampler_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
// Instance Data: current node position in the tree of each instance
std::vector<int> position_;
// PerThread x PerTreeNode: statistics for per thread construction

View File

@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2017-2023, XGBoost Contributors
* \file updater_quantile_hist.cc
* \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov
@ -470,8 +470,7 @@ class HistUpdater {
class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistUpdater> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
std::shared_ptr<common::ColumnSampler> column_sampler_;
common::Monitor monitor_;
ObjInfo const *task_{nullptr};
HistMakerTrainParam hist_param_;
@ -495,6 +494,10 @@ class QuantileHistMaker : public TreeUpdater {
void Update(TrainParam const *param, linalg::Matrix<GradientPair> *gpair, DMatrix *p_fmat,
common::Span<HostDeviceVector<bst_node_t>> out_position,
const std::vector<RegTree *> &trees) override {
if (!column_sampler_) {
column_sampler_ = common::MakeColumnSampler(ctx_);
}
if (trees.front()->IsMultiTarget()) {
CHECK(hist_param_.GetInitialised());
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();

View File

@ -57,13 +57,13 @@ TEST(Algorithm, GpuArgSort) {
auto ctx = MakeCUDACtx(0);
dh::device_vector<float> values(20);
dh::Iota(dh::ToSpan(values)); // accending
dh::Iota(dh::ToSpan(values), ctx.CUDACtx()->Stream()); // accending
dh::device_vector<size_t> sorted_idx(20);
dh::ArgSort<false>(dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(thrust::device, sorted_idx.begin(), sorted_idx.end(),
ArgSort<false>(&ctx, dh::ToSpan(values), dh::ToSpan(sorted_idx)); // sort to descending
ASSERT_TRUE(thrust::is_sorted(ctx.CUDACtx()->CTP(), sorted_idx.begin(), sorted_idx.end(),
thrust::greater<size_t>{}));
dh::Iota(dh::ToSpan(values));
dh::Iota(dh::ToSpan(values), ctx.CUDACtx()->Stream());
dh::device_vector<size_t> groups(3);
groups[0] = 0;
groups[1] = 10;

View File

@ -16,6 +16,7 @@
#include <vector> // for vector
#include "../../../include/xgboost/logging.h"
#include "../../../src/common/cuda_context.cuh"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/hist_util.h"
@ -211,7 +212,7 @@ TEST(HistUtil, RemoveDuplicatedCategories) {
cuts_ptr.SetDevice(DeviceOrd::CUDA(0));
dh::device_vector<float> weight(n_samples * n_features, 0);
dh::Iota(dh::ToSpan(weight));
dh::Iota(dh::ToSpan(weight), ctx.CUDACtx()->Stream());
dh::caching_device_vector<bst_row_t> columns_ptr(4);
for (std::size_t i = 0; i < columns_ptr.size(); ++i) {

View File

@ -1,19 +1,20 @@
#include <valarray>
/**
* Copyright 2018-2023, XGBoost Contributors
*/
#include "../../../src/common/random.h"
#include "../helpers.h"
#include "gtest/gtest.h"
#include "xgboost/context.h" // Context
#include "xgboost/context.h" // for Context
namespace xgboost {
namespace common {
TEST(ColumnSampler, Test) {
Context ctx;
namespace xgboost::common {
namespace {
void TestBasic(Context const* ctx) {
int n = 128;
ColumnSampler cs;
ColumnSampler cs{1u};
std::vector<float> feature_weights;
// No node sampling
cs.Init(&ctx, n, feature_weights, 1.0f, 0.5f, 0.5f);
cs.Init(ctx, n, feature_weights, 1.0f, 0.5f, 0.5f);
auto set0 = cs.GetFeatureSet(0);
ASSERT_EQ(set0->Size(), 32);
@ -26,7 +27,7 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set2->Size(), 32);
// Node sampling
cs.Init(&ctx, n, feature_weights, 0.5f, 1.0f, 0.5f);
cs.Init(ctx, n, feature_weights, 0.5f, 1.0f, 0.5f);
auto set3 = cs.GetFeatureSet(0);
ASSERT_EQ(set3->Size(), 32);
@ -36,21 +37,33 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set4->Size(), 32);
// No level or node sampling, should be the same at different depth
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
cs.GetFeatureSet(1)->HostVector());
cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(), cs.GetFeatureSet(1)->HostVector());
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set5 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->Size(), n);
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
cs.Init(ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set6 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->HostVector(), set6->HostVector());
// Should always be a minimum of one feature
cs.Init(&ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
cs.Init(ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
}
} // namespace
TEST(ColumnSampler, Test) {
Context ctx;
TestBasic(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
TEST(ColumnSampler, GPUTest) {
auto ctx = MakeCUDACtx(0);
TestBasic(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
// Test if different threads using the same seed produce the same result
TEST(ColumnSampler, ThreadSynchronisation) {
@ -81,16 +94,16 @@ TEST(ColumnSampler, ThreadSynchronisation) {
ASSERT_TRUE(success);
}
TEST(ColumnSampler, WeightedSampling) {
auto test_basic = [](int first) {
Context ctx;
namespace {
void TestWeightedSampling(Context const* ctx) {
auto test_basic = [ctx](int first) {
std::vector<float> feature_weights(2);
feature_weights[0] = std::abs(first - 1.0f);
feature_weights[1] = first - 0.0f;
ColumnSampler cs{0};
cs.Init(&ctx, 2, feature_weights, 1.0, 1.0, 0.5);
cs.Init(ctx, 2, feature_weights, 1.0, 1.0, 0.5);
auto feature_sets = cs.GetFeatureSet(0);
auto const &h_feat_set = feature_sets->HostVector();
auto const& h_feat_set = feature_sets->HostVector();
ASSERT_EQ(h_feat_set.size(), 1);
ASSERT_EQ(h_feat_set[0], first - 0);
};
@ -104,8 +117,7 @@ TEST(ColumnSampler, WeightedSampling) {
SimpleRealUniformDistribution<float> dist(.0f, 12.0f);
std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); });
ColumnSampler cs{0};
Context ctx;
cs.Init(&ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f);
cs.Init(ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f);
std::vector<bst_feature_t> features(kCols);
std::iota(features.begin(), features.end(), 0);
std::vector<float> freq(kCols, 0);
@ -131,8 +143,22 @@ TEST(ColumnSampler, WeightedSampling) {
EXPECT_NEAR(freq[i], feature_weights[i], 1e-2);
}
}
} // namespace
TEST(ColumnSampler, WeightedMultiSampling) {
TEST(ColumnSampler, WeightedSampling) {
Context ctx;
TestWeightedSampling(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
TEST(ColumnSampler, GPUWeightedSampling) {
auto ctx = MakeCUDACtx(0);
TestWeightedSampling(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
namespace {
void TestWeightedMultiSampling(Context const* ctx) {
size_t constexpr kCols = 32;
std::vector<float> feature_weights(kCols, 0);
for (size_t i = 0; i < feature_weights.size(); ++i) {
@ -140,13 +166,24 @@ TEST(ColumnSampler, WeightedMultiSampling) {
}
ColumnSampler cs{0};
float bytree{0.5}, bylevel{0.5}, bynode{0.5};
Context ctx;
cs.Init(&ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode);
cs.Init(ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode);
auto feature_set = cs.GetFeatureSet(0);
size_t n_sampled = kCols * bytree * bylevel * bynode;
ASSERT_EQ(feature_set->Size(), n_sampled);
feature_set = cs.GetFeatureSet(1);
ASSERT_EQ(feature_set->Size(), n_sampled);
}
} // namespace common
} // namespace xgboost
} // namespace
TEST(ColumnSampler, WeightedMultiSampling) {
Context ctx;
TestWeightedMultiSampling(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
TEST(ColumnSampler, GPUWeightedMultiSampling) {
auto ctx = MakeCUDACtx(0);
TestWeightedMultiSampling(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost::common

View File

@ -28,7 +28,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
Context ctx;
ctx.nthread = 4;
int static constexpr kRows = 8, kCols = 16;
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
@ -102,7 +102,7 @@ TEST(HistMultiEvaluator, Evaluate) {
TrainParam param;
param.Init(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
std::size_t n_samples = 3;
bst_feature_t n_features = 2;
@ -166,7 +166,7 @@ TEST(HistEvaluator, Apply) {
TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
auto evaluator_ = HistEvaluator{&ctx, &param, dmat->Info(), sampler};
CPUExpandEntry entry{0, 0};
@ -194,7 +194,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
Context ctx;
// check the evaluator is returning the optimal split
std::vector<FeatureType> ft{FeatureType::kCategorical};
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
HistEvaluator evaluator{&ctx, &param_, info_, sampler};
evaluator.InitRoot(GradStats{total_gpair_});
RegTree tree;
@ -224,7 +224,7 @@ auto CompareOneHotAndPartition(bool onehot) {
auto dmat =
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
auto evaluator = HistEvaluator{&ctx, &param, dmat->Info(), sampler};
std::vector<CPUExpandEntry> entries(1);
HistMakerTrainParam hist_param;
@ -271,7 +271,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
ASSERT_EQ(node_hist.size(), feature_histogram_.size());
std::copy(feature_histogram_.cbegin(), feature_histogram_.cend(), node_hist.begin());
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};

View File

@ -1,3 +1,6 @@
/**
* Copyright 2019-2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
@ -9,9 +12,7 @@
#include "../../../src/tree/hist/evaluate_splits.h"
#include "../helpers.h"
namespace xgboost {
namespace tree {
namespace xgboost::tree {
TEST(CPUFeatureInteractionConstraint, Empty) {
TrainParam param;
param.UpdateAllowUnknown(Args{});
@ -77,7 +78,7 @@ TEST(CPUMonoConstraint, Basic) {
param.UpdateAllowUnknown(Args{{"monotone_constraints", str_mono}});
auto Xy = RandomDataGenerator{kRows, kCols, 0.0}.GenerateDMatrix(true);
auto sampler = std::make_shared<common::ColumnSampler>();
auto sampler = std::make_shared<common::ColumnSampler>(1u);
HistEvaluator evalutor{&ctx, &param, Xy->Info(), sampler};
evalutor.InitRoot(GradStats{2.0, 2.0});
@ -90,5 +91,4 @@ TEST(CPUMonoConstraint, Basic) {
ASSERT_TRUE(evalutor.Evaluator().has_constraint);
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree