Specify the number of threads for parallel sort. (#8735)
* Specify the number of threads for parallel sort. - Pass context object into argsort. - Replace macros with inline functions.
This commit is contained in:
parent
c7c485d052
commit
282b1729da
@ -48,21 +48,6 @@
|
||||
#define XGBOOST_ALIGNAS(X)
|
||||
#endif // defined(__GNUC__) && ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ > 4)
|
||||
|
||||
#if defined(__GNUC__) && ((__GNUC__ == 4 && __GNUC_MINOR__ >= 8) || __GNUC__ > 4) && \
|
||||
!defined(__CUDACC__) && !defined(__sun) && !defined(sun)
|
||||
#include <parallel/algorithm>
|
||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) \
|
||||
__gnu_parallel::stable_sort((X), (Y), (Z))
|
||||
#elif defined(_MSC_VER) && (!__INTEL_COMPILER)
|
||||
#include <ppl.h>
|
||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) concurrency::parallel_sort((X), (Y), (Z))
|
||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
|
||||
#else
|
||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z))
|
||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
|
||||
#endif // GLIBC VERSION
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define XGBOOST_EXPECT(cond, ret) __builtin_expect((cond), (ret))
|
||||
#else
|
||||
|
||||
@ -124,18 +124,7 @@ class MetaInfo {
|
||||
return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
|
||||
}
|
||||
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
|
||||
inline const std::vector<size_t>& LabelAbsSort() const {
|
||||
if (label_order_cache_.size() == labels.Size()) {
|
||||
return label_order_cache_;
|
||||
}
|
||||
label_order_cache_.resize(labels.Size());
|
||||
std::iota(label_order_cache_.begin(), label_order_cache_.end(), 0);
|
||||
const auto& l = labels.Data()->HostVector();
|
||||
XGBOOST_PARALLEL_STABLE_SORT(label_order_cache_.begin(), label_order_cache_.end(),
|
||||
[&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);});
|
||||
|
||||
return label_order_cache_;
|
||||
}
|
||||
const std::vector<size_t>& LabelAbsSort(Context const* ctx) const;
|
||||
/*! \brief clear all the information */
|
||||
void Clear();
|
||||
/*!
|
||||
|
||||
@ -23,6 +23,10 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
#endif // defined(_MSC_VER)
|
||||
|
||||
// decouple it from xgboost.
|
||||
#ifndef LINALG_HD
|
||||
#if defined(__CUDA__) || defined(__NVCC__)
|
||||
|
||||
@ -1,10 +1,31 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_ALGORITHM_H_
|
||||
#define XGBOOST_COMMON_ALGORITHM_H_
|
||||
#include <algorithm> // std::upper_bound
|
||||
#include <cinttypes> // std::size_t
|
||||
#include <algorithm> // upper_bound, stable_sort, sort, max
|
||||
#include <cinttypes> // size_t
|
||||
#include <functional> // less
|
||||
#include <iterator> // iterator_traits, distance
|
||||
#include <vector> // vector
|
||||
|
||||
#include "numeric.h" // Iota
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
// clang with libstdc++ works as well
|
||||
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && !defined(__APPLE__)
|
||||
#define GCC_HAS_PARALLEL 1
|
||||
#endif // GLIC_VERSION
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
||||
#define MSVC_HAS_PARALLEL 1
|
||||
#endif // MSC
|
||||
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
#include <parallel/algorithm>
|
||||
#elif defined(MSVC_HAS_PARALLEL)
|
||||
#include <ppl.h>
|
||||
#endif // GLIBC VERSION
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -13,6 +34,63 @@ auto SegmentId(It first, It last, Idx idx) {
|
||||
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
|
||||
return segment_id;
|
||||
}
|
||||
|
||||
template <typename Iter, typename Comp>
|
||||
void StableSort(Context const *ctx, Iter begin, Iter end, Comp &&comp) {
|
||||
if (ctx->Threads() > 1) {
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
__gnu_parallel::stable_sort(begin, end, comp,
|
||||
__gnu_parallel::default_parallel_tag(ctx->Threads()));
|
||||
#else
|
||||
// the only stable sort is radix sort for msvc ppl.
|
||||
std::stable_sort(begin, end, comp);
|
||||
#endif // GLIBC VERSION
|
||||
} else {
|
||||
std::stable_sort(begin, end, comp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Iter, typename Comp>
|
||||
void Sort(Context const *ctx, Iter begin, Iter end, Comp comp) {
|
||||
if (ctx->Threads() > 1) {
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
__gnu_parallel::sort(begin, end, comp, __gnu_parallel::default_parallel_tag(ctx->Threads()));
|
||||
#elif defined(MSVC_HAS_PARALLEL)
|
||||
auto n = std::distance(begin, end);
|
||||
// use chunk size as hint to number of threads. No local policy/scheduler input with the
|
||||
// concurrency module.
|
||||
std::size_t chunk_size = n / ctx->Threads();
|
||||
// 2048 is the default of msvc ppl as of v2022.
|
||||
chunk_size = std::max(chunk_size, static_cast<std::size_t>(2048));
|
||||
concurrency::parallel_sort(begin, end, comp, chunk_size);
|
||||
#else
|
||||
std::sort(begin, end, comp);
|
||||
#endif // GLIBC VERSION
|
||||
} else {
|
||||
std::sort(begin, end, comp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Idx, typename Iter, typename V = typename std::iterator_traits<Iter>::value_type,
|
||||
typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = std::less<V>{}) {
|
||||
CHECK(ctx->IsCPU());
|
||||
auto n = std::distance(begin, end);
|
||||
std::vector<Idx> result(n);
|
||||
Iota(ctx, result.begin(), result.end(), 0);
|
||||
auto op = [&](Idx const &l, Idx const &r) { return comp(begin[l], begin[r]); };
|
||||
StableSort(ctx, result.begin(), result.end(), op);
|
||||
return result;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#if defined(GCC_HAS_PARALLEL)
|
||||
#undef GCC_HAS_PARALLEL
|
||||
#endif // defined(GCC_HAS_PARALLEL)
|
||||
|
||||
#if defined(MSVC_HAS_PARALLEL)
|
||||
#undef MSVC_HAS_PARALLEL
|
||||
#endif // defined(MSVC_HAS_PARALLEL)
|
||||
|
||||
#endif // XGBOOST_COMMON_ALGORITHM_H_
|
||||
|
||||
@ -188,17 +188,6 @@ inline void SetDevice(std::int32_t device) {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Idx, typename Container,
|
||||
typename V = typename Container::value_type,
|
||||
typename Comp = std::less<V>>
|
||||
std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
|
||||
std::vector<Idx> result(array.size());
|
||||
std::iota(result.begin(), result.end(), 0);
|
||||
auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); };
|
||||
XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Last index of a group in a CSR style of index pointer.
|
||||
*/
|
||||
|
||||
@ -24,8 +24,9 @@ std::shared_ptr<HostDeviceVector<bst_feature_t>> ColumnSampler::ColSample(
|
||||
for (size_t i = 0; i < h_features.size(); ++i) {
|
||||
weights[i] = feature_weights_[h_features[i]];
|
||||
}
|
||||
CHECK(ctx_);
|
||||
new_features.HostVector() =
|
||||
WeightedSamplingWithoutReplacement(p_features->HostVector(), weights, n);
|
||||
WeightedSamplingWithoutReplacement(ctx_, p_features->HostVector(), weights, n);
|
||||
} else {
|
||||
new_features.Resize(features.size());
|
||||
std::copy(features.begin(), features.end(), new_features.HostVector().begin());
|
||||
|
||||
@ -20,7 +20,9 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "algorithm.h" // ArgSort
|
||||
#include "common.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -87,8 +89,8 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
|
||||
* https://timvieira.github.io/blog/post/2019/09/16/algorithms-for-sampling-without-replacement/
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
std::vector<T> const &array, std::vector<float> const &weights, size_t n) {
|
||||
std::vector<T> WeightedSamplingWithoutReplacement(Context const* ctx, std::vector<T> const& array,
|
||||
std::vector<float> const& weights, size_t n) {
|
||||
// ES sampling.
|
||||
CHECK_EQ(array.size(), weights.size());
|
||||
std::vector<float> keys(weights.size());
|
||||
@ -100,7 +102,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(
|
||||
auto k = std::log(u) / w;
|
||||
keys[i] = k;
|
||||
}
|
||||
auto ind = ArgSort<size_t>(Span<float>{keys}, std::greater<>{});
|
||||
auto ind = ArgSort<std::size_t>(ctx, keys.data(), keys.data() + keys.size(), std::greater<>{});
|
||||
ind.resize(n);
|
||||
|
||||
std::vector<T> results(ind.size());
|
||||
@ -126,6 +128,7 @@ class ColumnSampler {
|
||||
float colsample_bytree_{1.0f};
|
||||
float colsample_bynode_{1.0f};
|
||||
GlobalRandomEngine rng_;
|
||||
Context const* ctx_;
|
||||
|
||||
public:
|
||||
std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
|
||||
@ -157,12 +160,13 @@ class ColumnSampler {
|
||||
* \param colsample_bytree
|
||||
* \param skip_index_0 (Optional) True to skip index 0.
|
||||
*/
|
||||
void Init(int64_t num_col, std::vector<float> feature_weights, float colsample_bynode,
|
||||
float colsample_bylevel, float colsample_bytree) {
|
||||
void Init(Context const* ctx, int64_t num_col, std::vector<float> feature_weights,
|
||||
float colsample_bynode, float colsample_bylevel, float colsample_bytree) {
|
||||
feature_weights_ = std::move(feature_weights);
|
||||
colsample_bylevel_ = colsample_bylevel;
|
||||
colsample_bytree_ = colsample_bytree;
|
||||
colsample_bynode_ = colsample_bynode;
|
||||
ctx_ = ctx;
|
||||
|
||||
if (feature_set_tree_ == nullptr) {
|
||||
feature_set_tree_ = std::make_shared<HostDeviceVector<bst_feature_t>>();
|
||||
|
||||
@ -10,12 +10,13 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/algorithm.h" // StableSort
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/version.h"
|
||||
#include "../data/adapter.h"
|
||||
@ -258,6 +259,19 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
|
||||
if (label_order_cache_.size() == labels.Size()) {
|
||||
return label_order_cache_;
|
||||
}
|
||||
label_order_cache_.resize(labels.Size());
|
||||
common::Iota(ctx, label_order_cache_.begin(), label_order_cache_.end(), 0);
|
||||
const auto& l = labels.Data()->HostVector();
|
||||
common::StableSort(ctx, label_order_cache_.begin(), label_order_cache_.end(),
|
||||
[&l](size_t i1, size_t i2) { return std::abs(l[i1]) < std::abs(l[i2]); });
|
||||
|
||||
return label_order_cache_;
|
||||
}
|
||||
|
||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
auto version = Version::Load(fi);
|
||||
auto major = std::get<0>(version);
|
||||
|
||||
@ -14,9 +14,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/algorithm.h" // ArgSort
|
||||
#include "../common/math.h"
|
||||
#include "../common/optional_weight.h" // OptionalWeights
|
||||
#include "metric_common.h" // MetricNoCache
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/metric.h"
|
||||
@ -77,9 +79,8 @@ BinaryAUC(common::Span<float const> predts, linalg::VectorView<float const> labe
|
||||
* Machine Learning Models
|
||||
*/
|
||||
template <typename BinaryAUC>
|
||||
double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
size_t n_classes, int32_t n_threads,
|
||||
BinaryAUC &&binary_auc) {
|
||||
double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaInfo const &info,
|
||||
size_t n_classes, int32_t n_threads, BinaryAUC &&binary_auc) {
|
||||
CHECK_NE(n_classes, 0);
|
||||
auto const labels = info.labels.View(Context::kCpuId);
|
||||
if (labels.Shape(0) != 0) {
|
||||
@ -108,7 +109,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
}
|
||||
double fp;
|
||||
std::tie(fp, tp(c), auc(c)) =
|
||||
binary_auc(proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
|
||||
binary_auc(ctx, proba, linalg::MakeVec(response.data(), response.size(), -1), weights);
|
||||
local_area(c) = fp * tp(c);
|
||||
});
|
||||
}
|
||||
@ -139,23 +140,26 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
return auc_sum;
|
||||
}
|
||||
|
||||
std::tuple<double, double, double> BinaryROCAUC(common::Span<float const> predts,
|
||||
std::tuple<double, double, double> BinaryROCAUC(Context const *ctx,
|
||||
common::Span<float const> predts,
|
||||
linalg::VectorView<float const> labels,
|
||||
common::OptionalWeights weights) {
|
||||
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
|
||||
auto const sorted_idx =
|
||||
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
|
||||
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate AUC for 1 ranking group;
|
||||
*/
|
||||
double GroupRankingROC(common::Span<float const> predts,
|
||||
double GroupRankingROC(Context const* ctx, common::Span<float const> predts,
|
||||
linalg::VectorView<float const> labels, float w) {
|
||||
// on ranking, we just count all pairs.
|
||||
double auc{0};
|
||||
// argsort doesn't support tensor input yet.
|
||||
auto raw_labels = labels.Values().subspan(0, labels.Size());
|
||||
auto const sorted_idx = common::ArgSort<size_t>(raw_labels, std::greater<>{});
|
||||
auto const sorted_idx = common::ArgSort<size_t>(
|
||||
ctx, raw_labels.data(), raw_labels.data() + raw_labels.size(), std::greater<>{});
|
||||
w = common::Sqr(w);
|
||||
|
||||
double sum_w = 0.0f;
|
||||
@ -185,10 +189,11 @@ double GroupRankingROC(common::Span<float const> predts,
|
||||
*
|
||||
* https://doi.org/10.1371/journal.pone.0092209
|
||||
*/
|
||||
std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
|
||||
std::tuple<double, double, double> BinaryPRAUC(Context const *ctx, common::Span<float const> predts,
|
||||
linalg::VectorView<float const> labels,
|
||||
common::OptionalWeights weights) {
|
||||
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
|
||||
auto const sorted_idx =
|
||||
common::ArgSort<size_t>(ctx, predts.data(), predts.data() + predts.size(), std::greater<>{});
|
||||
double total_pos{0}, total_neg{0};
|
||||
for (size_t i = 0; i < labels.Size(); ++i) {
|
||||
auto w = weights[i];
|
||||
@ -211,9 +216,8 @@ std::tuple<double, double, double> BinaryPRAUC(common::Span<float const> predts,
|
||||
* Cast LTR problem to binary classification problem by comparing pairs.
|
||||
*/
|
||||
template <bool is_roc>
|
||||
std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
|
||||
MetaInfo const &info,
|
||||
int32_t n_threads) {
|
||||
std::pair<double, uint32_t> RankingAUC(Context const *ctx, std::vector<float> const &predts,
|
||||
MetaInfo const &info, int32_t n_threads) {
|
||||
CHECK_GE(info.group_ptr_.size(), 2);
|
||||
uint32_t n_groups = info.group_ptr_.size() - 1;
|
||||
auto s_predts = common::Span<float const>{predts};
|
||||
@ -237,9 +241,9 @@ std::pair<double, uint32_t> RankingAUC(std::vector<float> const &predts,
|
||||
auc = 0;
|
||||
} else {
|
||||
if (is_roc) {
|
||||
auc = GroupRankingROC(g_predts, g_labels, w);
|
||||
auc = GroupRankingROC(ctx, g_predts, g_labels, w);
|
||||
} else {
|
||||
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, common::OptionalWeights{w}));
|
||||
auc = std::get<2>(BinaryPRAUC(ctx, g_predts, g_labels, common::OptionalWeights{w}));
|
||||
}
|
||||
if (std::isnan(auc)) {
|
||||
invalid_groups++;
|
||||
@ -344,7 +348,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
auto n_threads = ctx_->Threads();
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(auc, valid_groups) =
|
||||
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
|
||||
RankingAUC<true>(ctx_, predts.ConstHostVector(), info, n_threads);
|
||||
} else {
|
||||
std::tie(auc, valid_groups) =
|
||||
GPURankingAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_);
|
||||
@ -358,8 +362,7 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
auto n_threads = ctx_->Threads();
|
||||
CHECK_NE(n_classes, 0);
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
|
||||
BinaryROCAUC);
|
||||
auc = MultiClassOVR(ctx_, predts.ConstHostVector(), info, n_classes, n_threads, BinaryROCAUC);
|
||||
} else {
|
||||
auc = GPUMultiClassROCAUC(ctx_, predts.ConstDeviceSpan(), info, &this->d_cache_, n_classes);
|
||||
}
|
||||
@ -370,9 +373,9 @@ class EvalROCAUC : public EvalAUC<EvalROCAUC> {
|
||||
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
|
||||
double fp, tp, auc;
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(fp, tp, auc) =
|
||||
BinaryROCAUC(predts.ConstHostVector(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||
std::tie(fp, tp, auc) = BinaryROCAUC(ctx_, predts.ConstHostVector(),
|
||||
info.labels.HostView().Slice(linalg::All(), 0),
|
||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||
} else {
|
||||
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
|
||||
ctx_->gpu_id, &this->d_cache_);
|
||||
@ -422,7 +425,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
double pr, re, auc;
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
std::tie(pr, re, auc) =
|
||||
BinaryPRAUC(predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||
BinaryPRAUC(ctx_, predts.ConstHostSpan(), info.labels.HostView().Slice(linalg::All(), 0),
|
||||
common::OptionalWeights{info.weights_.ConstHostSpan()});
|
||||
} else {
|
||||
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
|
||||
@ -435,8 +438,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
size_t n_classes) {
|
||||
if (ctx_->gpu_id == Context::kCpuId) {
|
||||
auto n_threads = this->ctx_->Threads();
|
||||
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
|
||||
BinaryPRAUC);
|
||||
return MultiClassOVR(ctx_, predts.ConstHostSpan(), info, n_classes, n_threads, BinaryPRAUC);
|
||||
} else {
|
||||
return GPUMultiClassPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_, n_classes);
|
||||
}
|
||||
@ -453,7 +455,7 @@ class EvalPRAUC : public EvalAUC<EvalPRAUC> {
|
||||
InvalidLabels();
|
||||
}
|
||||
std::tie(auc, valid_groups) =
|
||||
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
|
||||
RankingAUC<false>(ctx_, predts.ConstHostVector(), info, n_threads);
|
||||
} else {
|
||||
std::tie(auc, valid_groups) =
|
||||
GPURankingPRAUC(ctx_, predts.ConstDeviceSpan(), info, &d_cache_);
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/algorithm.h" // Sort
|
||||
#include "../common/math.h"
|
||||
#include "../common/ranking_utils.h" // MakeMetricName
|
||||
#include "../common/threading_utils.h"
|
||||
@ -113,7 +114,7 @@ struct EvalAMS : public MetricNoCache {
|
||||
const auto &h_preds = preds.ConstHostVector();
|
||||
common::ParallelFor(ndata, ctx_->Threads(),
|
||||
[&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); });
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
common::Sort(ctx_, rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
const double br = 10.0;
|
||||
@ -330,7 +331,7 @@ struct EvalCox : public MetricNoCache {
|
||||
using namespace std; // NOLINT(*)
|
||||
|
||||
const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
|
||||
const auto &label_order = info.LabelAbsSort();
|
||||
const auto &label_order = info.LabelAbsSort(ctx_);
|
||||
|
||||
// pre-compute a sum for the denominator
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
|
||||
@ -6,24 +6,25 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/algorithm.h" // ArgSort
|
||||
#include "../common/numeric.h" // RunLengthEncode
|
||||
#include "../common/stats.h" // Quantile,WeightedQuantile
|
||||
#include "../common/threading_utils.h" // ParallelFor
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
namespace detail {
|
||||
void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& position,
|
||||
std::vector<size_t>* p_nptr, std::vector<bst_node_t>* p_nidx,
|
||||
std::vector<size_t>* p_ridx) {
|
||||
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
|
||||
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) {
|
||||
auto& nptr = *p_nptr;
|
||||
auto& nidx = *p_nidx;
|
||||
auto& ridx = *p_ridx;
|
||||
ridx = common::ArgSort<size_t>(position);
|
||||
ridx = common::ArgSort<size_t>(ctx, position.cbegin(), position.cend());
|
||||
std::vector<bst_node_t> sorted_pos(position);
|
||||
// permutation
|
||||
for (size_t i = 0; i < position.size(); ++i) {
|
||||
@ -74,7 +75,7 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
std::vector<bst_node_t> nidx;
|
||||
std::vector<size_t> nptr;
|
||||
std::vector<size_t> ridx;
|
||||
EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx);
|
||||
EncodeTreeLeafHost(ctx, *p_tree, position, &nptr, &nidx, &ridx);
|
||||
size_t n_leaf = nidx.size();
|
||||
if (nptr.empty()) {
|
||||
std::vector<float> quantiles;
|
||||
|
||||
@ -1,5 +1,10 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#include "init_estimation.h"
|
||||
|
||||
#include <memory> // unique_ptr
|
||||
|
||||
#include "../common/stats.h" // Mean
|
||||
#include "../tree/fit_stump.h" // FitStump
|
||||
#include "xgboost/base.h" // GradientPair
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_
|
||||
#define XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "xgboost/objective.h" // ObjFunction
|
||||
@ -17,3 +22,4 @@ inline void CheckInitInputs(MetaInfo const& info) {
|
||||
}
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_
|
||||
|
||||
@ -393,7 +393,7 @@ class CoxRegression : public FitIntercept {
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort(ctx_);
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
|
||||
@ -34,10 +34,10 @@ class HistEvaluator {
|
||||
};
|
||||
|
||||
private:
|
||||
Context const* ctx_;
|
||||
TrainParam param_;
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_;
|
||||
TreeEvaluator tree_evaluator_;
|
||||
int32_t n_threads_ {0};
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
std::vector<NodeEntry> snode_;
|
||||
|
||||
@ -283,6 +283,7 @@ class HistEvaluator {
|
||||
void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut,
|
||||
common::Span<FeatureType const> feature_types, const RegTree &tree,
|
||||
std::vector<ExpandEntry> *p_entries) {
|
||||
auto n_threads = ctx_->Threads();
|
||||
auto& entries = *p_entries;
|
||||
// All nodes are on the same level, so we can store the shared ptr.
|
||||
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(
|
||||
@ -294,23 +295,23 @@ class HistEvaluator {
|
||||
}
|
||||
CHECK(!features.empty());
|
||||
const size_t grain_size =
|
||||
std::max<size_t>(1, features.front()->Size() / n_threads_);
|
||||
std::max<size_t>(1, features.front()->Size() / n_threads);
|
||||
common::BlockedSpace2d space(entries.size(), [&](size_t nidx_in_set) {
|
||||
return features[nidx_in_set]->Size();
|
||||
}, grain_size);
|
||||
|
||||
std::vector<ExpandEntry> tloc_candidates(n_threads_ * entries.size());
|
||||
std::vector<ExpandEntry> tloc_candidates(n_threads * entries.size());
|
||||
for (size_t i = 0; i < entries.size(); ++i) {
|
||||
for (decltype(n_threads_) j = 0; j < n_threads_; ++j) {
|
||||
tloc_candidates[i * n_threads_ + j] = entries[i];
|
||||
for (decltype(n_threads) j = 0; j < n_threads; ++j) {
|
||||
tloc_candidates[i * n_threads + j] = entries[i];
|
||||
}
|
||||
}
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
auto const& cut_ptrs = cut.Ptrs();
|
||||
|
||||
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
|
||||
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
|
||||
auto tidx = omp_get_thread_num();
|
||||
auto entry = &tloc_candidates[n_threads_ * nidx_in_set + tidx];
|
||||
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
|
||||
auto best = &entry->split;
|
||||
auto nidx = entry->nid;
|
||||
auto histogram = hist[nidx];
|
||||
@ -349,9 +350,9 @@ class HistEvaluator {
|
||||
|
||||
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size();
|
||||
++nidx_in_set) {
|
||||
for (auto tidx = 0; tidx < n_threads_; ++tidx) {
|
||||
for (auto tidx = 0; tidx < n_threads; ++tidx) {
|
||||
entries[nidx_in_set].split.Update(
|
||||
tloc_candidates[n_threads_ * nidx_in_set + tidx].split);
|
||||
tloc_candidates[n_threads * nidx_in_set + tidx].split);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -424,15 +425,15 @@ class HistEvaluator {
|
||||
public:
|
||||
// The column sampler must be constructed by caller since we need to preserve the rng
|
||||
// for the entire training session.
|
||||
explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads,
|
||||
explicit HistEvaluator(Context const* ctx, TrainParam const ¶m, MetaInfo const &info,
|
||||
std::shared_ptr<common::ColumnSampler> sampler)
|
||||
: param_{param},
|
||||
: ctx_{ctx}, param_{param},
|
||||
column_sampler_{std::move(sampler)},
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
|
||||
n_threads_{n_threads} {
|
||||
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId} {
|
||||
interaction_constraints_.Configure(param, info.num_col_);
|
||||
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -171,7 +171,7 @@ class GloablApproxBuilder {
|
||||
common::Monitor *monitor)
|
||||
: param_{std::move(param)},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{param_, info, ctx->Threads(), col_sampler_},
|
||||
evaluator_{ctx, param_, info, col_sampler_},
|
||||
ctx_{ctx},
|
||||
task_{task},
|
||||
monitor_{monitor} {}
|
||||
|
||||
@ -234,9 +234,9 @@ class ColMaker: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
{
|
||||
column_sampler_.Init(fmat.Info().num_col_, fmat.Info().feature_weights.ConstHostVector(),
|
||||
param_.colsample_bynode, param_.colsample_bylevel,
|
||||
param_.colsample_bytree);
|
||||
column_sampler_.Init(ctx_, fmat.Info().num_col_,
|
||||
fmat.Info().feature_weights.ConstHostVector(), param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
}
|
||||
{
|
||||
// setup temp space for each thread
|
||||
|
||||
@ -243,7 +243,7 @@ struct GPUHistMakerDevice {
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
auto const& info = dmat->Info();
|
||||
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
|
||||
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
|
||||
@ -290,8 +290,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
|
||||
// store a pointer to the tree
|
||||
p_last_tree_ = &tree;
|
||||
evaluator_.reset(
|
||||
new HistEvaluator<CPUExpandEntry>{param_, info, this->ctx_->Threads(), column_sampler_});
|
||||
evaluator_.reset(new HistEvaluator<CPUExpandEntry>{ctx_, param_, info, column_sampler_});
|
||||
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
35
tests/cpp/common/test_algorithm.cc
Normal file
35
tests/cpp/common/test_algorithm.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/span.h>
|
||||
|
||||
#include <algorithm> // is_sorted
|
||||
|
||||
#include "../../../src/common/algorithm.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(Algorithm, ArgSort) {
|
||||
Context ctx;
|
||||
std::vector<float> inputs{3.0, 2.0, 1.0};
|
||||
auto ret = ArgSort<bst_feature_t>(&ctx, inputs.cbegin(), inputs.cend());
|
||||
std::vector<bst_feature_t> sol{2, 1, 0};
|
||||
ASSERT_EQ(ret, sol);
|
||||
}
|
||||
|
||||
TEST(Algorithm, Sort) {
|
||||
Context ctx;
|
||||
ctx.Init(Args{{"nthread", "8"}});
|
||||
std::vector<float> inputs{3.0, 1.0, 2.0};
|
||||
|
||||
Sort(&ctx, inputs.begin(), inputs.end(), std::less<>{});
|
||||
ASSERT_TRUE(std::is_sorted(inputs.cbegin(), inputs.cend()));
|
||||
|
||||
inputs = {3.0, 1.0, 2.0};
|
||||
StableSort(&ctx, inputs.begin(), inputs.end(), std::less<>{});
|
||||
ASSERT_TRUE(std::is_sorted(inputs.cbegin(), inputs.cend()));
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -52,9 +52,9 @@ void TestSegmentedArgSort() {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Algorithms, SegmentedArgSort) { TestSegmentedArgSort(); }
|
||||
TEST(Algorithm, SegmentedArgSort) { TestSegmentedArgSort(); }
|
||||
|
||||
TEST(Algorithms, ArgSort) {
|
||||
TEST(Algorithm, GpuArgSort) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
|
||||
@ -80,7 +80,7 @@ TEST(Algorithms, ArgSort) {
|
||||
thrust::is_sorted(sorted_idx.begin() + 10, sorted_idx.end(), thrust::greater<size_t>{}));
|
||||
}
|
||||
|
||||
TEST(Algorithms, SegmentedSequence) {
|
||||
TEST(Algorithm, SegmentedSequence) {
|
||||
dh::device_vector<std::size_t> idx(16);
|
||||
dh::device_vector<std::size_t> ptr(3);
|
||||
Context ctx = CreateEmptyGenericParam(0);
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/span.h>
|
||||
#include "../../../src/common/common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(ArgSort, Basic) {
|
||||
std::vector<float> inputs {3.0, 2.0, 1.0};
|
||||
auto ret = ArgSort<bst_feature_t>(Span<float>{inputs});
|
||||
std::vector<bst_feature_t> sol{2, 1, 0};
|
||||
ASSERT_EQ(ret, sol);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -2,16 +2,18 @@
|
||||
#include "../../../src/common/random.h"
|
||||
#include "../helpers.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(ColumnSampler, Test) {
|
||||
Context ctx;
|
||||
int n = 128;
|
||||
ColumnSampler cs;
|
||||
std::vector<float> feature_weights;
|
||||
|
||||
// No node sampling
|
||||
cs.Init(n, feature_weights, 1.0f, 0.5f, 0.5f);
|
||||
cs.Init(&ctx, n, feature_weights, 1.0f, 0.5f, 0.5f);
|
||||
auto set0 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set0->Size(), 32);
|
||||
|
||||
@ -24,7 +26,7 @@ TEST(ColumnSampler, Test) {
|
||||
ASSERT_EQ(set2->Size(), 32);
|
||||
|
||||
// Node sampling
|
||||
cs.Init(n, feature_weights, 0.5f, 1.0f, 0.5f);
|
||||
cs.Init(&ctx, n, feature_weights, 0.5f, 1.0f, 0.5f);
|
||||
auto set3 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set3->Size(), 32);
|
||||
|
||||
@ -34,24 +36,25 @@ TEST(ColumnSampler, Test) {
|
||||
ASSERT_EQ(set4->Size(), 32);
|
||||
|
||||
// No level or node sampling, should be the same at different depth
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 0.5f);
|
||||
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 0.5f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->HostVector(),
|
||||
cs.GetFeatureSet(1)->HostVector());
|
||||
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
auto set5 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5->Size(), n);
|
||||
cs.Init(n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
cs.Init(&ctx, n, feature_weights, 1.0f, 1.0f, 1.0f);
|
||||
auto set6 = cs.GetFeatureSet(0);
|
||||
ASSERT_EQ(set5->HostVector(), set6->HostVector());
|
||||
|
||||
// Should always be a minimum of one feature
|
||||
cs.Init(n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
|
||||
cs.Init(&ctx, n, feature_weights, 1e-16f, 1e-16f, 1e-16f);
|
||||
ASSERT_EQ(cs.GetFeatureSet(0)->Size(), 1);
|
||||
}
|
||||
|
||||
// Test if different threads using the same seed produce the same result
|
||||
TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
Context ctx;
|
||||
const int64_t num_threads = 100;
|
||||
int n = 128;
|
||||
size_t iterations = 10;
|
||||
@ -63,7 +66,7 @@ TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
{
|
||||
for (auto j = 0ull; j < iterations; j++) {
|
||||
ColumnSampler cs(j);
|
||||
cs.Init(n, feature_weights, 0.5f, 0.5f, 0.5f);
|
||||
cs.Init(&ctx, n, feature_weights, 0.5f, 0.5f, 0.5f);
|
||||
for (auto level = 0ull; level < levels; level++) {
|
||||
auto result = cs.GetFeatureSet(level)->ConstHostVector();
|
||||
#pragma omp single
|
||||
@ -80,11 +83,12 @@ TEST(ColumnSampler, ThreadSynchronisation) {
|
||||
|
||||
TEST(ColumnSampler, WeightedSampling) {
|
||||
auto test_basic = [](int first) {
|
||||
Context ctx;
|
||||
std::vector<float> feature_weights(2);
|
||||
feature_weights[0] = std::abs(first - 1.0f);
|
||||
feature_weights[1] = first - 0.0f;
|
||||
ColumnSampler cs{0};
|
||||
cs.Init(2, feature_weights, 1.0, 1.0, 0.5);
|
||||
cs.Init(&ctx, 2, feature_weights, 1.0, 1.0, 0.5);
|
||||
auto feature_sets = cs.GetFeatureSet(0);
|
||||
auto const &h_feat_set = feature_sets->HostVector();
|
||||
ASSERT_EQ(h_feat_set.size(), 1);
|
||||
@ -100,7 +104,8 @@ TEST(ColumnSampler, WeightedSampling) {
|
||||
SimpleRealUniformDistribution<float> dist(.0f, 12.0f);
|
||||
std::generate(feature_weights.begin(), feature_weights.end(), [&]() { return dist(&rng); });
|
||||
ColumnSampler cs{0};
|
||||
cs.Init(kCols, feature_weights, 0.5f, 1.0f, 1.0f);
|
||||
Context ctx;
|
||||
cs.Init(&ctx, kCols, feature_weights, 0.5f, 1.0f, 1.0f);
|
||||
std::vector<bst_feature_t> features(kCols);
|
||||
std::iota(features.begin(), features.end(), 0);
|
||||
std::vector<float> freq(kCols, 0);
|
||||
@ -135,7 +140,8 @@ TEST(ColumnSampler, WeightedMultiSampling) {
|
||||
}
|
||||
ColumnSampler cs{0};
|
||||
float bytree{0.5}, bylevel{0.5}, bynode{0.5};
|
||||
cs.Init(feature_weights.size(), feature_weights, bytree, bylevel, bynode);
|
||||
Context ctx;
|
||||
cs.Init(&ctx, feature_weights.size(), feature_weights, bytree, bylevel, bynode);
|
||||
auto feature_set = cs.GetFeatureSet(0);
|
||||
size_t n_sampled = kCols * bytree * bylevel * bynode;
|
||||
ASSERT_EQ(feature_set->Size(), n_sampled);
|
||||
|
||||
@ -9,12 +9,14 @@
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||
#include "../test_evaluate_splits.h"
|
||||
#include "../../helpers.h"
|
||||
#include "xgboost/context.h" // Context
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
void TestEvaluateSplits(bool force_read_by_column) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
int static constexpr kRows = 8, kCols = 16;
|
||||
int32_t n_threads = std::min(omp_get_max_threads(), 4);
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
|
||||
TrainParam param;
|
||||
@ -22,7 +24,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
|
||||
|
||||
auto evaluator = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
|
||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
||||
common::HistCollection hist;
|
||||
std::vector<GradientPair> row_gpairs = {
|
||||
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||
@ -86,13 +88,15 @@ TEST(HistEvaluator, Evaluate) {
|
||||
}
|
||||
|
||||
TEST(HistEvaluator, Apply) {
|
||||
Context ctx;
|
||||
ctx.nthread = 4;
|
||||
RegTree tree;
|
||||
int static constexpr kNRows = 8, kNCols = 16;
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}});
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix();
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
auto evaluator_ = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), 4, sampler};
|
||||
auto evaluator_ = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
||||
|
||||
CPUExpandEntry entry{0, 0, 10.0f};
|
||||
entry.split.left_sum = GradStats{0.4, 0.6f};
|
||||
@ -115,10 +119,11 @@ TEST(HistEvaluator, Apply) {
|
||||
}
|
||||
|
||||
TEST_F(TestPartitionBasedSplit, CPUHist) {
|
||||
Context ctx;
|
||||
// check the evaluator is returning the optimal split
|
||||
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
HistEvaluator<CPUExpandEntry> evaluator{param_, info_, AllThreadsForTest(), sampler};
|
||||
HistEvaluator<CPUExpandEntry> evaluator{&ctx, param_, info_, sampler};
|
||||
evaluator.InitRoot(GradStats{total_gpair_});
|
||||
RegTree tree;
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
@ -128,6 +133,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
|
||||
|
||||
namespace {
|
||||
auto CompareOneHotAndPartition(bool onehot) {
|
||||
Context ctx;
|
||||
int static constexpr kRows = 128, kCols = 1;
|
||||
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
|
||||
|
||||
@ -147,8 +153,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
||||
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
auto evaluator =
|
||||
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), AllThreadsForTest(), sampler};
|
||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param, dmat->Info(), sampler};
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
|
||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
||||
@ -198,8 +203,8 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
MetaInfo info;
|
||||
info.num_col_ = 1;
|
||||
info.feature_types = {FeatureType::kCategorical};
|
||||
auto evaluator =
|
||||
HistEvaluator<CPUExpandEntry>{param_, info, AllThreadsForTest(), sampler};
|
||||
Context ctx;
|
||||
auto evaluator = HistEvaluator<CPUExpandEntry>{&ctx, param_, info, sampler};
|
||||
evaluator.InitRoot(GradStats{parent_sum_});
|
||||
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user