Specify the number of threads for parallel sort. (#8735)

* Specify the number of threads for parallel sort.

- Pass context object into argsort.
- Replace macros with inline functions.
This commit is contained in:
Jiaming Yuan 2023-02-16 00:20:19 +08:00 committed by GitHub
parent c7c485d052
commit 282b1729da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 254 additions and 143 deletions

View File

@ -48,21 +48,6 @@
#define XGBOOST_ALIGNAS(X)
#endif // defined(__GNUC__) && ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ > 4)
#if defined(__GNUC__) && ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ > 4) && \
!defined(__CUDACC__) && !defined(__sun) && !defined(sun)
#include <parallel/algorithm>
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) \
__gnu_parallel::stable_sort((X), (Y), (Z))
#elif defined(_MSC_VER) && (!__INTEL_COMPILER)
#include <ppl.h>
#define XGBOOST_PARALLEL_SORT(X, Y, Z) concurrency::parallel_sort((X), (Y), (Z))
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
#else
#define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z))
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
#endif // GLIBC VERSION
#if defined(__GNUC__)
#define XGBOOST_EXPECT(cond, ret) __builtin_expect((cond), (ret))
#else

View File

@ -124,18 +124,7 @@ class MetaInfo {
return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
}
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const {
if (label_order_cache_.size() == labels.Size()) {
return label_order_cache_;
}
label_order_cache_.resize(labels.Size());
std::iota(label_order_cache_.begin(), label_order_cache_.end(), 0);
const auto& l = labels.Data()->HostVector();
XGBOOST_PARALLEL_STABLE_SORT(label_order_cache_.begin(), label_order_cache_.end(),
[&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);});
return label_order_cache_;
}
const std::vector<size_t>& LabelAbsSort(Context const* ctx) const;
/*! \brief clear all the information */
void Clear();
/*!

View File

@ -23,6 +23,10 @@
#include <utility>
#include <vector>
#if defined(_MSC_VER)
#include <intrin.h>
#endif // defined(_MSC_VER)
// decouple it from xgboost.
#ifndef LINALG_HD
#if defined(__CUDA__) || defined(__NVCC__)

View File

@ -1,10 +1,31 @@
/*!
* Copyright 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_ALGORITHM_H_
#define XGBOOST_COMMON_ALGORITHM_H_
#include <algorithm> // std::upper_bound
#include <cinttypes> // std::size_t
#include <algorithm> // upper_bound, stable_sort, sort, max
#include <cinttypes> // size_t
#include <functional> // less
#include <iterator> // iterator_traits, distance
#include <vector> // vector
#include "numeric.h" // Iota
#include "xgboost/context.h" // Context
// clang with libstdc++ works as well
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && !defined(__APPLE__)
#define GCC_HAS_PARALLEL 1
#endif // GLIC_VERSION
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
#define MSVC_HAS_PARALLEL 1
#endif // MSC
#if defined(GCC_HAS_PARALLEL)
#include <parallel/algorithm>
#elif defined(MSVC_HAS_PARALLEL)
#include <ppl.h>
#endif // GLIBC VERSION
namespace xgboost {
namespace common {
@ -13,6 +34,63 @@ auto SegmentId(It first, It last, Idx idx) {
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
return segment_id;
}
template <typename Iter, typename Comp>
void StableSort(Context const *ctx, Iter begin, Iter end, Comp &&comp) {
if (ctx->Threads() > 1) {
#if defined(GCC_HAS_PARALLEL)
__gnu_parallel::stable_sort(begin, end, comp,
__gnu_parallel::default_parallel_tag(ctx->Threads()));
#else
// the only stable sort is radix sort for msvc ppl.
std::stable_sort(begin, end, comp);
#endif // GLIBC VERSION
} else {
std::stable_sort(begin, end, comp);
}
}
template <typename Iter, typename Comp>
void Sort(Context const *ctx, Iter begin, Iter end, Comp comp) {
if (ctx->Threads() > 1) {
#if defined(GCC_HAS_PARALLEL)
__gnu_parallel::sort(begin, end, comp, __gnu_parallel::default_parallel_tag(ctx->Threads()));
#elif defined(MSVC_HAS_PARALLEL)
auto n = std::distance(begin, end);
// use chunk size as hint to number of threads. No local policy/scheduler input with the
// concurrency module.
std::size_t chunk_size = n / ctx->Threads();
// 2048 is the default of msvc ppl as of v2022.
chunk_size = std::max(chunk_size, static_cast<std::size_t>(2048));
concurrency::parallel_sort(begin, end, comp, chunk_size);
#else
std::sort(begin, end, comp);
#endif // GLIBC VERSION
} else {
std::sort(begin, end, comp);
}
}
template <typename Idx, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type,
typename Comp = std::less<V>>
std::vector<Idx> ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = std::less<V>{}) {
CHECK(ctx->IsCPU());
auto n = std::distance(begin, end);
std::vector<Idx> result(n);
Iota(ctx, result.begin(), result.end(), 0);
auto op = [&](Idx const &l, Idx const &r) { return comp(begin[l], begin[r]); };
StableSort(ctx, result.begin(), result.end(), op);
return result;
}
} // namespace common
} // namespace xgboost
#if defined(GCC_HAS_PARALLEL)
#undef GCC_HAS_PARALLEL
#endif // defined(GCC_HAS_PARALLEL)
#if defined(MSVC_HAS_PARALLEL)
#undef MSVC_HAS_PARALLEL
#endif // defined(MSVC_HAS_PARALLEL)
#endif // XGBOOST_COMMON_ALGORITHM_H_

View File

@ -188,17 +188,6 @@ inline void SetDevice(std::int32_t device) {
}
#endif
template <typename Idx, typename Container,
typename V = typename Container::value_type,
typename Comp = std::less<V>>
std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
std::vector<Idx> result(array.size());
std::iota(result.begin(), result.end(), 0);
auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); };
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
return result;
}
/**
* Last index of a group in a CSR style of index pointer.
*/

View File

@ -24,8 +24,9 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
for (size_t i = 0; i < h_features.size(); ++i) {
weights[i] = feature_weights_[h_features[i]];
}
CHECK(ctx_);
new_features.HostVector() =
WeightedSamplingWithoutReplacement(p_features->HostVector(), weights, n);
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
} else {
new_features.Resize(features.size());
std::copy(features.begin(), features.end(), new_features.HostVector().begin());

View File

@ -20,7 +20,9 @@
#include <vector>
#include "../collective/communicator-inl.h"
#include "algorithm.h" // ArgSort
#include "common.h"
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h"
namespace xgboost {
@ -87,8 +89,8 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
*/
template <typename T>
std::vector<T> WeightedSamplingWithoutReplacement(
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vector<T> const& array,
std::vector<float> const& weights, size_t n) {
// ES sampling.
CHECK_EQ(array.size(), weights.size());
std::vector<float> keys(weights.size());
@ -100,7 +102,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(
auto k = std::log(u) / w;
keys[i] = k;
}
auto ind = ArgSort<size_t>(Span<float>{keys}, std::greater<>{});
auto ind = ArgSort<std::size_t>(ctx, keys.data(), keys.data() + keys.size(), std::greater<>{});
ind.resize(n);
std::vector<T> results(ind.size());
@ -126,6 +128,7 @@ class ColumnSampler {
float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f};
GlobalRandomEngine rng_;
Context const* ctx_;
public:
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
@ -157,12 +160,13 @@ class ColumnSampler {
* \param colsample_bytree
* \param skip_index_0 (Optional) True to skip index 0.
*/
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
float colsample_bylevel, float colsample_bytree) {
void Init(Context const* ctx, int64_t num_col, std::vector<float> feature_weights,
float colsample_bynode, float colsample_bylevel, float colsample_bytree) {
feature_weights_ = std::move(feature_weights);
colsample_bylevel_ = colsample_bylevel;
colsample_bytree_ = colsample_bytree;
colsample_bynode_ = colsample_bynode;
ctx_ = ctx;
if (feature_set_tree_ == nullptr) {
feature_set_tree_ = std::make_shared<HostDeviceVector<bst_feature_t>>();

View File

@ -10,12 +10,13 @@
#include <cstring>
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // StableSort
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h"
#include "../common/numeric.h" // Iota
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
@ -258,6 +259,19 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
}
}
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
if (label_order_cache_.size() == labels.Size()) {
return label_order_cache_;
}
label_order_cache_.resize(labels.Size());
common::Iota(ctx, label_order_cache_.begin(), label_order_cache_.end(), 0);
const auto& l = labels.Data()->HostVector();
common::StableSort(ctx, label_order_cache_.begin(), label_order_cache_.end(),
[&l](size_t i1, size_t i2) { return std::abs(l[i1]) < std::abs(l[i2]); });
return label_order_cache_;
}
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
auto version = Version::Load(fi);
auto major = std::get<0>(version);

View File

@ -14,9 +14,11 @@
#include <utility>
#include <vector>
#include "../common/algorithm.h" // ArgSort
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "metric_common.h" // MetricNoCache
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/linalg.h"
#include "xgboost/metric.h"
@ -77,9 +79,8 @@ BinaryAUC(common::Span<float const> predts, linalg::VectorView<float const> labe
* Machine Learning Models
*/
template <typename BinaryAUC>
double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads,
BinaryAUC &&binary_auc) {
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
CHECK_NE(n_classes, 0);
auto const labels = info.labels.View(Context::kCpuId);
if (labels.Shape(0) != 0) {
@ -108,7 +109,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
}
double fp;
std::tie(fp, tp(c), auc(c)) =
binary_auc(proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
local_area(c) = fp * tp(c);
});
}
@ -139,23 +140,26 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
return auc_sum;
}
std::tuple<double, double, double> BinaryROCAUC(common::Span<float const> predts,
std::tuple<double, double, double> BinaryROCAUC(Context const *ctx,
common::Span<float const> predts,
linalg::VectorView<float const> labels,
common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
auto const sorted_idx =
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
}
/**
* Calculate AUC for 1 ranking group;
*/
double GroupRankingROC(common::Span<float const> predts,
double GroupRankingROC(Context const* ctx, common::Span<float const> predts,
linalg::VectorView<float const> labels, float w) {
// on ranking, we just count all pairs.
double auc{0};
// argsort doesn't support tensor input yet.
auto raw_labels = labels.Values().subspan(0, labels.Size());
auto const sorted_idx = common::ArgSort<size_t>(raw_labels, std::greater<>{});
auto const sorted_idx = common::ArgSort<size_t>(
ctx, raw_labels.data(), raw_labels.data() + raw_labels.size(), std::greater<>{});
w = common::Sqr(w);
double sum_w = 0.0f;
@ -185,10 +189,11 @@ double GroupRankingROC(common::Span<float const> predts,
*
* https://doi.org/10.1371/journal.pone.0092209
*/
std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
std::tuple<double, double, double> BinaryPRAUC(Context const *ctx, common::Span<float const> predts,
linalg::VectorView<float const> labels,
common::OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
auto const sorted_idx =
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
double total_pos{0}, total_neg{0};
for (size_t i = 0; i < labels.Size(); ++i) {
auto w = weights[i];
@ -211,9 +216,8 @@ std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
* Cast LTR problem to binary classification problem by comparing pairs.
*/
template <bool is_roc>
std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info,
int32_t n_threads) {
std::pair<double, uint32_t> RankingAUC(Context const *ctx, std::vector<float> const &predts,
MetaInfo const &info, int32_t n_threads) {
CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1;
auto s_predts = common::Span<float const>{predts};
@ -237,9 +241,9 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
auc = 0;
} else {
if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w);
auc = GroupRankingROC(ctx, g_predts, g_labels, w);
} else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w}));
auc = std::get<2>(BinaryPRAUC(ctx, g_predts, g_labels, common::OptionalWeights{w}));
}
if (std::isnan(auc)) {
invalid_groups++;
@ -344,7 +348,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
auto n_threads = ctx_->Threads();
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(auc, valid_groups) =
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
RankingAUC<true>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) =
GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
@ -358,8 +362,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
auto n_threads = ctx_->Threads();
CHECK_NE(n_classes, 0);
if (ctx_->gpu_id == Context::kCpuId) {
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
BinaryROCAUC);
auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC);
} else {
auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes);
}
@ -370,9 +373,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
double fp, tp, auc;
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(fp, tp, auc) =
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(),
info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
ctx_->gpu_id, &this->d_cache_);
@ -422,7 +425,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
double pr, re, auc;
if (ctx_->gpu_id == Context::kCpuId) {
std::tie(pr, re, auc) =
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
common::OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
@ -435,8 +438,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
size_t n_classes) {
if (ctx_->gpu_id == Context::kCpuId) {
auto n_threads = this->ctx_->Threads();
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
BinaryPRAUC);
return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC);
} else {
return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
}
@ -453,7 +455,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
InvalidLabels();
}
std::tie(auc, valid_groups) =
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
RankingAUC<false>(ctx_, predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) =
GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);

View File

@ -27,6 +27,7 @@
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // Sort
#include "../common/math.h"
#include "../common/ranking_utils.h" // MakeMetricName
#include "../common/threading_utils.h"
@ -113,7 +114,7 @@ struct EvalAMS : public MetricNoCache {
const auto &h_preds = preds.ConstHostVector();
common::ParallelFor(ndata, ctx_->Threads(),
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
common::Sort(ctx_, rec.begin(), rec.end(), common::CmpFirst);
auto ntop = static_cast<unsigned>(ratio_ * ndata);
if (ntop == 0) ntop = ndata;
const double br = 10.0;
@ -330,7 +331,7 @@ struct EvalCox : public MetricNoCache {
using namespace std; // NOLINT(*)
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
const auto &label_order = info.LabelAbsSort();
const auto &label_order = info.LabelAbsSort(ctx_);
// pre-compute a sum for the denominator
double exp_p_sum = 0; // we use double because we might need the precision with large datasets

View File

@ -6,24 +6,25 @@
#include <limits>
#include <vector>
#include "../common/common.h"
#include "../common/numeric.h"
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "../common/algorithm.h" // ArgSort
#include "../common/numeric.h" // RunLengthEncode
#include "../common/stats.h" // Quantile,WeightedQuantile
#include "../common/threading_utils.h" // ParallelFor
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace obj {
namespace detail {
void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& position,
std::vector<size_t>* p_nptr, std::vector<bst_node_t>* p_nidx,
std::vector<size_t>* p_ridx) {
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) {
auto& nptr = *p_nptr;
auto& nidx = *p_nidx;
auto& ridx = *p_ridx;
ridx = common::ArgSort<size_t>(position);
ridx = common::ArgSort<size_t>(ctx, position.cbegin(), position.cend());
std::vector<bst_node_t> sorted_pos(position);
// permutation
for (size_t i = 0; i < position.size(); ++i) {
@ -74,7 +75,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
std::vector<bst_node_t> nidx;
std::vector<size_t> nptr;
std::vector<size_t> ridx;
EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx);
EncodeTreeLeafHost(ctx, *p_tree, position, &nptr, &nidx, &ridx);
size_t n_leaf = nidx.size();
if (nptr.empty()) {
std::vector<float> quantiles;

View File

@ -1,5 +1,10 @@
/**
* Copyright 2022-2023 by XGBoost contributors
*/
#include "init_estimation.h"
#include <memory> // unique_ptr
#include "../common/stats.h" // Mean
#include "../tree/fit_stump.h" // FitStump
#include "xgboost/base.h" // GradientPair

View File

@ -1,3 +1,8 @@
/**
* Copyright 2022-2023 by XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_
#define XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_
#include "xgboost/data.h" // MetaInfo
#include "xgboost/linalg.h" // Tensor
#include "xgboost/objective.h" // ObjFunction
@ -17,3 +22,4 @@ inline void CheckInitInputs(MetaInfo const& info) {
}
} // namespace obj
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_

View File

@ -393,7 +393,7 @@ class CoxRegression : public FitIntercept {
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
const std::vector<size_t> &label_order = info.LabelAbsSort();
const std::vector<size_t> &label_order = info.LabelAbsSort(ctx_);
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
const bool is_null_weight = info.weights_.Size() == 0;

View File

@ -34,10 +34,10 @@ class HistEvaluator {
};
private:
Context const* ctx_;
TrainParam param_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
int32_t n_threads_ {0};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;
@ -283,6 +283,7 @@ class HistEvaluator {
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
common::Span<FeatureType const> feature_types, const RegTree &tree,
std::vector<ExpandEntry> *p_entries) {
auto n_threads = ctx_->Threads();
auto& entries = *p_entries;
// All nodes are on the same level, so we can store the shared ptr.
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
@ -294,23 +295,23 @@ class HistEvaluator {
}
CHECK(!features.empty());
const size_t grain_size =
std::max<size_t>(1, features.front()->Size() / n_threads_);
std::max<size_t>(1, features.front()->Size() / n_threads);
common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) {
return features[nidx_in_set]->Size();
}, grain_size);
std::vector<ExpandEntry> tloc_candidates(n_threads_ * entries.size());
std::vector<ExpandEntry> tloc_candidates(n_threads * entries.size());
for (size_t i = 0; i < entries.size(); ++i) {
for (decltype(n_threads_) j = 0; j < n_threads_; ++j) {
tloc_candidates[i * n_threads_ + j] = entries[i];
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
}
}
auto evaluator = tree_evaluator_.GetEvaluator();
auto const& cut_ptrs = cut.Ptrs();
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads_ * nidx_in_set + tidx];
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
@ -349,9 +350,9 @@ class HistEvaluator {
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads_; ++tidx) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(
tloc_candidates[n_threads_ * nidx_in_set + tidx].split);
tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}
@ -424,15 +425,15 @@ class HistEvaluator {
public:
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
explicit HistEvaluator(Context const* ctx, TrainParam const &param, MetaInfo const &info,
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param},
: ctx_{ctx}, param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
n_threads_{n_threads} {
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree);
}
};

View File

@ -171,7 +171,7 @@ class GloablApproxBuilder {
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_},
evaluator_{ctx, param_, info, col_sampler_},
ctx_{ctx},
task_{task},
monitor_{monitor} {}

View File

@ -234,9 +234,9 @@ class ColMaker: public TreeUpdater {
}
}
{
column_sampler_.Init(fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree);
column_sampler_.Init(ctx_, fmat.Info().num_col_,
fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
}
{
// setup temp space for each thread

View File

@ -243,7 +243,7 @@ struct GPUHistMakerDevice {
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
auto const& info = dmat->Info();
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));

View File

@ -290,8 +290,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
// store a pointer to the tree
p_last_tree_ = &tree;
evaluator_.reset(
new HistEvaluator<CPUExpandEntry>{param_, info, this->ctx_->Threads(), column_sampler_});
evaluator_.reset(new HistEvaluator<CPUExpandEntry>{ctx_, param_, info, column_sampler_});
monitor_->Stop(__func__);
}

View File

@ -0,0 +1,35 @@
/**
* Copyright 2020-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // Context
#include <xgboost/span.h>
#include <algorithm> // is_sorted
#include "../../../src/common/algorithm.h"
namespace xgboost {
namespace common {
TEST(Algorithm, ArgSort) {
Context ctx;
std::vector<float> inputs{3.0, 2.0, 1.0};
auto ret = ArgSort<bst_feature_t>(&ctx, inputs.cbegin(), inputs.cend());
std::vector<bst_feature_t> sol{2, 1, 0};
ASSERT_EQ(ret, sol);
}
TEST(Algorithm, Sort) {
Context ctx;
ctx.Init(Args{{"nthread", "8"}});
std::vector<float> inputs{3.0, 1.0, 2.0};
Sort(&ctx, inputs.begin(), inputs.end(), std::less<>{});
ASSERT_TRUE(std::is_sorted(inputs.cbegin(), inputs.cend()));
inputs = {3.0, 1.0, 2.0};
StableSort(&ctx, inputs.begin(), inputs.end(), std::less<>{});
ASSERT_TRUE(std::is_sorted(inputs.cbegin(), inputs.cend()));
}
} // namespace common
} // namespace xgboost

View File

@ -52,9 +52,9 @@ void TestSegmentedArgSort() {
}
}
TEST(Algorithms, SegmentedArgSort) { TestSegmentedArgSort(); }
TEST(Algorithm, SegmentedArgSort) { TestSegmentedArgSort(); }
TEST(Algorithms, ArgSort) {
TEST(Algorithm, GpuArgSort) {
Context ctx;
ctx.gpu_id = 0;
@ -80,7 +80,7 @@ TEST(Algorithms, ArgSort) {
thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(), thrust::greater<size_t>{}));
}
TEST(Algorithms, SegmentedSequence) {
TEST(Algorithm, SegmentedSequence) {
dh::device_vector<std::size_t> idx(16);
dh::device_vector<std::size_t> ptr(3);
Context ctx = CreateEmptyGenericParam(0);

View File

@ -1,14 +0,0 @@
#include <gtest/gtest.h>
#include <xgboost/span.h>
#include "../../../src/common/common.h"
namespace xgboost {
namespace common {
TEST(ArgSort, Basic) {
std::vector<float> inputs {3.0, 2.0, 1.0};
auto ret = ArgSort<bst_feature_t>(Span<float>{inputs});
std::vector<bst_feature_t> sol{2, 1, 0};
ASSERT_EQ(ret, sol);
}
} // namespace common
} // namespace xgboost

View File

@ -2,16 +2,18 @@
#include "../../../src/common/random.h"
#include "../helpers.h"
#include "gtest/gtest.h"
#include "xgboost/context.h" // Context
namespace xgboost {
namespace common {
TEST(ColumnSampler, Test) {
Context ctx;
int n = 128;
ColumnSampler cs;
std::vector<float> feature_weights;
// No node sampling
cs.Init(n, feature_weights, 1.0f, 0.5f, 0.5f);
cs.Init(&ctx, n, feature_weights, 1.0f, 0.5f, 0.5f);
auto set0 = cs.GetFeatureSet(0);
ASSERT_EQ(set0->Size(), 32);
@ -24,7 +26,7 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set2->Size(), 32);
// Node sampling
cs.Init(n, feature_weights, 0.5f, 1.0f, 0.5f);
cs.Init(&ctx, n, feature_weights, 0.5f, 1.0f, 0.5f);
auto set3 = cs.GetFeatureSet(0);
ASSERT_EQ(set3->Size(), 32);
@ -34,24 +36,25 @@ TEST(ColumnSampler, Test) {
ASSERT_EQ(set4->Size(), 32);
// No level or node sampling, should be the same at different depth
cs.Init(n, feature_weights, 1.0f, 1.0f, 0.5f);
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 0.5f);
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
cs.GetFeatureSet(1)->HostVector());
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set5 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->Size(), n);
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
auto set6 = cs.GetFeatureSet(0);
ASSERT_EQ(set5->HostVector(), set6->HostVector());
// Should always be a minimum of one feature
cs.Init(n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
cs.Init(&ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
}
// Test if different threads using the same seed produce the same result
TEST(ColumnSampler, ThreadSynchronisation) {
Context ctx;
const int64_t num_threads = 100;
int n = 128;
size_t iterations = 10;
@ -63,7 +66,7 @@ TEST(ColumnSampler, ThreadSynchronisation) {
{
for (auto j = 0ull; j < iterations; j++) {
ColumnSampler cs(j);
cs.Init(n, feature_weights, 0.5f, 0.5f, 0.5f);
cs.Init(&ctx, n, feature_weights, 0.5f, 0.5f, 0.5f);
for (auto level = 0ull; level < levels; level++) {
auto result = cs.GetFeatureSet(level)->ConstHostVector();
#pragma omp single
@ -80,11 +83,12 @@ TEST(ColumnSampler, ThreadSynchronisation) {
TEST(ColumnSampler, WeightedSampling) {
auto test_basic = [](int first) {
Context ctx;
std::vector<float> feature_weights(2);
feature_weights[0] = std::abs(first - 1.0f);
feature_weights[1] = first - 0.0f;
ColumnSampler cs{0};
cs.Init(2, feature_weights, 1.0, 1.0, 0.5);
cs.Init(&ctx, 2, feature_weights, 1.0, 1.0, 0.5);
auto feature_sets = cs.GetFeatureSet(0);
auto const &h_feat_set = feature_sets->HostVector();
ASSERT_EQ(h_feat_set.size(), 1);
@ -100,7 +104,8 @@ TEST(ColumnSampler, WeightedSampling) {
SimpleRealUniformDistribution<float> dist(.0f, 12.0f);
std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); });
ColumnSampler cs{0};
cs.Init(kCols, feature_weights, 0.5f, 1.0f, 1.0f);
Context ctx;
cs.Init(&ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f);
std::vector<bst_feature_t> features(kCols);
std::iota(features.begin(), features.end(), 0);
std::vector<float> freq(kCols, 0);
@ -135,7 +140,8 @@ TEST(ColumnSampler, WeightedMultiSampling) {
}
ColumnSampler cs{0};
float bytree{0.5}, bylevel{0.5}, bynode{0.5};
cs.Init(feature_weights.size(), feature_weights, bytree, bylevel, bynode);
Context ctx;
cs.Init(&ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode);
auto feature_set = cs.GetFeatureSet(0);
size_t n_sampled = kCols * bytree * bylevel * bynode;
ASSERT_EQ(feature_set->Size(), n_sampled);

View File

@ -9,12 +9,14 @@
#include "../../../../src/tree/hist/evaluate_splits.h"
#include "../test_evaluate_splits.h"
#include "../../helpers.h"
#include "xgboost/context.h" // Context
namespace xgboost {
namespace tree {
void TestEvaluateSplits(bool force_read_by_column) {
Context ctx;
ctx.nthread = 4;
int static constexpr kRows = 8, kCols = 16;
int32_t n_threads = std::min(omp_get_max_threads(), 4);
auto sampler = std::make_shared<common::ColumnSampler>();
TrainParam param;
@ -22,7 +24,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
auto evaluator = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
common::HistCollection hist;
std::vector<GradientPair> row_gpairs = {
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
@ -86,13 +88,15 @@ TEST(HistEvaluator, Evaluate) {
}
TEST(HistEvaluator, Apply) {
Context ctx;
ctx.nthread = 4;
RegTree tree;
int static constexpr kNRows = 8, kNCols = 16;
TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator_ = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), 4, sampler};
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
CPUExpandEntry entry{0, 0, 10.0f};
entry.split.left_sum = GradStats{0.4, 0.6f};
@ -115,10 +119,11 @@ TEST(HistEvaluator, Apply) {
}
TEST_F(TestPartitionBasedSplit, CPUHist) {
Context ctx;
// check the evaluator is returning the optimal split
std::vector<FeatureType> ft{FeatureType::kCategorical};
auto sampler = std::make_shared<common::ColumnSampler>();
HistEvaluator<CPUExpandEntry> evaluator{param_, info_, AllThreadsForTest(), sampler};
HistEvaluator<CPUExpandEntry> evaluator{&ctx, param_, info_, sampler};
evaluator.InitRoot(GradStats{total_gpair_});
RegTree tree;
std::vector<CPUExpandEntry> entries(1);
@ -128,6 +133,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
namespace {
auto CompareOneHotAndPartition(bool onehot) {
Context ctx;
int static constexpr kRows = 128, kCols = 1;
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
@ -147,8 +153,7 @@ auto CompareOneHotAndPartition(bool onehot) {
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator =
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), AllThreadsForTest(), sampler};
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
std::vector<CPUExpandEntry> entries(1);
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
@ -198,8 +203,8 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
MetaInfo info;
info.num_col_ = 1;
info.feature_types = {FeatureType::kCategorical};
auto evaluator =
HistEvaluator<CPUExpandEntry>{param_, info, AllThreadsForTest(), sampler};
Context ctx;
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param_, info, sampler};
evaluator.InitRoot(GradStats{parent_sum_});
std::vector<CPUExpandEntry> entries(1);