Fix column split race condition. (#10572)
This commit is contained in:
@@ -36,10 +36,11 @@ class ColumnSplitHelper {
|
||||
common::PartitionBuilder<kPartitionBlockSize>* partition_builder,
|
||||
common::RowSetCollection* row_set_collection)
|
||||
: partition_builder_{partition_builder}, row_set_collection_{row_set_collection} {
|
||||
decision_storage_.resize(num_row);
|
||||
decision_bits_ = BitVector(common::Span<BitVector::value_type>(decision_storage_));
|
||||
missing_storage_.resize(num_row);
|
||||
missing_bits_ = BitVector(common::Span<BitVector::value_type>(missing_storage_));
|
||||
auto n_bytes = BitVector::ComputeStorageSize(num_row);
|
||||
decision_storage_.resize(n_bytes);
|
||||
decision_bits_ = BitVector{common::Span<BitVector::value_type>{decision_storage_}};
|
||||
missing_storage_.resize(n_bytes);
|
||||
missing_bits_ = BitVector{common::Span<BitVector::value_type>{missing_storage_}};
|
||||
}
|
||||
|
||||
template <typename BinIdxType, bool any_missing, bool any_cat, typename ExpandEntry>
|
||||
@@ -51,14 +52,43 @@ class ColumnSplitHelper {
|
||||
// we first collect all the decisions and whether the feature is missing into bit vectors.
|
||||
std::fill(decision_storage_.begin(), decision_storage_.end(), 0);
|
||||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||
common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
|
||||
this->tloc_decision_.resize(decision_storage_.size() * n_threads);
|
||||
this->tloc_missing_.resize(decision_storage_.size() * n_threads);
|
||||
std::fill_n(this->tloc_decision_.data(), this->tloc_decision_.size(), 0);
|
||||
std::fill_n(this->tloc_missing_.data(), this->tloc_missing_.size(), 0);
|
||||
|
||||
// Make thread-local storage.
|
||||
using T = decltype(decision_storage_)::value_type;
|
||||
auto make_tloc = [&](std::vector<T>& storage, std::int32_t tidx) {
|
||||
auto span = common::Span<T>{storage};
|
||||
auto n = decision_storage_.size();
|
||||
auto bitvec = BitVector{span.subspan(n * tidx, n)};
|
||||
return bitvec;
|
||||
};
|
||||
|
||||
common::ParallelFor2d(space, n_threads, [&](std::size_t node_in_set, common::Range1d r) {
|
||||
bst_node_t const nid = nodes[node_in_set].nid;
|
||||
auto tidx = omp_get_thread_num();
|
||||
auto decision = make_tloc(this->tloc_decision_, tidx);
|
||||
auto missing = make_tloc(this->tloc_missing_, tidx);
|
||||
bst_bin_t split_cond = column_matrix.IsInitialized() ? split_conditions[node_in_set] : 0;
|
||||
partition_builder_->MaskRows<BinIdxType, any_missing, any_cat>(
|
||||
node_in_set, nodes, r, split_cond, gmat, column_matrix, *p_tree,
|
||||
(*row_set_collection_)[nid].begin(), &decision_bits_, &missing_bits_);
|
||||
(*row_set_collection_)[nid].begin(), &decision, &missing);
|
||||
});
|
||||
|
||||
// Reduce thread local
|
||||
auto decision = make_tloc(this->tloc_decision_, 0);
|
||||
auto missing = make_tloc(this->tloc_missing_, 0);
|
||||
for (std::int32_t tidx = 1; tidx < n_threads; ++tidx) {
|
||||
decision |= make_tloc(this->tloc_decision_, tidx);
|
||||
missing |= make_tloc(this->tloc_missing_, tidx);
|
||||
}
|
||||
CHECK_EQ(decision_storage_.size(), decision.NumValues());
|
||||
std::copy_n(decision.Data(), decision_storage_.size(), decision_storage_.data());
|
||||
std::copy_n(missing.Data(), missing_storage_.size(), missing_storage_.data());
|
||||
|
||||
// Then aggregate the bit vectors across all the workers.
|
||||
auto rc = collective::Success() << [&] {
|
||||
return collective::Allreduce(ctx, &decision_storage_, collective::Op::kBitwiseOR);
|
||||
@@ -85,6 +115,10 @@ class ColumnSplitHelper {
|
||||
BitVector decision_bits_{};
|
||||
std::vector<BitVector::value_type> missing_storage_{};
|
||||
BitVector missing_bits_{};
|
||||
|
||||
std::vector<BitVector::value_type> tloc_decision_;
|
||||
std::vector<BitVector::value_type> tloc_missing_;
|
||||
|
||||
common::PartitionBuilder<kPartitionBlockSize>* partition_builder_;
|
||||
common::RowSetCollection* row_set_collection_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user