[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -189,7 +189,7 @@ struct SparsePageView {
|
||||
|
||||
explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); }
|
||||
SparsePage::Inst operator[](size_t i) { return view[i]; }
|
||||
size_t Size() const { return view.Size(); }
|
||||
[[nodiscard]] size_t Size() const { return view.Size(); }
|
||||
};
|
||||
|
||||
struct SingleInstanceView {
|
||||
@@ -250,7 +250,7 @@ struct GHistIndexMatrixView {
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
size_t Size() const { return page_.Size(); }
|
||||
[[nodiscard]] size_t Size() const { return page_.Size(); }
|
||||
};
|
||||
|
||||
template <typename Adapter>
|
||||
@@ -290,7 +290,7 @@ class AdapterView {
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t Size() const { return adapter_->NumRows(); }
|
||||
[[nodiscard]] size_t Size() const { return adapter_->NumRows(); }
|
||||
|
||||
bst_row_t const static base_rowid = 0; // NOLINT
|
||||
};
|
||||
@@ -408,31 +408,33 @@ class ColumnSplitHelper {
|
||||
ColumnSplitHelper(ColumnSplitHelper &&) noexcept = delete;
|
||||
ColumnSplitHelper &operator=(ColumnSplitHelper &&) noexcept = delete;
|
||||
|
||||
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
void PredictDMatrix(Context const *ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(),
|
||||
p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(SparsePageView{&batch}, out_preds);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize>(ctx, SparsePageView{&batch}, out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(SparsePage::Inst const &inst, std::vector<bst_float> *out_preds) {
|
||||
void PredictInstance(Context const *ctx, SparsePage::Inst const &inst,
|
||||
std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
PredictBatchKernel<SingleInstanceView, 1>(SingleInstanceView{inst}, out_preds);
|
||||
PredictBatchKernel<SingleInstanceView, 1>(ctx, SingleInstanceView{inst}, out_preds);
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector<bst_float> *out_preds) {
|
||||
CHECK(xgboost::collective::IsDistributed())
|
||||
<< "column-split prediction is only supported for distributed training";
|
||||
|
||||
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_));
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(SparsePageView{&batch}, out_preds);
|
||||
PredictBatchKernel<SparsePageView, kBlockOfRowsSize, true>(ctx, SparsePageView{&batch},
|
||||
out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,12 +455,13 @@ class ColumnSplitHelper {
|
||||
std::fill(missing_storage_.begin(), missing_storage_.end(), 0);
|
||||
}
|
||||
|
||||
std::size_t BitIndex(std::size_t tree_id, std::size_t row_id, std::size_t node_id) const {
|
||||
[[nodiscard]] std::size_t BitIndex(std::size_t tree_id, std::size_t row_id,
|
||||
std::size_t node_id) const {
|
||||
size_t tree_index = tree_id - tree_begin_;
|
||||
return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id;
|
||||
}
|
||||
|
||||
void AllreduceBitVectors() {
|
||||
void AllreduceBitVectors(Context const*) {
|
||||
collective::Allreduce<collective::Operation::kBitwiseOR>(decision_storage_.data(),
|
||||
decision_storage_.size());
|
||||
collective::Allreduce<collective::Operation::kBitwiseAND>(missing_storage_.data(),
|
||||
@@ -547,7 +550,7 @@ class ColumnSplitHelper {
|
||||
}
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size, bool predict_leaf = false>
|
||||
void PredictBatchKernel(DataView batch, std::vector<bst_float> *out_preds) {
|
||||
void PredictBatchKernel(Context const* ctx, DataView batch, std::vector<bst_float> *out_preds) {
|
||||
auto const num_group = model_.learner_model_param->num_output_group;
|
||||
|
||||
// parallel over local batch
|
||||
@@ -568,7 +571,7 @@ class ColumnSplitHelper {
|
||||
FVecDrop(block_size, fvec_offset, &feat_vecs_);
|
||||
});
|
||||
|
||||
AllreduceBitVectors();
|
||||
AllreduceBitVectors(ctx);
|
||||
|
||||
// auto block_id has the same type as `n_blocks`.
|
||||
common::ParallelFor(n_blocks, n_threads_, [&](auto block_id) {
|
||||
@@ -646,7 +649,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict DMatrix with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end);
|
||||
helper.PredictDMatrix(p_fmat, out_preds);
|
||||
helper.PredictDMatrix(ctx_, p_fmat, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -779,7 +782,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict instance with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(this->ctx_->Threads(), model, 0, ntree_limit);
|
||||
helper.PredictInstance(inst, out_preds);
|
||||
helper.PredictInstance(ctx_, inst, out_preds);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -811,7 +814,7 @@ class CPUPredictor : public Predictor {
|
||||
<< "Predict leaf with column split" << MTNotImplemented();
|
||||
|
||||
ColumnSplitHelper helper(n_threads, model, 0, ntree_limit);
|
||||
helper.PredictLeaf(p_fmat, &preds);
|
||||
helper.PredictLeaf(ctx_, p_fmat, &preds);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -62,9 +62,7 @@ struct TreeView {
|
||||
cats.node_ptr = tree_cat_ptrs;
|
||||
}
|
||||
|
||||
__device__ bool HasCategoricalSplit() const {
|
||||
return !cats.categories.empty();
|
||||
}
|
||||
[[nodiscard]] __device__ bool HasCategoricalSplit() const { return !cats.categories.empty(); }
|
||||
};
|
||||
|
||||
struct SparsePageView {
|
||||
@@ -77,7 +75,7 @@ struct SparsePageView {
|
||||
common::Span<const bst_row_t> row_ptr,
|
||||
bst_feature_t num_features)
|
||||
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
// Binary search
|
||||
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
|
||||
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
|
||||
@@ -105,8 +103,8 @@ struct SparsePageView {
|
||||
// Value is missing
|
||||
return nanf("");
|
||||
}
|
||||
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||
};
|
||||
|
||||
struct SparsePageLoader {
|
||||
@@ -137,7 +135,7 @@ struct SparsePageLoader {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * data.num_features + fidx];
|
||||
} else {
|
||||
@@ -151,7 +149,7 @@ struct EllpackLoader {
|
||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_row_t,
|
||||
size_t, float)
|
||||
: matrix{m} {}
|
||||
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
|
||||
Reference in New Issue
Block a user