Reduce thread contention in column split histogram test. (#10708)
This commit is contained in:
parent
2258bc870d
commit
abe65e3769
@ -222,10 +222,9 @@ TEST(CPUHistogram, SyncHist) {
|
|||||||
TestSyncHist(false);
|
TestSyncHist(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
void TestBuildHistogram(Context const* ctx, bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
||||||
size_t constexpr kNRows = 8, kNCols = 16;
|
size_t constexpr kNRows = 8, kNCols = 16;
|
||||||
int32_t constexpr kMaxBins = 4;
|
int32_t constexpr kMaxBins = 4;
|
||||||
Context ctx;
|
|
||||||
auto p_fmat =
|
auto p_fmat =
|
||||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||||
if (is_col_split) {
|
if (is_col_split) {
|
||||||
@ -233,7 +232,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
}
|
}
|
||||||
auto const &gmat =
|
auto const &gmat =
|
||||||
*(p_fmat->GetBatches<GHistIndexMatrix>(&ctx, BatchParam{kMaxBins, 0.5}).begin());
|
*(p_fmat->GetBatches<GHistIndexMatrix>(ctx, BatchParam{kMaxBins, 0.5}).begin());
|
||||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||||
|
|
||||||
static double constexpr kEps = 1e-6;
|
static double constexpr kEps = 1e-6;
|
||||||
@ -244,7 +243,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
HistogramBuilder histogram;
|
HistogramBuilder histogram;
|
||||||
HistMakerTrainParam hist_param;
|
HistMakerTrainParam hist_param;
|
||||||
histogram.Reset(&ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param);
|
histogram.Reset(ctx, total_bins, {kMaxBins, 0.5}, is_distributed, is_col_split, &hist_param);
|
||||||
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
|
|
||||||
@ -262,11 +261,11 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
histogram.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
|
histogram.AddHistRows(&tree, &nodes_to_build, &dummy_sub, false);
|
||||||
common::BlockedSpace2d space{
|
common::BlockedSpace2d space{
|
||||||
1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256};
|
1, [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size(); }, 256};
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {kMaxBins, 0.5})) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(ctx, {kMaxBins, 0.5})) {
|
||||||
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
histogram.BuildHist(0, space, gidx, row_set_collection, nodes_to_build,
|
||||||
linalg::MakeTensorView(&ctx, gpair, gpair.size()), force_read_by_column);
|
linalg::MakeTensorView(ctx, gpair, gpair.size()), force_read_by_column);
|
||||||
}
|
}
|
||||||
histogram.SyncHistogram(&ctx, &tree, nodes_to_build, {});
|
histogram.SyncHistogram(ctx, &tree, nodes_to_build, {});
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back());
|
||||||
@ -292,16 +291,21 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUHistogram, BuildHist) {
|
TEST(CPUHistogram, BuildHist) {
|
||||||
TestBuildHistogram(true, false, false);
|
Context ctx;
|
||||||
TestBuildHistogram(false, false, false);
|
TestBuildHistogram(&ctx, true, false, false);
|
||||||
TestBuildHistogram(true, true, false);
|
TestBuildHistogram(&ctx, false, false, false);
|
||||||
TestBuildHistogram(false, true, false);
|
TestBuildHistogram(&ctx, true, true, false);
|
||||||
|
TestBuildHistogram(&ctx, false, true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUHistogram, BuildHistColSplit) {
|
TEST(CPUHistogram, BuildHistColumnSplit) {
|
||||||
auto constexpr kWorkers = 4;
|
auto constexpr kWorkers = 4;
|
||||||
collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, true); });
|
Context ctx;
|
||||||
collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, false, true); });
|
std::int32_t n_total_threads = std::thread::hardware_concurrency();
|
||||||
|
auto n_threads = std::max(n_total_threads / kWorkers, 1);
|
||||||
|
ctx.UpdateAllowUnknown(Args{{"nthread", std::to_string(n_threads)}});
|
||||||
|
collective::TestDistributedGlobal(kWorkers, [&] { TestBuildHistogram(&ctx, true, true, true); });
|
||||||
|
collective::TestDistributedGlobal(kWorkers, [&] { TestBuildHistogram(&ctx, true, false, true); });
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user