Support hessian in host sketch container. (#7081)

Prepare for migrating approx onto hist's codebase.
This commit is contained in:
Jiaming Yuan
2021-07-08 16:33:58 +08:00
committed by GitHub
parent 84d359efb8
commit 77f6cf2d13
8 changed files with 238 additions and 53 deletions

View File

@@ -9,6 +9,7 @@
#include <dmlc/common.h>
#include <vector>
#include <algorithm>
#include <type_traits> // std::is_signed
#include "xgboost/logging.h"
namespace xgboost {
@@ -133,19 +134,92 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) {
exc.Rethrow();
}
/**
* OpenMP schedule
*/
struct Sched {
enum {
kAuto,
kDynamic,
kStatic,
kGuided,
} sched;
size_t chunk{0};
Sched static Auto() { return Sched{kAuto}; }
Sched static Dyn(size_t n = 0) { return Sched{kDynamic, n}; }
Sched static Static(size_t n = 0) { return Sched{kStatic, n}; }
Sched static Guided() { return Sched{kGuided}; }
};
template <typename Index, typename Func>
void ParallelFor(Index size, size_t nthreads, Func fn) {
void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) {
#if defined(_MSC_VER)
// msvc doesn't support unsigned integer as openmp index.
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
#else
using OmpInd = Index;
#endif
OmpInd length = static_cast<OmpInd>(size);
dmlc::OMPException exc;
#pragma omp parallel for num_threads(nthreads) schedule(static)
for (Index i = 0; i < size; ++i) {
exc.Run(fn, i);
switch (sched.sched) {
case Sched::kAuto: {
#pragma omp parallel for num_threads(n_threads)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
break;
}
case Sched::kDynamic: {
if (sched.chunk == 0) {
#pragma omp parallel for num_threads(n_threads) schedule(dynamic)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
} else {
#pragma omp parallel for num_threads(n_threads) schedule(dynamic, sched.chunk)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
}
break;
}
case Sched::kStatic: {
if (sched.chunk == 0) {
#pragma omp parallel for num_threads(n_threads) schedule(static)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
} else {
#pragma omp parallel for num_threads(n_threads) schedule(static, sched.chunk)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
}
break;
}
case Sched::kGuided: {
#pragma omp parallel for num_threads(n_threads) schedule(guided)
for (OmpInd i = 0; i < length; ++i) {
exc.Run(fn, i);
}
break;
}
}
exc.Rethrow();
}
template <typename Index, typename Func>
void ParallelFor(Index size, size_t n_threads, Func fn) {
ParallelFor(size, n_threads, Sched::Static(), fn);
}
// FIXME(jiamingy): Remove this function to get rid of `omp_set_num_threads`, which sets a
// global variable in runtime and affects other programs in the same process.
template <typename Index, typename Func>
void ParallelFor(Index size, Func fn) {
ParallelFor(size, omp_get_max_threads(), fn);
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
}
/* \brief Configure parallel threads.
@@ -174,6 +248,12 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
return nthread_original;
}
inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = omp_get_num_procs();
}
return n_threads;
}
} // namespace common
} // namespace xgboost