From 3760cede0ff14f53b32dbc3f4f1e56716cfd3a9b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 30 Jan 2023 15:25:31 +0800 Subject: [PATCH] Consistent use of context to specify number of threads. (#8733) - Use context in all tests. - Use context in R. - Use context in C API DMatrix initialization. (0 threads is used as dft). --- R-package/src/xgboost_R.cc | 11 +++-- src/c_api/c_api.cc | 18 ++++---- src/c_api/c_api.cu | 10 ++--- src/common/threading_utils.cc | 27 +++++++----- src/common/threading_utils.h | 33 ++++++++------ src/context.cc | 3 -- src/data/data.cc | 8 ++-- tests/cpp/c_api/test_c_api.cc | 48 ++++++++++++++++++--- tests/cpp/common/test_column_matrix.cc | 17 ++++---- tests/cpp/common/test_hist_util.cc | 44 ++++++++----------- tests/cpp/common/test_hist_util.cu | 8 ++-- tests/cpp/common/test_quantile.cc | 10 ++--- tests/cpp/common/test_transform_range.cc | 10 ++--- tests/cpp/data/test_data.cc | 10 ++--- tests/cpp/data/test_gradient_index.cc | 8 ++-- tests/cpp/data/test_iterative_dmatrix.cc | 11 +++-- tests/cpp/data/test_simple_dmatrix.cc | 11 +++++ tests/cpp/data/test_sparse_page_dmatrix.cc | 24 +++++------ tests/cpp/helpers.cc | 6 +-- tests/cpp/helpers.h | 4 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 16 +++---- tests/cpp/tree/hist/test_histogram.cc | 15 ++++--- tests/cpp/tree/test_prediction_cache.cc | 5 +-- tests/cpp/tree/test_regen.cc | 7 +-- 24 files changed, 212 insertions(+), 152 deletions(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index a52108e8e..990274100 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2022 by XGBoost Contributors + * Copyright 2014-2023 by XGBoost Contributors */ #include #include @@ -115,7 +115,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) { din = REAL(mat); } std::vector data(nrow * ncol); - int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads)); + xgboost::Context ctx; + ctx.nthread = asInteger(n_threads); + std::int32_t threads = ctx.Threads(); xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) { for (size_t j = 0; j < ncol; ++j) { @@ -149,8 +151,9 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, for (size_t i = 0; i < nindptr; ++i) { col_ptr_[i] = static_cast(p_indptr[i]); } - int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads)); - xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) { + xgboost::Context ctx; + ctx.nthread = asInteger(n_threads); + xgboost::common::ParallelFor(ndata, ctx.Threads(), [&](xgboost::omp_ulong i) { indices_[i] = static_cast(p_indices[i]); data_[i] = static_cast(p_data[i]); }); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index cb85590eb..d388f1506 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,4 +1,6 @@ -// Copyright (c) 2014-2022 by Contributors +/** + * Copyright 2014-2023 by XGBoost Contributors + */ #include #include @@ -279,7 +281,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); - auto n_threads = OptionalArg(jconfig, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(jconfig, "nthread", 0); xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(reset); @@ -319,7 +321,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); - auto n_threads = OptionalArg(jconfig, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(jconfig, "nthread", 0); auto max_bin = OptionalArg(jconfig, "max_bin", 256); xgboost_CHECK_C_ARG_PTR(next); @@ -420,7 +422,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char xgboost_CHECK_C_ARG_PTR(c_json_config); auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(config, "nthread", 0); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); API_END(); @@ -435,10 +437,9 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, xgboost_CHECK_C_ARG_PTR(c_json_config); auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(config, "nthread", 0); xgboost_CHECK_C_ARG_PTR(out); - *out = - new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); + *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); API_END(); } @@ -506,8 +507,7 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char c auto jconfig = Json::Load(StringView{config}); auto missing = GetMissing(jconfig); auto n_batches = RequiredArg(jconfig, "nbatch", __func__); - auto n_threads = - OptionalArg(jconfig, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(jconfig, "nthread", 0); data::RecordBatchesIterAdapter adapter(next, n_batches); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 2af36f0ac..b75c70adc 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,4 +1,6 @@ -// Copyright (c) 2019-2022 by Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors + */ #include "../common/threading_utils.h" #include "../data/device_adapter.cuh" #include "../data/proxy_dmatrix.h" @@ -68,8 +70,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto n_threads = - OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(config, "nthread", 0); data::CudfAdapter adapter(json_str); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); @@ -83,8 +84,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data, std::string json_str{data}; auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); - auto n_threads = - OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + auto n_threads = OptionalArg(config, "nthread", 0); data::CupyAdapter adapter(json_str); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); diff --git a/src/common/threading_utils.cc b/src/common/threading_utils.cc index bcff45efb..349cc0ba7 100644 --- a/src/common/threading_utils.cc +++ b/src/common/threading_utils.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2022 by XGBoost Contributors +/** + * Copyright 2022-2023 by XGBoost Contributors */ #include "threading_utils.h" @@ -10,14 +10,6 @@ namespace xgboost { namespace common { -/** - * \brief Get thread limit from CFS - * - * Modified from - * github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp - * - * MIT License: Copyright (c) 2016 Domagoj Šarić - */ int32_t GetCfsCPUCount() noexcept { #if defined(__linux__) // https://bugs.openjdk.java.net/browse/JDK-8146115 @@ -47,5 +39,20 @@ int32_t GetCfsCPUCount() noexcept { #endif // defined(__linux__) return -1; } + +std::int32_t OmpGetNumThreads(std::int32_t n_threads) { + // Don't use parallel if we are in a parallel region. + if (omp_in_parallel()) { + return 1; + } + // If -1 or 0 is specified by the user, we default to maximum number of threads. + if (n_threads <= 0) { + n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); + } + // Honor the openmp thread limit, which can be set via environment variable. + n_threads = std::min(n_threads, OmpGetThreadLimit()); + n_threads = std::max(n_threads, 1); + return n_threads; +} } // namespace common } // namespace xgboost diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 656e570ae..a52695e02 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by XGBoost Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_THREADING_UTILS_H_ #define XGBOOST_COMMON_THREADING_UTILS_H_ @@ -231,23 +231,28 @@ void ParallelFor(Index size, int32_t n_threads, Func fn) { ParallelFor(size, n_threads, Sched::Static(), fn); } -inline int32_t OmpGetThreadLimit() { - int32_t limit = omp_get_thread_limit(); +inline std::int32_t OmpGetThreadLimit() { + std::int32_t limit = omp_get_thread_limit(); CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP."; return limit; } -int32_t GetCfsCPUCount() noexcept; - -inline int32_t OmpGetNumThreads(int32_t n_threads) { - if (n_threads <= 0) { - n_threads = std::min(omp_get_num_procs(), omp_get_max_threads()); - } - n_threads = std::min(n_threads, OmpGetThreadLimit()); - n_threads = std::max(n_threads, 1); - return n_threads; -} +/** + * \brief Get thread limit from CFS. + * + * This function has non-trivial overhead and should not be called repeatly. + * + * Modified from + * github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp + * + * MIT License: Copyright (c) 2016 Domagoj Šarić + */ +std::int32_t GetCfsCPUCount() noexcept; +/** + * \brief Get the number of available threads based on n_threads specified by users. + */ +std::int32_t OmpGetNumThreads(std::int32_t n_threads); /*! * \brief A C-style array with in-stack allocation. As long as the array is smaller than diff --git a/src/context.cc b/src/context.cc index 2628c9d95..28fda9c45 100644 --- a/src/context.cc +++ b/src/context.cc @@ -53,9 +53,6 @@ void Context::ConfigureGpuId(bool require_gpu) { } std::int32_t Context::Threads() const { - if (omp_in_parallel()) { - return 1; - } auto n_threads = common::OmpGetNumThreads(nthread); if (cfs_cpu_count_ > 0) { n_threads = std::min(n_threads, cfs_cpu_count_); diff --git a/src/data/data.cc b/src/data/data.cc index dd1f51717..91052f274 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file data.cc */ #include "xgboost/data.h" @@ -27,6 +27,7 @@ #include "sparse_page_writer.h" #include "validation.h" #include "xgboost/c_api.h" +#include "xgboost/context.h" #include "xgboost/host_device_vector.h" #include "xgboost/learner.h" #include "xgboost/logging.h" @@ -850,7 +851,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s std::unique_ptr> parser( dmlc::Parser::Create(fname.c_str(), partid, npart, file_format.c_str())); data::FileAdapter adapter(parser.get()); - dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1, cache_file); + dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), Context{}.Threads(), + cache_file); } else { data::FileIterator iter{fname, static_cast(partid), static_cast(npart), file_format}; diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 046f8a317..6b5bc7cb8 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -1,16 +1,21 @@ -/*! - * Copyright 2019-2022 XGBoost contributors +/** + * Copyright 2019-2023 XGBoost contributors */ #include -#include #include #include +#include // Json #include +#include -#include "../helpers.h" -#include "../../../src/common/io.h" +#include // std::size_t +#include // std::numeric_limits +#include // std::string +#include #include "../../../src/c_api/c_api_error.h" +#include "../../../src/common/io.h" +#include "../helpers.h" TEST(CAPI, XGDMatrixCreateFromMatDT) { std::vector col0 = {0, -1, 3}; @@ -83,6 +88,39 @@ TEST(CAPI, Version) { ASSERT_EQ(patch, XGBOOST_VER_PATCH); } +TEST(CAPI, XGDMatrixCreateFromCSR) { + HostDeviceVector indptr{0, 3}; + HostDeviceVector data{0.0, 1.0, 2.0}; + HostDeviceVector indices{0, 1, 2}; + auto indptr_arr = GetArrayInterface(&indptr, 2, 1); + auto indices_arr = GetArrayInterface(&indices, 3, 1); + auto data_arr = GetArrayInterface(&data, 3, 1); + std::string sindptr, sindices, sdata, sconfig; + Json::Dump(indptr_arr, &sindptr); + Json::Dump(indices_arr, &sindices); + Json::Dump(data_arr, &sdata); + Json config{Object{}}; + config["missing"] = Number{std::numeric_limits::quiet_NaN()}; + Json::Dump(config, &sconfig); + + DMatrixHandle handle; + XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), 3, sconfig.c_str(), + &handle); + bst_ulong n; + ASSERT_EQ(XGDMatrixNumRow(handle, &n), 0); + ASSERT_EQ(n, 1); + ASSERT_EQ(XGDMatrixNumCol(handle, &n), 0); + ASSERT_EQ(n, 3); + ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0); + ASSERT_EQ(n, 3); + + std::shared_ptr *pp_fmat = + static_cast *>(handle); + ASSERT_EQ((*pp_fmat)->Ctx()->Threads(), AllThreadsForTest()); + + XGDMatrixFree(handle); +} + TEST(CAPI, ConfigIO) { size_t constexpr kRows = 10; auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix(); diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index fca5c0c4e..de7b9a258 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2018-2022 by XGBoost Contributors +/** + * Copyright 2018-2023 by XGBoost Contributors */ #include @@ -18,11 +18,10 @@ TEST(DenseColumn, Test) { for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto sparse_thresh = 0.2; - GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, - common::OmpGetNumThreads(0)}; + GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, AllThreadsForTest()}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest()); } ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); @@ -65,10 +64,10 @@ TEST(SparseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); - GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; + GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, AllThreadsForTest()}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, 1.0, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest()); } common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); @@ -93,10 +92,10 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, AllThreadsForTest()); ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { - column_matrix.InitFromSparse(page, gmat, 0.2, common::OmpGetNumThreads(0)); + column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest()); } ASSERT_TRUE(column_matrix.AnyMissing()); DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index a095f6972..41c728f35 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by XGBoost Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors */ #include #include @@ -14,15 +14,13 @@ namespace xgboost { namespace common { -size_t GetNThreads() { return common::OmpGetNumThreads(0); } - void ParallelGHistBuilderReset() { constexpr size_t kBins = 10; constexpr size_t kNodes = 5; constexpr size_t kNodesExtended = 10; constexpr size_t kTasksPerNode = 10; constexpr double kValue = 1.0; - const size_t nthreads = GetNThreads(); + const size_t nthreads = AllThreadsForTest(); HistCollection collection; collection.Init(kBins); @@ -78,7 +76,7 @@ void ParallelGHistBuilderReduceHist(){ constexpr size_t kNodes = 5; constexpr size_t kTasksPerNode = 10; constexpr double kValue = 1.0; - const size_t nthreads = GetNThreads(); + const size_t nthreads = AllThreadsForTest(); HistCollection collection; collection.Init(kBins); @@ -167,7 +165,7 @@ TEST(HistUtil, DenseCutsCategorical) { std::vector x_sorted(x); std::sort(x_sorted.begin(), x_sorted.end()); auto dmat = GetDMatrixFromData(x, n, 1); - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest()); auto cuts_from_sketch = cuts.Values(); EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); @@ -180,13 +178,12 @@ TEST(HistUtil, DenseCutsCategorical) { TEST(HistUtil, DenseCutsAccuracyTest) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100}; - // omp_set_num_threads(1); int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest()); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -203,13 +200,11 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) { dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { { - HistogramCuts cuts = - SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), true); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), true); ValidateCuts(cuts, dmat.get(), num_bins); } { - HistogramCuts cuts = - SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), false); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), false); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -231,14 +226,14 @@ void TestQuantileWithHessian(bool use_sorted) { for (auto num_bins : bin_sizes) { HistogramCuts cuts_hess = - SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted, hessian); + SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted, hessian); for (size_t i = 0; i < w.size(); ++i) { dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; } ValidateCuts(cuts_hess, dmat.get(), num_bins); HistogramCuts cuts_wh = - SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted); + SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted); ValidateCuts(cuts_wh, dmat.get(), num_bins); ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); @@ -265,7 +260,7 @@ TEST(HistUtil, DenseCutsExternalMemory) { dmlc::TemporaryDirectory tmpdir; auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, tmpdir); for (auto num_bins : bin_sizes) { - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest()); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -285,7 +280,7 @@ TEST(HistUtil, IndexBinBound) { for (auto max_bin : bin_sizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest()); EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); } @@ -308,7 +303,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest()); uint32_t const* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { @@ -331,9 +326,8 @@ TEST(HistUtil, IndexBinData) { void TestSketchFromWeights(bool with_group) { size_t constexpr kRows = 300, kCols = 20, kBins = 256; size_t constexpr kGroups = 10; - auto m = - RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix(); - common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); + auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix(); + common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest()); MetaInfo info; Context ctx; @@ -369,7 +363,7 @@ void TestSketchFromWeights(bool with_group) { if (with_group) { m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight - HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); + HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest()); for (size_t i = 0; i < cuts.Values().size(); ++i) { EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); } @@ -388,7 +382,7 @@ void TestSketchFromWeights(bool with_group) { for (size_t i = 0; i < h_weights.size(); ++i) { h_weights[i] = static_cast(i + 1) / static_cast(kGroups); } - HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); + HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest()); ValidateCuts(weighted, m.get(), kBins); } } @@ -400,10 +394,10 @@ TEST(HistUtil, SketchFromWeights) { TEST(HistUtil, SketchCategoricalFeatures) { TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) { - return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0)); + return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest()); }); TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) { - return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0)); + return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest()); }); } } // namespace common diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 7324531b1..c9db7f646 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by XGBoost Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors */ #include #include @@ -27,7 +27,7 @@ namespace common { template HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { data::SimpleDMatrix dmat(adapter, missing, 1); - HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, common::OmpGetNumThreads(0)); + HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, AllThreadsForTest()); return cuts; } @@ -39,7 +39,7 @@ TEST(HistUtil, DeviceSketch) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); - HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0)); + HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest()); EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index bd6932aa3..3541b977f 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2020-2022 by XGBoost Contributors +/** + * Copyright 2020-2023 by XGBoost Contributors */ #include "test_quantile.h" @@ -73,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { auto hess = Span{hessian}; ContainerType sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, OmpGetNumThreads(0)); + column_size, false, AllThreadsForTest()); if (use_column) { for (auto const& page : m->GetBatches()) { @@ -94,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, OmpGetNumThreads(0)); + column_size, false, AllThreadsForTest()); m->Info().num_row_ = rows; for (auto rank = 0; rank < world; ++rank) { @@ -188,7 +188,7 @@ void TestSameOnAllWorkers() { .MaxCategory(17) .Seed(rank + seed) .GenerateDMatrix(); - auto cuts = SketchOnDMatrix(m.get(), n_bins, common::OmpGetNumThreads(0)); + auto cuts = SketchOnDMatrix(m.get(), n_bins, AllThreadsForTest()); std::vector cut_values(cuts.Values().size() * world, 0); std::vector< typename std::remove_reference_t::value_type> diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 97103d8f3..6e3ae9d82 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2018-2022 by XGBoost Contributors +/** + * Copyright 2018-2023 by XGBoost Contributors */ #include #include @@ -45,7 +45,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) { out_vec.Fill(0); Transform<>::Init(TestTransformRange{}, - Range{0, static_cast(size)}, common::OmpGetNumThreads(0), + Range{0, static_cast(size)}, AllThreadsForTest(), TRANSFORM_GPU) .Eval(&out_vec, &in_vec); std::vector res = out_vec.HostVector(); @@ -61,8 +61,8 @@ TEST(TransformDeathTest, Exception) { EXPECT_DEATH( { Transform<>::Init([](size_t idx, common::Span _in) { _in[idx + 1]; }, - Range(0, static_cast(kSize)), - common::OmpGetNumThreads(0), -1) + Range(0, static_cast(kSize)), AllThreadsForTest(), + -1) .Eval(&in_vec); }, ""); diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 8f80f3171..7b35c6f6f 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by XGBoost Contributors +/** + * Copyright 2019-2023 by XGBoost Contributors */ #include @@ -70,14 +70,14 @@ TEST(SparsePage, PushCSCAfterTranspose) { SparsePage page; // Consolidated sparse page for (const auto &batch : dmat->GetBatches()) { // Transpose each batch and push - SparsePage tmp = batch.GetTranspose(ncols, common::OmpGetNumThreads(0)); + SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest()); page.PushCSC(tmp); } // Make sure that the final sparse page has the right number of entries ASSERT_EQ(kEntries, page.data.Size()); - page.SortRows(common::OmpGetNumThreads(0)); + page.SortRows(AllThreadsForTest()); auto v = page.GetView(); for (size_t i = 0; i < v.Size(); ++i) { auto column = v[i]; @@ -89,7 +89,7 @@ TEST(SparsePage, PushCSCAfterTranspose) { TEST(SparsePage, SortIndices) { auto p_fmat = RandomDataGenerator{100, 10, 0.6}.GenerateDMatrix(); - auto n_threads = common::OmpGetNumThreads(0); + auto n_threads = AllThreadsForTest(); SparsePage copy; for (auto const& page : p_fmat->GetBatches()) { ASSERT_TRUE(page.IsIndicesSorted(n_threads)); diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 6233f1b25..2bfb756c1 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 XGBoost contributors +/** + * Copyright 2021-2023 by XGBoost contributors */ #include #include @@ -46,7 +46,7 @@ TEST(GradientIndex, FromCategoricalBasic) { h_ft.resize(kCols, FeatureType::kCategorical); BatchParam p(max_bins, 0.8); - GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {}); + GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, AllThreadsForTest(), {}); auto x_copy = x; std::sort(x_copy.begin(), x_copy.end()); @@ -75,7 +75,7 @@ TEST(GradientIndex, PushBatch) { auto test = [&](float sparisty) { auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true); - auto cuts = common::SketchOnDMatrix(m.get(), max_bins, common::OmpGetNumThreads(0), false, {}); + auto cuts = common::SketchOnDMatrix(m.get(), max_bins, AllThreadsForTest(), false, {}); common::HistogramCuts copy_cuts = cuts; ASSERT_EQ(m->Info().num_row_, kRows); diff --git a/tests/cpp/data/test_iterative_dmatrix.cc b/tests/cpp/data/test_iterative_dmatrix.cc index 3e9372aab..f95f7c03c 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cc +++ b/tests/cpp/data/test_iterative_dmatrix.cc @@ -1,13 +1,16 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023 by XGBoost contributors */ #include "test_iterative_dmatrix.h" #include +#include +#include #include "../../../src/data/gradient_index.h" #include "../../../src/data/iterative_dmatrix.h" #include "../helpers.h" +#include "xgboost/data.h" // DMatrix namespace xgboost { namespace data { @@ -20,8 +23,10 @@ TEST(IterativeDMatrix, IsDense) { int n_bins = 16; auto test = [n_bins](float sparsity) { NumpyArrayIterForTest iter(sparsity); + auto n_threads = 0; IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next, - std::numeric_limits::quiet_NaN(), 0, n_bins); + std::numeric_limits::quiet_NaN(), n_threads, n_bins); + ASSERT_EQ(m.Ctx()->Threads(), AllThreadsForTest()); if (sparsity == 0.0) { ASSERT_TRUE(m.IsDense()); } else { diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 9d54751a7..3dbe0a51a 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -411,3 +411,14 @@ TEST(SimpleDMatrix, SaveLoadBinary) { delete dmat; delete dmat_read; } + +TEST(SimpleDMatrix, Threads) { + size_t constexpr kRows{16}; + size_t constexpr kCols{8}; + HostDeviceVector data; + auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data); + auto adapter = data::ArrayAdapter{StringView{arr_str}}; + std::unique_ptr p_fmat{ + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 0, "")}; + ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest()); +} diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 8f740ca9b..8c2ff9514 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -1,4 +1,6 @@ -// Copyright by Contributors +/** + * Copyright 2016-2023 by XGBoost Contributors + */ #include #include @@ -22,13 +24,15 @@ void TestSparseDMatrixLoadFile() { CreateBigTestData(opath, 3 * 64, false); opath += "?indexing_mode=1"; data::FileIterator iter{opath, 0, 1, "libsvm"}; + auto n_threads = 0; data::SparsePageDMatrix m{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, std::numeric_limits::quiet_NaN(), - 1, + n_threads, tmpdir.path + "cache"}; + ASSERT_EQ(AllThreadsForTest(), m.Ctx()->Threads()); ASSERT_EQ(m.Info().num_col_, 5); ASSERT_EQ(m.Info().num_row_, 64); @@ -213,16 +217,13 @@ TEST(SparsePageDMatrix, ColAccessBatches) { size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; // Create multiple sparse pages std::unique_ptr dmat{xgboost::CreateSparsePageDMatrix(kEntries)}; - auto n_threads = omp_get_max_threads(); - omp_set_num_threads(16); + ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest()); for (auto const &page : dmat->GetBatches()) { ASSERT_EQ(dmat->Info().num_col_, page.Size()); } - omp_set_num_threads(n_threads); } auto TestSparsePageDMatrixDeterminism(int32_t threads) { - omp_set_num_threads(threads); std::vector sparse_data; std::vector sparse_rptr; std::vector sparse_cids; @@ -231,16 +232,15 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) { CreateBigTestData(filename, 1 << 16); data::FileIterator iter(filename, 0, 1, "auto"); - std::unique_ptr sparse{new data::SparsePageDMatrix{ - &iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, - std::numeric_limits::quiet_NaN(), 1, filename}}; + std::unique_ptr sparse{ + new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, + std::numeric_limits::quiet_NaN(), threads, filename}}; + CHECK(sparse->Ctx()->Threads() == threads || sparse->Ctx()->Threads() == AllThreadsForTest()); DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids); auto cache_name = - data::MakeId(filename, - dynamic_cast(sparse.get())) + - ".row.page"; + data::MakeId(filename, dynamic_cast(sparse.get())) + ".row.page"; std::string cache = common::LoadSequentialFile(cache_name); return cache; } diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 04c3dd3ad..2c3cb5094 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -438,9 +438,9 @@ std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, size_t n_rows = n_entries / n_columns; NumpyArrayIterForTest iter(0, n_rows, n_columns, 2); - std::unique_ptr dmat{DMatrix::Create( - static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), omp_get_max_threads(), prefix)}; + std::unique_ptr dmat{ + DMatrix::Create(static_cast(&iter), iter.Proxy(), Reset, Next, + std::numeric_limits::quiet_NaN(), 0, prefix)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(dmat.get())) + diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 80f504d8e..71424af18 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -10,6 +10,7 @@ #include #include +#include // std::int32_t #include #include #include @@ -461,6 +462,8 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint return mparam; } +inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); } + template void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) { std::vector threads; @@ -481,5 +484,4 @@ void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&& thread.join(); } } - } // namespace xgboost diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 7000240df..c45ed5385 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023 by XGBoost Contributors */ #include #include @@ -14,9 +14,7 @@ namespace xgboost { namespace tree { void TestEvaluateSplits(bool force_read_by_column) { int static constexpr kRows = 8, kCols = 16; - auto orig = omp_get_max_threads(); int32_t n_threads = std::min(omp_get_max_threads(), 4); - omp_set_num_threads(n_threads); auto sampler = std::make_shared(); TrainParam param; @@ -32,7 +30,7 @@ void TestEvaluateSplits(bool force_read_by_column) { size_t constexpr kMaxBins = 4; // dense, no missing values - GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, AllThreadsForTest()); common::RowSetCollection row_set_collection; std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kRows); @@ -80,8 +78,6 @@ void TestEvaluateSplits(bool force_read_by_column) { right.SetSubstract(GradStats{total_gpair}, left); } } - - omp_set_num_threads(orig); } TEST(HistEvaluator, Evaluate) { @@ -122,7 +118,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) { // check the evaluator is returning the optimal split std::vector ft{FeatureType::kCategorical}; auto sampler = std::make_shared(); - HistEvaluator evaluator{param_, info_, common::OmpGetNumThreads(0), sampler}; + HistEvaluator evaluator{param_, info_, AllThreadsForTest(), sampler}; evaluator.InitRoot(GradStats{total_gpair_}); RegTree tree; std::vector entries(1); @@ -152,7 +148,7 @@ auto CompareOneHotAndPartition(bool onehot) { auto sampler = std::make_shared(); auto evaluator = - HistEvaluator{param, dmat->Info(), common::OmpGetNumThreads(0), sampler}; + HistEvaluator{param, dmat->Info(), AllThreadsForTest(), sampler}; std::vector entries(1); for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { @@ -203,7 +199,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { info.num_col_ = 1; info.feature_types = {FeatureType::kCategorical}; auto evaluator = - HistEvaluator{param_, info, common::OmpGetNumThreads(0), sampler}; + HistEvaluator{param_, info, AllThreadsForTest(), sampler}; evaluator.InitRoot(GradStats{parent_sum_}); std::vector entries(1); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index d7d0f12cc..1e37f1cd4 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -1,7 +1,8 @@ -/*! - * Copyright 2018-2022 by Contributors +/** + * Copyright 2018-2023 by Contributors */ #include +#include // Context #include @@ -375,6 +376,7 @@ TEST(CPUHistogram, Categorical) { } namespace { void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool force_read_by_column) { + Context ctx; size_t constexpr kEntries = 1 << 16; auto m = CreateSparsePageDMatrix(kEntries, "cache"); @@ -417,7 +419,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, 256}; - multi_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), rows_set.size(), false); + multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false); size_t page_idx{0}; for (auto const &page : m->GetBatches(batch_param)) { @@ -438,17 +440,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo common::RowSetCollection row_set_collection; InitRowPartitionForTest(&row_set_collection, n_samples); - single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false); + single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false); SparsePage concat; std::vector hess(m->Info().num_row_, 1.0f); for (auto const& page : m->GetBatches()) { concat.Push(page); } - auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0), - false, hess); + auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, ctx.Threads(), false, hess); GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false, - std::numeric_limits::quiet_NaN(), common::OmpGetNumThreads(0)); + std::numeric_limits::quiet_NaN(), ctx.Threads()); single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column); single_page = single_build.Histogram()[0]; } diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index a6677ad02..dc41b3edd 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2021-2022 by XGBoost contributors +/** + * Copyright 2021-2023 by XGBoost contributors */ #include #include @@ -62,7 +62,6 @@ class TestPredictionCache : public ::testing::Test { void RunTest(std::string updater_name) { { - omp_set_num_threads(1); Context ctx; ctx.InitAllowUnknown(Args{{"nthread", "8"}}); if (updater_name == "grow_gpu_hist") { diff --git a/tests/cpp/tree/test_regen.cc b/tests/cpp/tree/test_regen.cc index 47a576f45..b766e0775 100644 --- a/tests/cpp/tree/test_regen.cc +++ b/tests/cpp/tree/test_regen.cc @@ -1,11 +1,12 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023 XGBoost contributors */ #include #include "../../../src/data/adapter.h" #include "../../../src/data/simple_dmatrix.h" #include "../helpers.h" +#include "xgboost/context.h" namespace xgboost { namespace { @@ -50,7 +51,7 @@ class RegenTest : public ::testing::Test { auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage); auto adapter = data::ArrayAdapter(StringView{dense}); p_fmat_ = std::shared_ptr(new DMatrixForTest{ - &adapter, std::numeric_limits::quiet_NaN(), common::OmpGetNumThreads(0)}); + &adapter, std::numeric_limits::quiet_NaN(), AllThreadsForTest()}); p_fmat_->Info().labels.Reshape(256, 1); auto labels = p_fmat_->Info().labels.Data();