[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -189,7 +189,7 @@ struct SparsePageView {
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
SparsePage::Inst operator[](size_t i) { return view[i]; }
size_t Size() const { return view.Size(); }
[[nodiscard]] size_t Size() const { return view.Size(); }
};
struct SingleInstanceView {
@@ -250,7 +250,7 @@ struct GHistIndexMatrixView {
}
return ret;
}
size_t Size() const { return page_.Size(); }
[[nodiscard]] size_t Size() const { return page_.Size(); }
};
template <typename Adapter>
@@ -290,7 +290,7 @@ class AdapterView {
return ret;
}
size_t Size() const { return adapter_->NumRows(); }
[[nodiscard]] size_t Size() const { return adapter_->NumRows(); }
bst_row_t const static base_rowid = 0; // NOLINT
};
@@ -408,31 +408,33 @@ class ColumnSplitHelper {
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
void PredictDMatrix(Context const *ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
CHECK(xgboost::collective::IsDistributed())
<< "column-split prediction is only supported for distributed training";
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(ctx, SparsePageView{&batch}, out_preds);
}
}
void PredictInstance(SparsePage::Inst const &inst, std::vector<bst_float> *out_preds) {
void PredictInstance(Context const *ctx, SparsePage::Inst const &inst,
std::vector<bst_float> *out_preds) {
CHECK(xgboost::collective::IsDistributed())
<< "column-split prediction is only supported for distributed training";
PredictBatchKernel<SingleInstanceView, 1>(SingleInstanceView{inst}, out_preds);
PredictBatchKernel<SingleInstanceView, 1>(ctx, SingleInstanceView{inst}, out_preds);
}
void PredictLeaf(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
CHECK(xgboost::collective::IsDistributed())
<< "column-split prediction is only supported for distributed training";
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(ctx, SparsePageView{&batch},
out_preds);
}
}
@@ -453,12 +455,13 @@ class ColumnSplitHelper {
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
}
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
[[nodiscard]] std::size_t BitIndex(std::size_t tree_id, std::size_t row_id,
std::size_t node_id) const {
size_t tree_index = tree_id - tree_begin_;
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
}
void AllreduceBitVectors() {
void AllreduceBitVectors(Context const*) {
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
decision_storage_.size());
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
@@ -547,7 +550,7 @@ class ColumnSplitHelper {
}
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
void PredictBatchKernel(Context const* ctx, DataView batch, std::vector<bst_float> *out_preds) {
auto const num_group = model_.learner_model_param->num_output_group;
// parallel over local batch
@@ -568,7 +571,7 @@ class ColumnSplitHelper {
FVecDrop(block_size, fvec_offset, &feat_vecs_);
});
AllreduceBitVectors();
AllreduceBitVectors(ctx);
// auto block_id has the same type as `n_blocks`.
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
@@ -646,7 +649,7 @@ class CPUPredictor : public Predictor {
<< "Predict DMatrix with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
helper.PredictDMatrix(p_fmat, out_preds);
helper.PredictDMatrix(ctx_, p_fmat, out_preds);
return;
}
@@ -779,7 +782,7 @@ class CPUPredictor : public Predictor {
<< "Predict instance with column split" << MTNotImplemented();
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
helper.PredictInstance(inst, out_preds);
helper.PredictInstance(ctx_, inst, out_preds);
return;
}
@@ -811,7 +814,7 @@ class CPUPredictor : public Predictor {
<< "Predict leaf with column split" << MTNotImplemented();
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
helper.PredictLeaf(p_fmat, &preds);
helper.PredictLeaf(ctx_, p_fmat, &preds);
return;
}