[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -189,7 +189,7 @@ struct SparsePageView {
|
||||
|
||||
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
|
||||
SparsePage::Inst operator[](size_t i) { return view[i]; }
|
||||
size_t Size() const { return view.Size(); }
|
||||
[[nodiscard]] size_t Size() const { return view.Size(); }
|
||||
};
|
||||
|
||||
struct SingleInstanceView {
|
||||
@@ -250,7 +250,7 @@ struct GHistIndexMatrixView {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
size_t Size() const { return page_.Size(); }
|
||||
[[nodiscard]] size_t Size() const { return page_.Size(); }
|
||||
};
|
||||
|
||||
template <typename Adapter>
|
||||
@@ -290,7 +290,7 @@ class AdapterView {
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t Size() const { return adapter_->NumRows(); }
|
||||
[[nodiscard]] size_t Size() const { return adapter_->NumRows(); }
|
||||
|
||||
bst_row_t const static base_rowid = 0; // NOLINT
|
||||
};
|
||||
@@ -408,31 +408,33 @@ class ColumnSplitHelper {
|
||||
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
|
||||
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
|
||||
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
void PredictDMatrix(Context const *ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(ctx, SparsePageView{&batch}, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(SparsePage::Inst const &inst, std::vector<bst_float> *out_preds) {
|
||||
void PredictInstance(Context const *ctx, SparsePage::Inst const &inst,
|
||||
std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
PredictBatchKernel<SingleInstanceView, 1>(SingleInstanceView{inst}, out_preds);
|
||||
PredictBatchKernel<SingleInstanceView, 1>(ctx, SingleInstanceView{inst}, out_preds);
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(ctx, SparsePageView{&batch},
|
||||
out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,12 +455,13 @@ class ColumnSplitHelper {
|
||||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||
}
|
||||
|
||||
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
|
||||
[[nodiscard]] std::size_t BitIndex(std::size_t tree_id, std::size_t row_id,
|
||||
std::size_t node_id) const {
|
||||
size_t tree_index = tree_id - tree_begin_;
|
||||
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
||||
}
|
||||
|
||||
void AllreduceBitVectors() {
|
||||
void AllreduceBitVectors(Context const*) {
|
||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||
decision_storage_.size());
|
||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||
@@ -547,7 +550,7 @@ class ColumnSplitHelper {
|
||||
}
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
|
||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
|
||||
void PredictBatchKernel(Context const* ctx, DataView batch, std::vector<bst_float> *out_preds) {
|
||||
auto const num_group = model_.learner_model_param->num_output_group;
|
||||
|
||||
// parallel over local batch
|
||||
@@ -568,7 +571,7 @@ class ColumnSplitHelper {
|
||||
FVecDrop(block_size, fvec_offset, &feat_vecs_);
|
||||
});
|
||||
|
||||
AllreduceBitVectors();
|
||||
AllreduceBitVectors(ctx);
|
||||
|
||||
// auto block_id has the same type as `n_blocks`.
|
||||
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
||||
@@ -646,7 +649,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict DMatrix with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||
helper.PredictDMatrix(p_fmat, out_preds);
|
||||
helper.PredictDMatrix(ctx_, p_fmat, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -779,7 +782,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict instance with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
||||
helper.PredictInstance(inst, out_preds);
|
||||
helper.PredictInstance(ctx_, inst, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -811,7 +814,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict leaf with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
||||
helper.PredictLeaf(p_fmat, &preds);
|
||||
helper.PredictLeaf(ctx_, p_fmat, &preds);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user