Specify the number of threads for parallel sort. (#8735)
* Specify the number of threads for parallel sort. - Pass context object into argsort. - Replace macros with inline functions.
This commit is contained in:
@@ -10,12 +10,13 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/algorithm.h" // StableSort
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/numeric.h" // Iota
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/version.h"
|
||||
#include "../data/adapter.h"
|
||||
@@ -258,6 +259,19 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
|
||||
if (label_order_cache_.size() == labels.Size()) {
|
||||
return label_order_cache_;
|
||||
}
|
||||
label_order_cache_.resize(labels.Size());
|
||||
common::Iota(ctx, label_order_cache_.begin(), label_order_cache_.end(), 0);
|
||||
const auto& l = labels.Data()->HostVector();
|
||||
common::StableSort(ctx, label_order_cache_.begin(), label_order_cache_.end(),
|
||||
[&l](size_t i1, size_t i2) { return std::abs(l[i1]) < std::abs(l[i2]); });
|
||||
|
||||
return label_order_cache_;
|
||||
}
|
||||
|
||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
auto version = Version::Load(fi);
|
||||
auto major = std::get<0>(version);
|
||||
|
||||
Reference in New Issue
Block a user