Consistent use of context to specify number of threads. (#8733)
- Use context in all tests. - Use context in R. - Use context in C API DMatrix initialization. (0 threads is used as dft).
This commit is contained in:
parent
21a28f2cc5
commit
3760cede0f
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <dmlc/common.h>
|
||||
#include <dmlc/omp.h>
|
||||
@ -115,7 +115,9 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||
din = REAL(mat);
|
||||
}
|
||||
std::vector<float> data(nrow * ncol);
|
||||
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||
xgboost::Context ctx;
|
||||
ctx.nthread = asInteger(n_threads);
|
||||
std::int32_t threads = ctx.Threads();
|
||||
|
||||
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
|
||||
for (size_t j = 0; j < ncol; ++j) {
|
||||
@ -149,8 +151,9 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
|
||||
for (size_t i = 0; i < nindptr; ++i) {
|
||||
col_ptr_[i] = static_cast<size_t>(p_indptr[i]);
|
||||
}
|
||||
int32_t threads = xgboost::common::OmpGetNumThreads(asInteger(n_threads));
|
||||
xgboost::common::ParallelFor(ndata, threads, [&](xgboost::omp_ulong i) {
|
||||
xgboost::Context ctx;
|
||||
ctx.nthread = asInteger(n_threads);
|
||||
xgboost::common::ParallelFor(ndata, ctx.Threads(), [&](xgboost::omp_ulong i) {
|
||||
indices_[i] = static_cast<unsigned>(p_indices[i]);
|
||||
data_[i] = static_cast<float>(p_data[i]);
|
||||
});
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// Copyright (c) 2014-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <rabit/c_api.h>
|
||||
|
||||
#include <cstring>
|
||||
@ -279,7 +281,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
xgboost_CHECK_C_ARG_PTR(reset);
|
||||
@ -319,7 +321,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", 0);
|
||||
auto max_bin = OptionalArg<Integer, int64_t>(jconfig, "max_bin", 256);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(next);
|
||||
@ -420,7 +422,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
|
||||
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
API_END();
|
||||
@ -435,10 +437,9 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
|
||||
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -506,8 +507,7 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char c
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
auto missing = GetMissing(jconfig);
|
||||
auto n_batches = RequiredArg<Integer>(jconfig, "nbatch", __func__);
|
||||
auto n_threads =
|
||||
OptionalArg<Integer, std::int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(jconfig, "nthread", 0);
|
||||
data::RecordBatchesIterAdapter adapter(next, n_batches);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// Copyright (c) 2019-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
@ -68,8 +70,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data,
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads =
|
||||
OptionalArg<Integer, std::int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(config, "nthread", 0);
|
||||
data::CudfAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
@ -83,8 +84,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
|
||||
std::string json_str{data};
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads =
|
||||
OptionalArg<Integer, std::int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto n_threads = OptionalArg<Integer, std::int64_t>(config, "nthread", 0);
|
||||
data::CupyAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "threading_utils.h"
|
||||
|
||||
@ -10,14 +10,6 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/**
|
||||
* \brief Get thread limit from CFS
|
||||
*
|
||||
* Modified from
|
||||
* github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp
|
||||
*
|
||||
* MIT License: Copyright (c) 2016 Domagoj Šarić
|
||||
*/
|
||||
int32_t GetCfsCPUCount() noexcept {
|
||||
#if defined(__linux__)
|
||||
// https://bugs.openjdk.java.net/browse/JDK-8146115
|
||||
@ -47,5 +39,20 @@ int32_t GetCfsCPUCount() noexcept {
|
||||
#endif // defined(__linux__)
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::int32_t OmpGetNumThreads(std::int32_t n_threads) {
|
||||
// Don't use parallel if we are in a parallel region.
|
||||
if (omp_in_parallel()) {
|
||||
return 1;
|
||||
}
|
||||
// If -1 or 0 is specified by the user, we default to maximum number of threads.
|
||||
if (n_threads <= 0) {
|
||||
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
|
||||
}
|
||||
// Honor the openmp thread limit, which can be set via environment variable.
|
||||
n_threads = std::min(n_threads, OmpGetThreadLimit());
|
||||
n_threads = std::max(n_threads, 1);
|
||||
return n_threads;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
#define XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
@ -231,23 +231,28 @@ void ParallelFor(Index size, int32_t n_threads, Func fn) {
|
||||
ParallelFor(size, n_threads, Sched::Static(), fn);
|
||||
}
|
||||
|
||||
inline int32_t OmpGetThreadLimit() {
|
||||
int32_t limit = omp_get_thread_limit();
|
||||
inline std::int32_t OmpGetThreadLimit() {
|
||||
std::int32_t limit = omp_get_thread_limit();
|
||||
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
|
||||
return limit;
|
||||
}
|
||||
|
||||
int32_t GetCfsCPUCount() noexcept;
|
||||
|
||||
inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());
|
||||
}
|
||||
n_threads = std::min(n_threads, OmpGetThreadLimit());
|
||||
n_threads = std::max(n_threads, 1);
|
||||
return n_threads;
|
||||
}
|
||||
/**
|
||||
* \brief Get thread limit from CFS.
|
||||
*
|
||||
* This function has non-trivial overhead and should not be called repeatly.
|
||||
*
|
||||
* Modified from
|
||||
* github.com/psiha/sweater/blob/master/include/boost/sweater/hardware_concurrency.hpp
|
||||
*
|
||||
* MIT License: Copyright (c) 2016 Domagoj Šarić
|
||||
*/
|
||||
std::int32_t GetCfsCPUCount() noexcept;
|
||||
|
||||
/**
|
||||
* \brief Get the number of available threads based on n_threads specified by users.
|
||||
*/
|
||||
std::int32_t OmpGetNumThreads(std::int32_t n_threads);
|
||||
|
||||
/*!
|
||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than
|
||||
|
||||
@ -53,9 +53,6 @@ void Context::ConfigureGpuId(bool require_gpu) {
|
||||
}
|
||||
|
||||
std::int32_t Context::Threads() const {
|
||||
if (omp_in_parallel()) {
|
||||
return 1;
|
||||
}
|
||||
auto n_threads = common::OmpGetNumThreads(nthread);
|
||||
if (cfs_cpu_count_ > 0) {
|
||||
n_threads = std::min(n_threads, cfs_cpu_count_);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include "xgboost/data.h"
|
||||
@ -27,6 +27,7 @@
|
||||
#include "sparse_page_writer.h"
|
||||
#include "validation.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/logging.h"
|
||||
@ -850,7 +851,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, cache_file);
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file);
|
||||
} else {
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||
file_format};
|
||||
|
||||
@ -1,16 +1,21 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2019-2023 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/version_config.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/json.h> // Json
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/version_config.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include <cstddef> // std::size_t
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <string> // std::string
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(CAPI, XGDMatrixCreateFromMatDT) {
|
||||
std::vector<int> col0 = {0, -1, 3};
|
||||
@ -83,6 +88,39 @@ TEST(CAPI, Version) {
|
||||
ASSERT_EQ(patch, XGBOOST_VER_PATCH);
|
||||
}
|
||||
|
||||
TEST(CAPI, XGDMatrixCreateFromCSR) {
|
||||
HostDeviceVector<std::size_t> indptr{0, 3};
|
||||
HostDeviceVector<double> data{0.0, 1.0, 2.0};
|
||||
HostDeviceVector<std::size_t> indices{0, 1, 2};
|
||||
auto indptr_arr = GetArrayInterface(&indptr, 2, 1);
|
||||
auto indices_arr = GetArrayInterface(&indices, 3, 1);
|
||||
auto data_arr = GetArrayInterface(&data, 3, 1);
|
||||
std::string sindptr, sindices, sdata, sconfig;
|
||||
Json::Dump(indptr_arr, &sindptr);
|
||||
Json::Dump(indices_arr, &sindices);
|
||||
Json::Dump(data_arr, &sdata);
|
||||
Json config{Object{}};
|
||||
config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
|
||||
Json::Dump(config, &sconfig);
|
||||
|
||||
DMatrixHandle handle;
|
||||
XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), 3, sconfig.c_str(),
|
||||
&handle);
|
||||
bst_ulong n;
|
||||
ASSERT_EQ(XGDMatrixNumRow(handle, &n), 0);
|
||||
ASSERT_EQ(n, 1);
|
||||
ASSERT_EQ(XGDMatrixNumCol(handle, &n), 0);
|
||||
ASSERT_EQ(n, 3);
|
||||
ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0);
|
||||
ASSERT_EQ(n, 3);
|
||||
|
||||
std::shared_ptr<xgboost::DMatrix> *pp_fmat =
|
||||
static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
ASSERT_EQ((*pp_fmat)->Ctx()->Threads(), AllThreadsForTest());
|
||||
|
||||
XGDMatrixFree(handle);
|
||||
}
|
||||
|
||||
TEST(CAPI, ConfigIO) {
|
||||
size_t constexpr kRows = 10;
|
||||
auto p_dmat = RandomDataGenerator(kRows, 10, 0).GenerateDMatrix();
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -18,11 +18,10 @@ TEST(DenseColumn, Test) {
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
auto sparse_thresh = 0.2;
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false,
|
||||
common::OmpGetNumThreads(0)};
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, AllThreadsForTest()};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, sparse_thresh, AllThreadsForTest());
|
||||
}
|
||||
ASSERT_GE(column_matrix.GetTypeSize(), last);
|
||||
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
|
||||
@ -65,10 +64,10 @@ TEST(SparseColumn, Test) {
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)};
|
||||
GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, AllThreadsForTest()};
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, 1.0, AllThreadsForTest());
|
||||
}
|
||||
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
using T = decltype(dtype);
|
||||
@ -93,10 +92,10 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix();
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0));
|
||||
GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, AllThreadsForTest());
|
||||
ColumnMatrix column_matrix;
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, common::OmpGetNumThreads(0));
|
||||
column_matrix.InitFromSparse(page, gmat, 0.2, AllThreadsForTest());
|
||||
}
|
||||
ASSERT_TRUE(column_matrix.AnyMissing());
|
||||
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
@ -14,15 +14,13 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
size_t GetNThreads() { return common::OmpGetNumThreads(0); }
|
||||
|
||||
void ParallelGHistBuilderReset() {
|
||||
constexpr size_t kBins = 10;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kNodesExtended = 10;
|
||||
constexpr size_t kTasksPerNode = 10;
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
const size_t nthreads = AllThreadsForTest();
|
||||
|
||||
HistCollection collection;
|
||||
collection.Init(kBins);
|
||||
@ -78,7 +76,7 @@ void ParallelGHistBuilderReduceHist(){
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kTasksPerNode = 10;
|
||||
constexpr double kValue = 1.0;
|
||||
const size_t nthreads = GetNThreads();
|
||||
const size_t nthreads = AllThreadsForTest();
|
||||
|
||||
HistCollection collection;
|
||||
collection.Init(kBins);
|
||||
@ -167,7 +165,7 @@ TEST(HistUtil, DenseCutsCategorical) {
|
||||
std::vector<float> x_sorted(x);
|
||||
std::sort(x_sorted.begin(), x_sorted.end());
|
||||
auto dmat = GetDMatrixFromData(x, n, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
auto cuts_from_sketch = cuts.Values();
|
||||
EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
|
||||
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
|
||||
@ -180,13 +178,12 @@ TEST(HistUtil, DenseCutsCategorical) {
|
||||
TEST(HistUtil, DenseCutsAccuracyTest) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100};
|
||||
// omp_set_num_threads(1);
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -203,13 +200,11 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) {
|
||||
dmat->Info().weights_.HostVector() = w;
|
||||
for (auto num_bins : bin_sizes) {
|
||||
{
|
||||
HistogramCuts cuts =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), true);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), true);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
{
|
||||
HistogramCuts cuts =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), false);
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), false);
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -231,14 +226,14 @@ void TestQuantileWithHessian(bool use_sorted) {
|
||||
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts_hess =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted, hessian);
|
||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted, hessian);
|
||||
for (size_t i = 0; i < w.size(); ++i) {
|
||||
dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i];
|
||||
}
|
||||
ValidateCuts(cuts_hess, dmat.get(), num_bins);
|
||||
|
||||
HistogramCuts cuts_wh =
|
||||
SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0), use_sorted);
|
||||
SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest(), use_sorted);
|
||||
ValidateCuts(cuts_wh, dmat.get(), num_bins);
|
||||
|
||||
ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size());
|
||||
@ -265,7 +260,7 @@ TEST(HistUtil, DenseCutsExternalMemory) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, tmpdir);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
@ -285,7 +280,7 @@ TEST(HistUtil, IndexBinBound) {
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize());
|
||||
}
|
||||
@ -308,7 +303,7 @@ TEST(HistUtil, IndexBinData) {
|
||||
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, common::OmpGetNumThreads(0));
|
||||
GHistIndexMatrix hmat(p_fmat.get(), max_bin, 0.5, false, AllThreadsForTest());
|
||||
uint32_t const* offsets = hmat.index.Offset();
|
||||
EXPECT_EQ(hmat.index.Size(), kRows*kCols);
|
||||
switch (max_bin) {
|
||||
@ -331,9 +326,8 @@ TEST(HistUtil, IndexBinData) {
|
||||
void TestSketchFromWeights(bool with_group) {
|
||||
size_t constexpr kRows = 300, kCols = 20, kBins = 256;
|
||||
size_t constexpr kGroups = 10;
|
||||
auto m =
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix();
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
@ -369,7 +363,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
|
||||
if (with_group) {
|
||||
m->Info().weights_ = decltype(m->Info().weights_)(); // remove weight
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
for (size_t i = 0; i < cuts.Values().size(); ++i) {
|
||||
EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]);
|
||||
}
|
||||
@ -388,7 +382,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < h_weights.size(); ++i) {
|
||||
h_weights[i] = static_cast<float>(i + 1) / static_cast<float>(kGroups);
|
||||
}
|
||||
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts weighted = SketchOnDMatrix(m.get(), kBins, AllThreadsForTest());
|
||||
ValidateCuts(weighted, m.get(), kBins);
|
||||
}
|
||||
}
|
||||
@ -400,10 +394,10 @@ TEST(HistUtil, SketchFromWeights) {
|
||||
|
||||
TEST(HistUtil, SketchCategoricalFeatures) {
|
||||
TestCategoricalSketch(1000, 256, 32, false, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
|
||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
||||
});
|
||||
TestCategoricalSketch(1000, 256, 32, true, [](DMatrix* p_fmat, int32_t num_bins) {
|
||||
return SketchOnDMatrix(p_fmat, num_bins, common::OmpGetNumThreads(0));
|
||||
return SketchOnDMatrix(p_fmat, num_bins, AllThreadsForTest());
|
||||
});
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <thrust/device_vector.h>
|
||||
@ -27,7 +27,7 @@ namespace common {
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins, AllThreadsForTest());
|
||||
return cuts;
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ TEST(HistUtil, DeviceSketch) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, common::OmpGetNumThreads(0));
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins, AllThreadsForTest());
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "test_quantile.h"
|
||||
|
||||
@ -73,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
auto hess = Span<float const>{hessian};
|
||||
|
||||
ContainerType<use_column> sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, OmpGetNumThreads(0));
|
||||
column_size, false, AllThreadsForTest());
|
||||
|
||||
if (use_column) {
|
||||
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
|
||||
@ -94,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
column_size, false, OmpGetNumThreads(0));
|
||||
column_size, false, AllThreadsForTest());
|
||||
m->Info().num_row_ = rows;
|
||||
|
||||
for (auto rank = 0; rank < world; ++rank) {
|
||||
@ -188,7 +188,7 @@ void TestSameOnAllWorkers() {
|
||||
.MaxCategory(17)
|
||||
.Seed(rank + seed)
|
||||
.GenerateDMatrix();
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, common::OmpGetNumThreads(0));
|
||||
auto cuts = SketchOnDMatrix(m.get(), n_bins, AllThreadsForTest());
|
||||
std::vector<float> cut_values(cuts.Values().size() * world, 0);
|
||||
std::vector<
|
||||
typename std::remove_reference_t<decltype(cuts.Ptrs())>::value_type>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2018-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2018-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
@ -45,7 +45,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
|
||||
out_vec.Fill(0);
|
||||
|
||||
Transform<>::Init(TestTransformRange<bst_float>{},
|
||||
Range{0, static_cast<Range::DifferenceType>(size)}, common::OmpGetNumThreads(0),
|
||||
Range{0, static_cast<Range::DifferenceType>(size)}, AllThreadsForTest(),
|
||||
TRANSFORM_GPU)
|
||||
.Eval(&out_vec, &in_vec);
|
||||
std::vector<bst_float> res = out_vec.HostVector();
|
||||
@ -61,8 +61,8 @@ TEST(TransformDeathTest, Exception) {
|
||||
EXPECT_DEATH(
|
||||
{
|
||||
Transform<>::Init([](size_t idx, common::Span<float const> _in) { _in[idx + 1]; },
|
||||
Range(0, static_cast<Range::DifferenceType>(kSize)),
|
||||
common::OmpGetNumThreads(0), -1)
|
||||
Range(0, static_cast<Range::DifferenceType>(kSize)), AllThreadsForTest(),
|
||||
-1)
|
||||
.Eval(&in_vec);
|
||||
},
|
||||
"");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -70,14 +70,14 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
||||
SparsePage page; // Consolidated sparse page
|
||||
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||
// Transpose each batch and push
|
||||
SparsePage tmp = batch.GetTranspose(ncols, common::OmpGetNumThreads(0));
|
||||
SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest());
|
||||
page.PushCSC(tmp);
|
||||
}
|
||||
|
||||
// Make sure that the final sparse page has the right number of entries
|
||||
ASSERT_EQ(kEntries, page.data.Size());
|
||||
|
||||
page.SortRows(common::OmpGetNumThreads(0));
|
||||
page.SortRows(AllThreadsForTest());
|
||||
auto v = page.GetView();
|
||||
for (size_t i = 0; i < v.Size(); ++i) {
|
||||
auto column = v[i];
|
||||
@ -89,7 +89,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
||||
|
||||
TEST(SparsePage, SortIndices) {
|
||||
auto p_fmat = RandomDataGenerator{100, 10, 0.6}.GenerateDMatrix();
|
||||
auto n_threads = common::OmpGetNumThreads(0);
|
||||
auto n_threads = AllThreadsForTest();
|
||||
SparsePage copy;
|
||||
for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
|
||||
ASSERT_TRUE(page.IsIndicesSorted(n_threads));
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
@ -46,7 +46,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
|
||||
h_ft.resize(kCols, FeatureType::kCategorical);
|
||||
|
||||
BatchParam p(max_bins, 0.8);
|
||||
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
|
||||
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, AllThreadsForTest(), {});
|
||||
|
||||
auto x_copy = x;
|
||||
std::sort(x_copy.begin(), x_copy.end());
|
||||
@ -75,7 +75,7 @@ TEST(GradientIndex, PushBatch) {
|
||||
|
||||
auto test = [&](float sparisty) {
|
||||
auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true);
|
||||
auto cuts = common::SketchOnDMatrix(m.get(), max_bins, common::OmpGetNumThreads(0), false, {});
|
||||
auto cuts = common::SketchOnDMatrix(m.get(), max_bins, AllThreadsForTest(), false, {});
|
||||
common::HistogramCuts copy_cuts = cuts;
|
||||
|
||||
ASSERT_EQ(m->Info().num_row_, kRows);
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors
|
||||
*/
|
||||
#include "test_iterative_dmatrix.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#include "../../../src/data/gradient_index.h"
|
||||
#include "../../../src/data/iterative_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/data.h" // DMatrix
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@ -20,8 +23,10 @@ TEST(IterativeDMatrix, IsDense) {
|
||||
int n_bins = 16;
|
||||
auto test = [n_bins](float sparsity) {
|
||||
NumpyArrayIterForTest iter(sparsity);
|
||||
auto n_threads = 0;
|
||||
IterativeDMatrix m(&iter, iter.Proxy(), nullptr, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||
std::numeric_limits<float>::quiet_NaN(), n_threads, n_bins);
|
||||
ASSERT_EQ(m.Ctx()->Threads(), AllThreadsForTest());
|
||||
if (sparsity == 0.0) {
|
||||
ASSERT_TRUE(m.IsDense());
|
||||
} else {
|
||||
|
||||
@ -411,3 +411,14 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
|
||||
delete dmat;
|
||||
delete dmat_read;
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, Threads) {
|
||||
size_t constexpr kRows{16};
|
||||
size_t constexpr kCols{8};
|
||||
HostDeviceVector<float> data;
|
||||
auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data);
|
||||
auto adapter = data::ArrayAdapter{StringView{arr_str}};
|
||||
std::unique_ptr<DMatrix> p_fmat{
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0, "")};
|
||||
ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest());
|
||||
}
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
// Copyright by Contributors
|
||||
/**
|
||||
* Copyright 2016-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
@ -22,13 +24,15 @@ void TestSparseDMatrixLoadFile() {
|
||||
CreateBigTestData(opath, 3 * 64, false);
|
||||
opath += "?indexing_mode=1";
|
||||
data::FileIterator iter{opath, 0, 1, "libsvm"};
|
||||
auto n_threads = 0;
|
||||
data::SparsePageDMatrix m{&iter,
|
||||
iter.Proxy(),
|
||||
data::fileiter::Reset,
|
||||
data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
1,
|
||||
n_threads,
|
||||
tmpdir.path + "cache"};
|
||||
ASSERT_EQ(AllThreadsForTest(), m.Ctx()->Threads());
|
||||
ASSERT_EQ(m.Info().num_col_, 5);
|
||||
ASSERT_EQ(m.Info().num_row_, 64);
|
||||
|
||||
@ -213,16 +217,13 @@ TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
// Create multiple sparse pages
|
||||
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
|
||||
auto n_threads = omp_get_max_threads();
|
||||
omp_set_num_threads(16);
|
||||
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
ASSERT_EQ(dmat->Info().num_col_, page.Size());
|
||||
}
|
||||
omp_set_num_threads(n_threads);
|
||||
}
|
||||
|
||||
auto TestSparsePageDMatrixDeterminism(int32_t threads) {
|
||||
omp_set_num_threads(threads);
|
||||
std::vector<float> sparse_data;
|
||||
std::vector<size_t> sparse_rptr;
|
||||
std::vector<bst_feature_t> sparse_cids;
|
||||
@ -231,16 +232,15 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) {
|
||||
CreateBigTestData(filename, 1 << 16);
|
||||
|
||||
data::FileIterator iter(filename, 0, 1, "auto");
|
||||
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix{
|
||||
&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, filename}};
|
||||
std::unique_ptr<DMatrix> sparse{
|
||||
new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), threads, filename}};
|
||||
CHECK(sparse->Ctx()->Threads() == threads || sparse->Ctx()->Threads() == AllThreadsForTest());
|
||||
|
||||
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
|
||||
|
||||
auto cache_name =
|
||||
data::MakeId(filename,
|
||||
dynamic_cast<data::SparsePageDMatrix *>(sparse.get())) +
|
||||
".row.page";
|
||||
data::MakeId(filename, dynamic_cast<data::SparsePageDMatrix *>(sparse.get())) + ".row.page";
|
||||
std::string cache = common::LoadSequentialFile(cache_name);
|
||||
return cache;
|
||||
}
|
||||
|
||||
@ -438,9 +438,9 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
|
||||
size_t n_rows = n_entries / n_columns;
|
||||
NumpyArrayIterForTest iter(0, n_rows, n_columns, 2);
|
||||
|
||||
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
|
||||
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix)};
|
||||
std::unique_ptr<DMatrix> dmat{
|
||||
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, prefix)};
|
||||
auto row_page_path =
|
||||
data::MakeId(prefix,
|
||||
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
@ -461,6 +462,8 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint
|
||||
return mparam;
|
||||
}
|
||||
|
||||
inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); }
|
||||
|
||||
template <typename Function, typename... Args>
|
||||
void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) {
|
||||
std::vector<std::thread> threads;
|
||||
@ -481,5 +484,4 @@ void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
@ -14,9 +14,7 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
void TestEvaluateSplits(bool force_read_by_column) {
|
||||
int static constexpr kRows = 8, kCols = 16;
|
||||
auto orig = omp_get_max_threads();
|
||||
int32_t n_threads = std::min(omp_get_max_threads(), 4);
|
||||
omp_set_num_threads(n_threads);
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
|
||||
TrainParam param;
|
||||
@ -32,7 +30,7 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
|
||||
size_t constexpr kMaxBins = 4;
|
||||
// dense, no missing values
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, common::OmpGetNumThreads(0));
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, 0.5, false, AllThreadsForTest());
|
||||
common::RowSetCollection row_set_collection;
|
||||
std::vector<size_t> &row_indices = *row_set_collection.Data();
|
||||
row_indices.resize(kRows);
|
||||
@ -80,8 +78,6 @@ void TestEvaluateSplits(bool force_read_by_column) {
|
||||
right.SetSubstract(GradStats{total_gpair}, left);
|
||||
}
|
||||
}
|
||||
|
||||
omp_set_num_threads(orig);
|
||||
}
|
||||
|
||||
TEST(HistEvaluator, Evaluate) {
|
||||
@ -122,7 +118,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) {
|
||||
// check the evaluator is returning the optimal split
|
||||
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
HistEvaluator<CPUExpandEntry> evaluator{param_, info_, common::OmpGetNumThreads(0), sampler};
|
||||
HistEvaluator<CPUExpandEntry> evaluator{param_, info_, AllThreadsForTest(), sampler};
|
||||
evaluator.InitRoot(GradStats{total_gpair_});
|
||||
RegTree tree;
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
@ -152,7 +148,7 @@ auto CompareOneHotAndPartition(bool onehot) {
|
||||
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
auto evaluator =
|
||||
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), common::OmpGetNumThreads(0), sampler};
|
||||
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), AllThreadsForTest(), sampler};
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
|
||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
||||
@ -203,7 +199,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
|
||||
info.num_col_ = 1;
|
||||
info.feature_types = {FeatureType::kCategorical};
|
||||
auto evaluator =
|
||||
HistEvaluator<CPUExpandEntry>{param_, info, common::OmpGetNumThreads(0), sampler};
|
||||
HistEvaluator<CPUExpandEntry>{param_, info, AllThreadsForTest(), sampler};
|
||||
evaluator.InitRoot(GradStats{parent_sum_});
|
||||
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
/*!
|
||||
* Copyright 2018-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2018-2023 by Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
|
||||
#include <limits>
|
||||
|
||||
@ -375,6 +376,7 @@ TEST(CPUHistogram, Categorical) {
|
||||
}
|
||||
namespace {
|
||||
void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool force_read_by_column) {
|
||||
Context ctx;
|
||||
size_t constexpr kEntries = 1 << 16;
|
||||
auto m = CreateSparsePageDMatrix(kEntries, "cache");
|
||||
|
||||
@ -417,7 +419,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
||||
256};
|
||||
|
||||
multi_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), rows_set.size(), false);
|
||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false);
|
||||
|
||||
size_t page_idx{0};
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
||||
@ -438,17 +440,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
common::RowSetCollection row_set_collection;
|
||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||
|
||||
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
|
||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false);
|
||||
SparsePage concat;
|
||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
concat.Push(page);
|
||||
}
|
||||
|
||||
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0),
|
||||
false, hess);
|
||||
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, ctx.Threads(), false, hess);
|
||||
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
|
||||
std::numeric_limits<double>::quiet_NaN(), common::OmpGetNumThreads(0));
|
||||
std::numeric_limits<double>::quiet_NaN(), ctx.Threads());
|
||||
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column);
|
||||
single_page = single_build.Histogram()[0];
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021-2022 by XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
@ -62,7 +62,6 @@ class TestPredictionCache : public ::testing::Test {
|
||||
|
||||
void RunTest(std::string updater_name) {
|
||||
{
|
||||
omp_set_num_threads(1);
|
||||
Context ctx;
|
||||
ctx.InitAllowUnknown(Args{{"nthread", "8"}});
|
||||
if (updater_name == "grow_gpu_hist") {
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/simple_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/context.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
@ -50,7 +51,7 @@ class RegenTest : public ::testing::Test {
|
||||
auto dense = RandomDataGenerator{kRows, kCols, 0.5}.GenerateArrayInterface(&storage);
|
||||
auto adapter = data::ArrayAdapter(StringView{dense});
|
||||
p_fmat_ = std::shared_ptr<DMatrix>(new DMatrixForTest{
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), common::OmpGetNumThreads(0)});
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), AllThreadsForTest()});
|
||||
|
||||
p_fmat_->Info().labels.Reshape(256, 1);
|
||||
auto labels = p_fmat_->Info().labels.Data();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user