Implement column sampler in CUDA. (#9785)
- CUDA implementation. - Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation. - Some changes to the event loop for fixing a deadlock in CI. - Move argsort into algorithms.cuh, add support for cuda stream.
This commit is contained in:
@@ -23,8 +23,7 @@
|
||||
#include "xgboost/logging.h" // CHECK
|
||||
#include "xgboost/span.h" // Span,byte
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
namespace detail {
|
||||
// Wrapper around cub sort to define is_decending
|
||||
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
|
||||
@@ -127,13 +126,14 @@ inline void SegmentedSortKeys(Context const *ctx, Span<V const> group_ptr,
|
||||
template <bool accending, bool per_seg_index, typename U, typename V, typename IdxT>
|
||||
void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
|
||||
Span<IdxT> sorted_idx) {
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
CHECK_GE(group_ptr.size(), 1ul);
|
||||
std::size_t n_groups = group_ptr.size() - 1;
|
||||
std::size_t bytes = 0;
|
||||
if (per_seg_index) {
|
||||
SegmentedSequence(ctx, group_ptr, sorted_idx);
|
||||
} else {
|
||||
dh::Iota(sorted_idx);
|
||||
dh::Iota(sorted_idx, cuctx->Stream());
|
||||
}
|
||||
dh::TemporaryArray<std::remove_const_t<U>> values_out(values.size());
|
||||
dh::TemporaryArray<std::remove_const_t<IdxT>> sorted_idx_out(sorted_idx.size());
|
||||
@@ -141,15 +141,16 @@ void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
|
||||
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||
nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
|
||||
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
|
||||
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
|
||||
group_ptr.data() + 1, cuctx->Stream());
|
||||
dh::TemporaryArray<byte> temp_storage(bytes);
|
||||
detail::DeviceSegmentedRadixSortPair<!accending>(
|
||||
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(),
|
||||
sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(),
|
||||
group_ptr.data() + 1, ctx->CUDACtx()->Stream());
|
||||
group_ptr.data() + 1, cuctx->Stream());
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
|
||||
cuctx->Stream()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -159,11 +160,12 @@ void SegmentedArgSort(Context const *ctx, Span<U> values, Span<V> group_ptr,
|
||||
template <typename SegIt, typename ValIt>
|
||||
void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin,
|
||||
ValIt val_end, dh::device_vector<std::size_t> *p_sorted_idx) {
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
using Tup = thrust::tuple<std::int32_t, float>;
|
||||
auto &sorted_idx = *p_sorted_idx;
|
||||
std::size_t n = std::distance(val_begin, val_end);
|
||||
sorted_idx.resize(n);
|
||||
dh::Iota(dh::ToSpan(sorted_idx));
|
||||
dh::Iota(dh::ToSpan(sorted_idx), cuctx->Stream());
|
||||
dh::device_vector<Tup> keys(sorted_idx.size());
|
||||
auto key_it = dh::MakeTransformIterator<Tup>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) -> Tup {
|
||||
@@ -177,7 +179,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
|
||||
return thrust::make_tuple(seg_idx, residue);
|
||||
});
|
||||
thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin());
|
||||
thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
|
||||
thrust::stable_sort_by_key(cuctx->TP(), keys.begin(), keys.end(), sorted_idx.begin(),
|
||||
[=] XGBOOST_DEVICE(Tup const &l, Tup const &r) {
|
||||
if (thrust::get<0>(l) != thrust::get<0>(r)) {
|
||||
return thrust::get<0>(l) < thrust::get<0>(r); // segment index
|
||||
@@ -185,6 +187,75 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
|
||||
return thrust::get<1>(l) < thrust::get<1>(r); // residue
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
template <bool accending, typename IdxT, typename U>
|
||||
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
|
||||
xgboost::common::Span<IdxT> sorted_idx) {
|
||||
std::size_t bytes = 0;
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
dh::Iota(sorted_idx, cuctx->Stream());
|
||||
|
||||
using KeyT = typename decltype(keys)::value_type;
|
||||
using ValueT = std::remove_const_t<IdxT>;
|
||||
|
||||
dh::TemporaryArray<KeyT> out(keys.size());
|
||||
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()), out.data().get());
|
||||
dh::TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
|
||||
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
|
||||
sorted_idx_out.data().get());
|
||||
|
||||
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
|
||||
using OffsetT = std::conditional_t<!dh::BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
|
||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
cuctx->Stream())));
|
||||
#else
|
||||
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
nullptr, false)));
|
||||
#endif
|
||||
dh::TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
cuctx->Stream())));
|
||||
#else
|
||||
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
nullptr, false)));
|
||||
#endif
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
cuctx->Stream())));
|
||||
#else
|
||||
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
nullptr, false)));
|
||||
#endif
|
||||
dh::TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
cuctx->Stream())));
|
||||
#else
|
||||
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
|
||||
nullptr, false)));
|
||||
#endif
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
|
||||
cuctx->Stream()));
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
|
||||
|
||||
@@ -313,8 +313,8 @@ inline void LaunchN(size_t n, L lambda) {
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
void Iota(Container array) {
|
||||
LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; });
|
||||
void Iota(Container array, cudaStream_t stream) {
|
||||
LaunchN(array.size(), stream, [=] __device__(size_t i) { array[i] = i; });
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
@@ -597,6 +597,16 @@ class DoubleBuffer {
|
||||
T *Other() { return buff.Alternate(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
xgboost::common::Span<T> LazyResize(xgboost::Context const *ctx,
|
||||
xgboost::HostDeviceVector<T> *buffer, std::size_t n) {
|
||||
buffer->SetDevice(ctx->Device());
|
||||
if (buffer->Size() < n) {
|
||||
buffer->Resize(n);
|
||||
}
|
||||
return buffer->DeviceSpan().subspan(0, n);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Copies device span to std::vector.
|
||||
*
|
||||
@@ -1060,74 +1070,6 @@ void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items)
|
||||
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
|
||||
}
|
||||
|
||||
template <bool accending, typename IdxT, typename U>
|
||||
void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
|
||||
size_t bytes = 0;
|
||||
Iota(sorted_idx);
|
||||
|
||||
using KeyT = typename decltype(keys)::value_type;
|
||||
using ValueT = std::remove_const_t<IdxT>;
|
||||
|
||||
TemporaryArray<KeyT> out(keys.size());
|
||||
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
|
||||
out.data().get());
|
||||
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
|
||||
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
|
||||
sorted_idx_out.data().get());
|
||||
|
||||
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
|
||||
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
|
||||
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
#if THRUST_MAJOR_VERSION >= 2
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr)));
|
||||
#else
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
#endif
|
||||
}
|
||||
|
||||
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
class CUDAStreamView;
|
||||
|
||||
class CUDAEvent {
|
||||
|
||||
@@ -1,32 +1,50 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
* \file random.cc
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost Contributors
|
||||
*/
|
||||
#include "random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
#include <algorithm> // for sort, max, copy
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
|
||||
namespace xgboost::common {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
|
||||
if (colsample == 1.0f) {
|
||||
return p_features;
|
||||
}
|
||||
|
||||
int n = std::max(1, static_cast<int>(colsample * p_features->Size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
|
||||
if (ctx_->IsCUDA()) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cuda_impl::SampleFeature(ctx_, n, p_features, p_new_features, this->feature_weights_,
|
||||
&this->weight_buffer_, &this->idx_buffer_, &rng_);
|
||||
return p_new_features;
|
||||
#else
|
||||
AssertGPUSupport();
|
||||
return nullptr;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
const auto &features = p_features->HostVector();
|
||||
CHECK_GT(features.size(), 0);
|
||||
|
||||
int n = std::max(1, static_cast<int>(colsample * features.size()));
|
||||
auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
auto &new_features = *p_new_features;
|
||||
|
||||
if (feature_weights_.size() != 0) {
|
||||
if (!feature_weights_.Empty()) {
|
||||
auto const &h_features = p_features->HostVector();
|
||||
std::vector<float> weights(h_features.size());
|
||||
auto const &h_feature_weight = feature_weights_.ConstHostVector();
|
||||
auto &weight = this->weight_buffer_.HostVector();
|
||||
weight.resize(h_features.size());
|
||||
for (size_t i = 0; i < h_features.size(); ++i) {
|
||||
weights[i] = feature_weights_[h_features[i]];
|
||||
weight[i] = h_feature_weight[h_features[i]];
|
||||
}
|
||||
CHECK(ctx_);
|
||||
new_features.HostVector() =
|
||||
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
|
||||
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weight, n);
|
||||
} else {
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
|
||||
@@ -36,5 +54,4 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
std::sort(new_features.HostVector().begin(), new_features.HostVector().end());
|
||||
return p_new_features;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
106
src/common/random.cu
Normal file
106
src/common/random.cu
Normal file
@@ -0,0 +1,106 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/shuffle.h> // for shuffle
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "algorithm.cuh" // for ArgSort
|
||||
#include "cuda_context.cuh" // for CUDAContext
|
||||
#include "device_helpers.cuh"
|
||||
#include "random.h"
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
|
||||
namespace xgboost::common::cuda_impl {
|
||||
// GPU implementation for sampling without replacement, see the CPU version for references.
|
||||
void WeightedSamplingWithoutReplacement(Context const *ctx, common::Span<bst_feature_t const> array,
|
||||
common::Span<float const> weights,
|
||||
common::Span<bst_feature_t> results,
|
||||
HostDeviceVector<bst_feature_t> *sorted_idx,
|
||||
GlobalRandomEngine *grng) {
|
||||
CUDAContext const *cuctx = ctx->CUDACtx();
|
||||
CHECK_EQ(array.size(), weights.size());
|
||||
// Sampling keys
|
||||
dh::caching_device_vector<float> keys(weights.size());
|
||||
|
||||
auto d_keys = dh::ToSpan(keys);
|
||||
|
||||
auto seed = (*grng)();
|
||||
constexpr auto kEps = kRtEps; // avoid CUDA compilation error
|
||||
thrust::for_each_n(cuctx->CTP(), thrust::make_counting_iterator(0ul), array.size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||
thrust::default_random_engine rng;
|
||||
rng.seed(seed);
|
||||
rng.discard(i);
|
||||
thrust::uniform_real_distribution<float> dist;
|
||||
|
||||
auto w = std::max(weights[i], kEps);
|
||||
auto u = dist(rng);
|
||||
auto k = std::log(u) / w;
|
||||
d_keys[i] = k;
|
||||
});
|
||||
// Allocate buffer for sorted index.
|
||||
auto d_idx = dh::LazyResize(ctx, sorted_idx, keys.size());
|
||||
|
||||
ArgSort<false>(ctx, d_keys, d_idx);
|
||||
|
||||
// Filter the result according to sorted index.
|
||||
auto it = thrust::make_permutation_iterator(dh::tbegin(array), dh::tbegin(d_idx));
|
||||
// |array| == |weights| == |keys| == |sorted_idx| >= |results|
|
||||
for (auto size : {array.size(), weights.size(), keys.size()}) {
|
||||
CHECK_EQ(size, d_idx.size());
|
||||
}
|
||||
CHECK_GE(array.size(), results.size());
|
||||
thrust::copy_n(cuctx->CTP(), it, results.size(), dh::tbegin(results));
|
||||
}
|
||||
|
||||
void SampleFeature(Context const *ctx, bst_feature_t n_features,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
|
||||
HostDeviceVector<float> const &feature_weights,
|
||||
HostDeviceVector<float> *weight_buffer,
|
||||
HostDeviceVector<bst_feature_t> *idx_buffer, GlobalRandomEngine *grng) {
|
||||
CUDAContext const *cuctx = ctx->CUDACtx();
|
||||
auto &new_features = *p_new_features;
|
||||
new_features.SetDevice(ctx->Device());
|
||||
p_features->SetDevice(ctx->Device());
|
||||
CHECK_LE(n_features, p_features->Size());
|
||||
|
||||
if (!feature_weights.Empty()) {
|
||||
CHECK_LE(p_features->Size(), feature_weights.Size());
|
||||
idx_buffer->SetDevice(ctx->Device());
|
||||
feature_weights.SetDevice(ctx->Device());
|
||||
|
||||
auto d_old_features = p_features->DeviceSpan();
|
||||
auto d_weight_buffer = dh::LazyResize(ctx, weight_buffer, d_old_features.size());
|
||||
// Filter weights according to the existing feature index.
|
||||
auto d_feature_weight = feature_weights.ConstDeviceSpan();
|
||||
auto it = thrust::make_permutation_iterator(dh::tcbegin(d_feature_weight),
|
||||
dh::tcbegin(d_old_features));
|
||||
thrust::copy_n(cuctx->CTP(), it, d_old_features.size(), dh::tbegin(d_weight_buffer));
|
||||
new_features.Resize(n_features);
|
||||
WeightedSamplingWithoutReplacement(ctx, d_old_features, d_weight_buffer,
|
||||
new_features.DeviceSpan(), idx_buffer, grng);
|
||||
} else {
|
||||
new_features.Resize(p_features->Size());
|
||||
new_features.Copy(*p_features);
|
||||
auto d_feat = new_features.DeviceSpan();
|
||||
thrust::default_random_engine rng;
|
||||
rng.seed((*grng)());
|
||||
thrust::shuffle(cuctx->CTP(), dh::tbegin(d_feat), dh::tend(d_feat), rng);
|
||||
new_features.Resize(n_features);
|
||||
}
|
||||
|
||||
auto d_new_features = new_features.DeviceSpan();
|
||||
thrust::sort(cuctx->CTP(), dh::tbegin(d_new_features), dh::tend(d_new_features));
|
||||
}
|
||||
|
||||
void InitFeatureSet(Context const *ctx,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features) {
|
||||
CUDAContext const *cuctx = ctx->CUDACtx();
|
||||
auto d_features = p_features->DeviceSpan();
|
||||
thrust::sequence(cuctx->CTP(), dh::tbegin(d_features), dh::tend(d_features), 0);
|
||||
}
|
||||
} // namespace xgboost::common::cuda_impl
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2020, XGBoost Contributors
|
||||
* \file random.h
|
||||
* \brief Utility related to random.
|
||||
* \author Tianqi Chen
|
||||
@@ -25,8 +25,7 @@
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
/*!
|
||||
* \brief Define mt19937 as default type Random Engine.
|
||||
*/
|
||||
@@ -113,6 +112,18 @@ std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto
|
||||
return results;
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
void SampleFeature(Context const* ctx, bst_feature_t n_features,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_new_features,
|
||||
HostDeviceVector<float> const& feature_weights,
|
||||
HostDeviceVector<float>* weight_buffer,
|
||||
HostDeviceVector<bst_feature_t>* idx_buffer, GlobalRandomEngine* grng);
|
||||
|
||||
void InitFeatureSet(Context const* ctx,
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features);
|
||||
} // namespace cuda_impl
|
||||
|
||||
/**
|
||||
* \class ColumnSampler
|
||||
*
|
||||
@@ -123,46 +134,37 @@ std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vecto
|
||||
class ColumnSampler {
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
|
||||
std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
|
||||
std::vector<float> feature_weights_;
|
||||
HostDeviceVector<float> feature_weights_;
|
||||
float colsample_bylevel_{1.0f};
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
Context const* ctx_;
|
||||
|
||||
// Used for weighted sampling.
|
||||
HostDeviceVector<bst_feature_t> idx_buffer_;
|
||||
HostDeviceVector<float> weight_buffer_;
|
||||
|
||||
public:
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample);
|
||||
/**
|
||||
* \brief Column sampler constructor.
|
||||
* \note This constructor manually sets the rng seed
|
||||
* @brief Column sampler constructor.
|
||||
* @note This constructor manually sets the rng seed
|
||||
*/
|
||||
explicit ColumnSampler(uint32_t seed) {
|
||||
rng_.seed(seed);
|
||||
}
|
||||
explicit ColumnSampler(std::uint32_t seed) { rng_.seed(seed); }
|
||||
|
||||
/**
|
||||
* \brief Column sampler constructor.
|
||||
* \note This constructor synchronizes the RNG seed across processes.
|
||||
*/
|
||||
ColumnSampler() {
|
||||
uint32_t seed = common::GlobalRandom()();
|
||||
collective::Broadcast(&seed, sizeof(seed), 0);
|
||||
rng_.seed(seed);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Initialise this object before use.
|
||||
* @brief Initialise this object before use.
|
||||
*
|
||||
* \param num_col
|
||||
* \param colsample_bynode
|
||||
* \param colsample_bylevel
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
* @param num_col
|
||||
* @param colsample_bynode Sampling rate for node.
|
||||
* @param colsample_bylevel Sampling rate for tree level.
|
||||
* @param colsample_bytree Sampling rate for tree.
|
||||
*/
|
||||
void Init(Context const* ctx, int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel, float colsample_bytree) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
feature_weights_.HostVector() = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
@@ -173,8 +175,17 @@ class ColumnSampler {
|
||||
}
|
||||
Reset();
|
||||
|
||||
feature_set_tree_->SetDevice(ctx->Device());
|
||||
feature_set_tree_->Resize(num_col);
|
||||
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
|
||||
if (ctx->IsCPU()) {
|
||||
std::iota(feature_set_tree_->HostVector().begin(), feature_set_tree_->HostVector().end(), 0);
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cuda_impl::InitFeatureSet(ctx, feature_set_tree_);
|
||||
#else
|
||||
AssertGPUSupport();
|
||||
#endif
|
||||
}
|
||||
|
||||
feature_set_tree_ = ColSample(feature_set_tree_, colsample_bytree_);
|
||||
}
|
||||
@@ -216,6 +227,11 @@ class ColumnSampler {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
inline auto MakeColumnSampler(Context const*) {
|
||||
std::uint32_t seed = common::GlobalRandomEngine()();
|
||||
collective::Broadcast(&seed, sizeof(seed), 0);
|
||||
auto cs = std::make_shared<common::ColumnSampler>(seed);
|
||||
return cs;
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
#endif // XGBOOST_COMMON_RANDOM_H_
|
||||
|
||||
Reference in New Issue
Block a user