Specify the number of threads for parallel sort. (#8735)

* Specify the number of threads for parallel sort.

- Pass context object into argsort.
- Replace macros with inline functions.
This commit is contained in:
Jiaming Yuan
2023-02-16 00:20:19 +08:00
committed by GitHub
parent c7c485d052
commit 282b1729da
24 changed files with 254 additions and 143 deletions

View File

@@ -10,12 +10,13 @@
#include <cstring>
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // StableSort
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
#include "../common/math.h"
#include "../common/numeric.h"
#include "../common/numeric.h" // Iota
#include "../common/threading_utils.h"
#include "../common/version.h"
#include "../data/adapter.h"
@@ -258,6 +259,19 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
}
}
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
if (label_order_cache_.size() == labels.Size()) {
return label_order_cache_;
}
label_order_cache_.resize(labels.Size());
common::Iota(ctx, label_order_cache_.begin(), label_order_cache_.end(), 0);
const auto& l = labels.Data()->HostVector();
common::StableSort(ctx, label_order_cache_.begin(), label_order_cache_.end(),
[&l](size_t i1, size_t i2) { return std::abs(l[i1]) < std::abs(l[i2]); });
return label_order_cache_;
}
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
auto version = Version::Load(fi);
auto major = std::get<0>(version);