Use lambda function in ParallelFor2D. (#9441)

This commit is contained in:
Jiaming Yuan 2023-08-08 14:04:46 +08:00 committed by GitHub
parent 54029a59af
commit 97fd5207dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,7 @@
#include <cstdlib> // for malloc, free #include <cstdlib> // for malloc, free
#include <functional> // for function #include <functional> // for function
#include <new> // for bad_alloc #include <new> // for bad_alloc
#include <type_traits> // for is_signed, conditional_t #include <type_traits> // for is_signed, conditional_t, is_integral_v, invoke_result_t
#include <vector> // for vector #include <vector> // for vector
#include "xgboost/logging.h" #include "xgboost/logging.h"
@ -87,8 +87,9 @@ class BlockedSpace2d {
// dim1 - size of the first dimension in the space // dim1 - size of the first dimension in the space
// getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index // getter_size_dim2 - functor to get the second dimensions for each 'row' by row-index
// grain_size - max size of produced blocks // grain_size - max size of produced blocks
BlockedSpace2d(std::size_t dim1, std::function<std::size_t(std::size_t)> getter_size_dim2, template <typename Getter>
std::size_t grain_size) { BlockedSpace2d(std::size_t dim1, Getter&& getter_size_dim2, std::size_t grain_size) {
static_assert(std::is_integral_v<std::invoke_result_t<Getter, std::size_t>>);
for (std::size_t i = 0; i < dim1; ++i) { for (std::size_t i = 0; i < dim1; ++i) {
std::size_t size = getter_size_dim2(i); std::size_t size = getter_size_dim2(i);
// Each row (second dim) is divided into n_blocks // Each row (second dim) is divided into n_blocks
@ -137,8 +138,9 @@ class BlockedSpace2d {
// Wrapper to implement nested parallelism with simple omp parallel for // Wrapper to implement nested parallelism with simple omp parallel for
inline void ParallelFor2d(BlockedSpace2d const& space, std::int32_t n_threads, template <typename Func>
std::function<void(std::size_t, Range1d)> func) { void ParallelFor2d(const BlockedSpace2d& space, int n_threads, Func&& func) {
static_assert(std::is_void_v<std::invoke_result_t<Func, std::size_t, Range1d>>);
std::size_t n_blocks_in_space = space.Size(); std::size_t n_blocks_in_space = space.Size();
CHECK_GE(n_threads, 1); CHECK_GE(n_threads, 1);