Specify the number of threads for parallel sort. (#8735)
* Specify the number of threads for parallel sort. - Pass context object into argsort. - Replace macros with inline functions.
This commit is contained in:
@@ -1,10 +1,31 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_ALGORITHM_H_
|
||||
#define XGBOOST_COMMON_ALGORITHM_H_
|
||||
#include <algorithm> // std::upper_bound
|
||||
#include <cinttypes> // std::size_t
|
||||
#include <algorithm> // upper_bound, stable_sort, sort, max
|
||||
#include <cinttypes> // size_t
|
||||
#include <functional> // less
|
||||
#include <iterator> // iterator_traits, distance
|
||||
#include <vector> // vector
|
||||
|
||||
#include "numeric.h" // Iota
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
// clang with libstdc++ works as well
|
||||
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && !defined(__APPLE__)
|
||||
#define GCC_HAS_PARALLEL 1
|
||||
#endif // GLIC_VERSION
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
||||
#define MSVC_HAS_PARALLEL 1
|
||||
#endif // MSC
|
||||
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
#include <parallel/algorithm>
|
||||
#elif defined(MSVC_HAS_PARALLEL)
|
||||
#include <ppl.h>
|
||||
#endif // GLIBC VERSION
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@@ -13,6 +34,63 @@ auto SegmentId(It first, It last, Idx idx) {
|
||||
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
|
||||
return segment_id;
|
||||
}
|
||||
|
||||
template <typename Iter, typename Comp>
|
||||
void StableSort(Context const *ctx, Iter begin, Iter end, Comp &&comp) {
|
||||
if (ctx->Threads() > 1) {
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
__gnu_parallel::stable_sort(begin, end, comp,
|
||||
__gnu_parallel::default_parallel_tag(ctx->Threads()));
|
||||
#else
|
||||
// the only stable sort is radix sort for msvc ppl.
|
||||
std::stable_sort(begin, end, comp);
|
||||
#endif // GLIBC VERSION
|
||||
} else {
|
||||
std::stable_sort(begin, end, comp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Iter, typename Comp>
|
||||
void Sort(Context const *ctx, Iter begin, Iter end, Comp comp) {
|
||||
if (ctx->Threads() > 1) {
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
__gnu_parallel::sort(begin, end, comp, __gnu_parallel::default_parallel_tag(ctx->Threads()));
|
||||
#elif defined(MSVC_HAS_PARALLEL)
|
||||
auto n = std::distance(begin, end);
|
||||
// use chunk size as hint to number of threads. No local policy/scheduler input with the
|
||||
// concurrency module.
|
||||
std::size_t chunk_size = n / ctx->Threads();
|
||||
// 2048 is the default of msvc ppl as of v2022.
|
||||
chunk_size = std::max(chunk_size, static_cast<std::size_t>(2048));
|
||||
concurrency::parallel_sort(begin, end, comp, chunk_size);
|
||||
#else
|
||||
std::sort(begin, end, comp);
|
||||
#endif // GLIBC VERSION
|
||||
} else {
|
||||
std::sort(begin, end, comp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Idx, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type,
|
||||
typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = std::less<V>{}) {
|
||||
CHECK(ctx->IsCPU());
|
||||
auto n = std::distance(begin, end);
|
||||
std::vector<Idx> result(n);
|
||||
Iota(ctx, result.begin(), result.end(), 0);
|
||||
auto op = [&](Idx const &l, Idx const &r) { return comp(begin[l], begin[r]); };
|
||||
StableSort(ctx, result.begin(), result.end(), op);
|
||||
return result;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
#undef GCC_HAS_PARALLEL
|
||||
#endif // defined(GCC_HAS_PARALLEL)
|
||||
|
||||
#if defined(MSVC_HAS_PARALLEL)
|
||||
#undef MSVC_HAS_PARALLEL
|
||||
#endif // defined(MSVC_HAS_PARALLEL)
|
||||
|
||||
#endif // XGBOOST_COMMON_ALGORITHM_H_
|
||||
|
||||
@@ -188,17 +188,6 @@ inline void SetDevice(std::int32_t device) {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Idx, typename Container,
|
||||
typename V = typename Container::value_type,
|
||||
typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
|
||||
std::vector<Idx> result(array.size());
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); };
|
||||
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Last index of a group in a CSR style of index pointer.
|
||||
*/
|
||||
|
||||
@@ -24,8 +24,9 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
for (size_t i = 0; i < h_features.size(); ++i) {
|
||||
weights[i] = feature_weights_[h_features[i]];
|
||||
}
|
||||
CHECK(ctx_);
|
||||
new_features.HostVector() =
|
||||
WeightedSamplingWithoutReplacement(p_features->HostVector(), weights, n);
|
||||
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
|
||||
} else {
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
|
||||
|
||||
@@ -20,7 +20,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "algorithm.h" // ArgSort
|
||||
#include "common.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -87,8 +89,8 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vector<T> const& array,
|
||||
std::vector<float> const& weights, size_t n) {
|
||||
// ES sampling.
|
||||
CHECK_EQ(array.size(), weights.size());
|
||||
std::vector<float> keys(weights.size());
|
||||
@@ -100,7 +102,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
auto k = std::log(u) / w;
|
||||
keys[i] = k;
|
||||
}
|
||||
auto ind = ArgSort<size_t>(Span<float>{keys}, std::greater<>{});
|
||||
auto ind = ArgSort<std::size_t>(ctx, keys.data(), keys.data() + keys.size(), std::greater<>{});
|
||||
ind.resize(n);
|
||||
|
||||
std::vector<T> results(ind.size());
|
||||
@@ -126,6 +128,7 @@ class ColumnSampler {
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
Context const* ctx_;
|
||||
|
||||
public:
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
@@ -157,12 +160,13 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
|
||||
float colsample_bylevel, float colsample_bytree) {
|
||||
void Init(Context const* ctx, int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel, float colsample_bytree) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
ctx_ = ctx;
|
||||
|
||||
if (feature_set_tree_ == nullptr) {
|
||||
feature_set_tree_ = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
|
||||
Reference in New Issue
Block a user