Support column split in histogram builder (#8811)
This commit is contained in:
@@ -48,7 +48,7 @@ void TestAddHistRows(bool is_distributed) {
|
||||
|
||||
HistogramBuilder<CPUExpandEntry> histogram_builder;
|
||||
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
|
||||
is_distributed);
|
||||
is_distributed, false);
|
||||
histogram_builder.AddHistRows(&starting_index, &sync_count,
|
||||
nodes_for_explicit_hist_build_,
|
||||
nodes_for_subtraction_trick_, &tree);
|
||||
@@ -86,7 +86,7 @@ void TestSyncHist(bool is_distributed) {
|
||||
|
||||
HistogramBuilder<CPUExpandEntry> histogram;
|
||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed, false);
|
||||
|
||||
common::RowSetCollection row_set_collection_;
|
||||
{
|
||||
@@ -226,11 +226,14 @@ TEST(CPUHistogram, SyncHist) {
|
||||
TestSyncHist(false);
|
||||
}
|
||||
|
||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_col_split) {
|
||||
size_t constexpr kNRows = 8, kNCols = 16;
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
auto p_fmat =
|
||||
RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
|
||||
if (is_col_split) {
|
||||
p_fmat = std::shared_ptr<DMatrix>{
|
||||
p_fmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||
}
|
||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(BatchParam{kMaxBins, 0.5}).begin());
|
||||
uint32_t total_bins = gmat.cut.Ptrs().back();
|
||||
|
||||
@@ -241,7 +244,8 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
|
||||
bst_node_t nid = 0;
|
||||
HistogramBuilder<CPUExpandEntry> histogram;
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed);
|
||||
histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed,
|
||||
is_col_split);
|
||||
|
||||
RegTree tree;
|
||||
|
||||
@@ -284,11 +288,16 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column) {
|
||||
}
|
||||
|
||||
TEST(CPUHistogram, BuildHist) {
|
||||
TestBuildHistogram(true, false);
|
||||
TestBuildHistogram(false, false);
|
||||
TestBuildHistogram(true, true);
|
||||
TestBuildHistogram(false, true);
|
||||
TestBuildHistogram(true, false, false);
|
||||
TestBuildHistogram(false, false, false);
|
||||
TestBuildHistogram(true, true, false);
|
||||
TestBuildHistogram(false, true, false);
|
||||
}
|
||||
|
||||
TEST(CPUHistogram, BuildHistColSplit) {
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true);
|
||||
RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -340,7 +349,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
HistogramBuilder<CPUExpandEntry> cat_hist;
|
||||
for (auto const &gidx : cat_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
||||
cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||
cat_hist.BuildHist(0, gidx, &tree, row_set_collection,
|
||||
nodes_for_explicit_hist_build, {}, gpair.HostVector(),
|
||||
force_read_by_column);
|
||||
@@ -354,7 +363,7 @@ void TestHistogramCategorical(size_t n_categories, bool force_read_by_column) {
|
||||
HistogramBuilder<CPUExpandEntry> onehot_hist;
|
||||
for (auto const &gidx : encode_m->GetBatches<GHistIndexMatrix>({kBins, 0.5})) {
|
||||
auto total_bins = gidx.cut.TotalBins();
|
||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false);
|
||||
onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false, false);
|
||||
onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, nodes_for_explicit_hist_build, {},
|
||||
gpair.HostVector(),
|
||||
force_read_by_column);
|
||||
@@ -419,7 +428,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); },
|
||||
256};
|
||||
|
||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false);
|
||||
multi_build.Reset(total_bins, batch_param, ctx.Threads(), rows_set.size(), false, false);
|
||||
|
||||
size_t page_idx{0};
|
||||
for (auto const &page : m->GetBatches<GHistIndexMatrix>(batch_param)) {
|
||||
@@ -440,7 +449,7 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx, bool fo
|
||||
common::RowSetCollection row_set_collection;
|
||||
InitRowPartitionForTest(&row_set_collection, n_samples);
|
||||
|
||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false);
|
||||
single_build.Reset(total_bins, batch_param, ctx.Threads(), 1, false, false);
|
||||
SparsePage concat;
|
||||
std::vector<float> hess(m->Info().num_row_, 1.0f);
|
||||
for (auto const& page : m->GetBatches<SparsePage>()) {
|
||||
|
||||
Reference in New Issue
Block a user