Consistent use of context to specify number of threads. (#8733)

- Use context in all tests.
- Use context in R.
- Use context in C API DMatrix initialization. (0 threads is used as dft).
This commit is contained in:
Jiaming Yuan
2023-01-30 15:25:31 +08:00
committed by GitHub
parent 21a28f2cc5
commit 3760cede0f
24 changed files with 212 additions and 152 deletions

View File

@@ -1,7 +1,8 @@
/*!
* Copyright 2018-2022 by Contributors
/**
* Copyright 2018-2023 by Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // Context
#include <limits>
@@ -375,6 +376,7 @@ TEST(CPUHistogram, Categorical) {
}
namespace {
void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool force_read_by_column) {
Context ctx;
size_t constexpr kEntries = 1 << 16;
auto m = CreateSparsePageDMatrix(kEntries, "cache");
@@ -417,7 +419,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
256};
multi_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), rows_set.size(), false);
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false);
size_t page_idx{0};
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
@@ -438,17 +440,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
common::RowSetCollection row_set_collection;
InitRowPartitionForTest(&row_set_collection, n_samples);
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false);
SparsePage concat;
std::vector<float> hess(m->Info().num_row_, 1.0f);
for (auto const& page : m->GetBatches<SparsePage>()) {
concat.Push(page);
}
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0),
false, hess);
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, ctx.Threads(), false, hess);
GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false,
std::numeric_limits<double>::quiet_NaN(), common::OmpGetNumThreads(0));
std::numeric_limits<double>::quiet_NaN(), ctx.Threads());
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair, force_read_by_column);
single_page = single_build.Histogram()[0];
}