Support column-split in row partitioner (#8828)

This commit is contained in:
Rong Ou
2023-02-25 12:43:35 -08:00
committed by GitHub
parent 90c0633a28
commit d9688f93c7
8 changed files with 283 additions and 58 deletions

View File

@@ -9,6 +9,7 @@
#include <limits> // std::numeric_limits
#include <vector>
#include "../collective/communicator-inl.h"
#include "../common/numeric.h" // Iota
#include "../common/partition_builder.h"
#include "hist/expand_entry.h" // CPUExpandEntry
@@ -16,17 +17,73 @@
namespace xgboost {
namespace tree {
class CommonRowPartitioner {
static constexpr size_t kPartitionBlockSize = 2048;
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
common::RowSetCollection row_set_collection_;
static constexpr size_t kPartitionBlockSize = 2048;
class ColumnSplitHelper {
public:
ColumnSplitHelper() = default;
ColumnSplitHelper(bst_row_t num_row,
common::PartitionBuilder<kPartitionBlockSize>* partition_builder,
common::RowSetCollection* row_set_collection)
: partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
decision_storage_.resize(num_row);
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
missing_storage_.resize(num_row);
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
}
void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads,
GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix,
std::vector<CPUExpandEntry> const& nodes, RegTree const* p_tree) {
// When data is split by column, we don't have all the feature values in the local worker, so
// we first collect all the decisions and whether the feature is missing into bit vectors.
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
const int32_t nid = nodes[node_in_set].nid;
partition_builder_->MaskRows(node_in_set, nodes, r, gmat, column_matrix, *p_tree,
(*row_set_collection_)[nid].begin, &decision_bits_,
&missing_bits_);
});
// Then aggregate the bit vectors across all the workers.
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
decision_storage_.size());
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
missing_storage_.size());
// Finally use the bit vectors to partition the rows.
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_->GetTaskIdx(node_in_set, begin);
partition_builder_->AllocateForTask(task_id);
partition_builder_->PartitionByMask(node_in_set, nodes, r, gmat, column_matrix, *p_tree,
(*row_set_collection_)[nid].begin, decision_bits_,
missing_bits_);
});
}
private:
using BitVector = RBitField8;
std::vector<BitVector::value_type> decision_storage_{};
BitVector decision_bits_{};
std::vector<BitVector::value_type> missing_storage_{};
BitVector missing_bits_{};
common::PartitionBuilder<kPartitionBlockSize>* partition_builder_;
common::RowSetCollection* row_set_collection_;
};
class CommonRowPartitioner {
public:
bst_row_t base_rowid = 0;
CommonRowPartitioner() = default;
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid)
: base_rowid{_base_rowid} {
CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid,
bool is_col_split)
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
row_set_collection_.Clear();
std::vector<size_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(num_row);
@@ -34,6 +91,10 @@ class CommonRowPartitioner {
std::size_t* p_row_indices = row_indices.data();
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
row_set_collection_.Init();
if (is_col_split_) {
column_split_helper_ = ColumnSplitHelper{num_row, &partition_builder_, &row_set_collection_};
}
}
void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
@@ -156,16 +217,20 @@ class CommonRowPartitioner {
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
// Store results in intermediate buffers from partition_builder_
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
});
if (is_col_split_) {
column_split_helper_.Partition(space, ctx->Threads(), gmat, column_matrix, nodes, p_tree);
} else {
common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) {
size_t begin = r.begin();
const int32_t nid = nodes[node_in_set].nid;
const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin);
partition_builder_.AllocateForTask(task_id);
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
partition_builder_.template Partition<BinIdxType, any_missing, any_cat>(
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
row_set_collection_[nid].begin);
});
}
// 3. Compute offsets to copy blocks of row-indexes
// from partition_builder_ to row_set_collection_
@@ -205,6 +270,12 @@ class CommonRowPartitioner {
ctx, tree, this->Partitions(), p_out_position,
[&](size_t idx) -> bool { return gpair[idx].GetHess() - .0f == .0f; });
}
private:
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
common::RowSetCollection row_set_collection_;
bool is_col_split_;
ColumnSplitHelper column_split_helper_;
};
} // namespace tree

View File

@@ -71,7 +71,7 @@ class GloablApproxBuilder {
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
n_batches_++;
}

View File

@@ -277,7 +277,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid);
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, fmat->IsColumnSplit());
++page_id;
}
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,