Reduce thread contention in column split histogram test. (#10708)

This commit is contained in:
Jiaming Yuan 2024-08-17 01:00:32 +08:00 committed by GitHub
parent 2258bc870d
commit abe65e3769
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -222,10 +222,9 @@ TEST(CPUHistogram, SyncHist) {
TestSyncHist(false);
}
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
void TestBuildHistogram(Context const* ctx, bool is_distributed, bool force_read_by_column, bool is_col_split) {
size_t constexpr kNRows = 8, kNCols = 16;
int32_t constexpr kMaxBins = 4;
Context ctx;
auto p_fmat =
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
if (is_col_split) {
@ -233,7 +232,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
}
auto const &gmat =
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
*(p_fmat->GetBatches<GHistIndexMatrix>(ctx, BatchParam{kMaxBins, 0.5}).begin());
uint32_t total_bins = gmat.cut.Ptrs().back();
static double constexpr kEps = 1e-6;
@ -244,7 +243,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
bst_node_t nid = 0;
HistogramBuilder histogram;
HistMakerTrainParam hist_param;
histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param);
histogram.Reset(ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param);
RegTree tree;
@ -262,11 +261,11 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
histogram.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
common::BlockedSpace2d space{
1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256};
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {kMaxBins, 0.5})) {
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx, {kMaxBins, 0.5})) {
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
linalg::MakeTensorView(ctx, gpair, gpair.size()), force_read_by_column);
}
histogram.SyncHistogram(&ctx, &tree, nodes_to_build, {});
histogram.SyncHistogram(ctx, &tree, nodes_to_build, {});
// Check if number of histogram bins is correct
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
@ -292,16 +291,21 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
}
TEST(CPUHistogram, BuildHist) {
TestBuildHistogram(true, false, false);
TestBuildHistogram(false, false, false);
TestBuildHistogram(true, true, false);
TestBuildHistogram(false, true, false);
Context ctx;
TestBuildHistogram(&ctx, true, false, false);
TestBuildHistogram(&ctx, false, false, false);
TestBuildHistogram(&ctx, true, true, false);
TestBuildHistogram(&ctx, false, true, false);
}
TEST(CPUHistogram, BuildHistColSplit) {
TEST(CPUHistogram, BuildHistColumnSplit) {
auto constexpr kWorkers = 4;
collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, true); });
collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, false, true); });
Context ctx;
std::int32_t n_total_threads = std::thread::hardware_concurrency();
auto n_threads = std::max(n_total_threads / kWorkers, 1);
ctx.UpdateAllowUnknown(Args{{"nthread", std::to_string(n_threads)}});
collective::TestDistributedGlobal(kWorkers, [&] { TestBuildHistogram(&ctx, true, true, true); });
collective::TestDistributedGlobal(kWorkers, [&] { TestBuildHistogram(&ctx, true, false, true); });
}
namespace {