[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -480,7 +480,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
|
||||
// Configure allocator with maximum cached bin size of ~1GB and no limit on
|
||||
// maximum cached bytes
|
||||
thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
|
||||
thread_local std::unique_ptr<cub::CachingDeviceAllocator> allocator{
|
||||
std::make_unique<cub::CachingDeviceAllocator>(2, 9, 29)};
|
||||
return *allocator;
|
||||
}
|
||||
pointer allocate(size_t n) { // NOLINT
|
||||
|
||||
@@ -51,7 +51,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
||||
for (auto const &page : m->GetBatches<SparsePage>()) {
|
||||
container.PushRowPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(m->Info(), &out);
|
||||
container.MakeCuts(ctx, m->Info(), &out);
|
||||
} else {
|
||||
SortedSketchContainer container{ctx,
|
||||
max_bins,
|
||||
@@ -61,7 +61,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins
|
||||
for (auto const &page : m->GetBatches<SortedCSCPage>(ctx)) {
|
||||
container.PushColPage(page, info, hessian);
|
||||
}
|
||||
container.MakeCuts(m->Info(), &out);
|
||||
container.MakeCuts(ctx, m->Info(), &out);
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
@@ -359,7 +359,7 @@ HistogramCuts DeviceSketchWithHessian(Context const* ctx, DMatrix* p_fmat, bst_b
|
||||
}
|
||||
}
|
||||
|
||||
sketch_container.MakeCuts(&cuts, p_fmat->Info().IsColumnSplit());
|
||||
sketch_container.MakeCuts(ctx, &cuts, p_fmat->Info().IsColumnSplit());
|
||||
return cuts;
|
||||
}
|
||||
} // namespace xgboost::common
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
#include "categorical.h"
|
||||
#include "hist_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
namespace xgboost::common {
|
||||
template <typename WQSketch>
|
||||
SketchContainerImpl<WQSketch>::SketchContainerImpl(Context const *ctx,
|
||||
std::vector<bst_row_t> columns_size,
|
||||
@@ -129,7 +127,7 @@ struct QuantileAllreduce {
|
||||
* \param rank rank of target worker
|
||||
* \param fidx feature idx
|
||||
*/
|
||||
auto Values(int32_t rank, bst_feature_t fidx) const {
|
||||
[[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const {
|
||||
// get span for worker
|
||||
auto wsize = worker_indptr[rank + 1] - worker_indptr[rank];
|
||||
auto worker_values = global_values.subspan(worker_indptr[rank], wsize);
|
||||
@@ -145,7 +143,7 @@ struct QuantileAllreduce {
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
MetaInfo const& info,
|
||||
Context const *, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches) {
|
||||
@@ -206,7 +204,7 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
||||
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
|
||||
auto world_size = collective::GetWorldSize();
|
||||
auto rank = collective::GetRank();
|
||||
if (world_size == 1 || info.IsColumnSplit()) {
|
||||
@@ -274,16 +272,15 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
MetaInfo const& info,
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t>* p_num_cuts) {
|
||||
Context const *ctx, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced, std::vector<int32_t> *p_num_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
|
||||
size_t n_columns = sketches_.size();
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
|
||||
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";
|
||||
|
||||
AllreduceCategories(info);
|
||||
AllreduceCategories(ctx, info);
|
||||
|
||||
auto& num_cuts = *p_num_cuts;
|
||||
CHECK_EQ(num_cuts.size(), 0);
|
||||
@@ -324,7 +321,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
|
||||
std::vector<bst_row_t> sketches_scan((n_columns + 1) * world, 0);
|
||||
|
||||
std::vector<typename WQSketch::Entry> global_sketches;
|
||||
this->GatherSketchInfo(info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||
this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches);
|
||||
|
||||
std::vector<typename WQSketch::SummaryContainer> final_sketches(n_columns);
|
||||
|
||||
@@ -383,11 +380,12 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
|
||||
}
|
||||
|
||||
template <typename WQSketch>
|
||||
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const &info, HistogramCuts *p_cuts) {
|
||||
void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const &info,
|
||||
HistogramCuts *p_cuts) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<typename WQSketch::SummaryContainer> reduced;
|
||||
std::vector<int32_t> num_cuts;
|
||||
this->AllReduce(info, &reduced, &num_cuts);
|
||||
this->AllReduce(ctx, info, &reduced, &num_cuts);
|
||||
|
||||
p_cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
|
||||
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
|
||||
@@ -496,5 +494,4 @@ void SortedSketchContainer::PushColPage(SparsePage const &page, MetaInfo const &
|
||||
});
|
||||
monitor_.Stop(__func__);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
@@ -22,9 +22,7 @@
|
||||
#include "transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
namespace xgboost::common {
|
||||
using WQSketch = HostSketchContainer::WQSketch;
|
||||
using SketchEntry = WQSketch::Entry;
|
||||
|
||||
@@ -501,7 +499,7 @@ void SketchContainer::FixError() {
|
||||
});
|
||||
}
|
||||
|
||||
void SketchContainer::AllReduce(bool is_column_split) {
|
||||
void SketchContainer::AllReduce(Context const*, bool is_column_split) {
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1 || is_column_split) {
|
||||
@@ -582,13 +580,13 @@ struct InvalidCatOp {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||
void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool is_column_split) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
p_cuts->min_vals_.Resize(num_columns_);
|
||||
|
||||
// Sync between workers.
|
||||
this->AllReduce(is_column_split);
|
||||
this->AllReduce(ctx, is_column_split);
|
||||
|
||||
// Prune to final number of bins.
|
||||
this->Prune(num_bins_ + 1);
|
||||
@@ -731,5 +729,4 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
@@ -151,9 +151,9 @@ class SketchContainer {
|
||||
Span<SketchEntry const> that);
|
||||
|
||||
/* \brief Merge quantiles from other GPU workers. */
|
||||
void AllReduce(bool is_column_split);
|
||||
void AllReduce(Context const* ctx, bool is_column_split);
|
||||
/* \brief Create the final histogram cut values. */
|
||||
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
|
||||
void MakeCuts(Context const* ctx, HistogramCuts* cuts, bool is_column_split);
|
||||
|
||||
Span<SketchEntry const> Data() const {
|
||||
return {this->Current().data().get(), this->Current().size()};
|
||||
|
||||
@@ -827,13 +827,14 @@ class SketchContainerImpl {
|
||||
return group_ind;
|
||||
}
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(MetaInfo const& info,
|
||||
void GatherSketchInfo(Context const *ctx, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<typename WQSketch::Entry> *p_global_sketches);
|
||||
// Merge sketches from all workers.
|
||||
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
void AllReduce(Context const *ctx, MetaInfo const &info,
|
||||
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t> *p_num_cuts);
|
||||
|
||||
template <typename Batch, typename IsValid>
|
||||
@@ -887,11 +888,11 @@ class SketchContainerImpl {
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
|
||||
|
||||
void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);
|
||||
void MakeCuts(Context const *ctx, MetaInfo const &info, HistogramCuts *cuts);
|
||||
|
||||
private:
|
||||
// Merge all categories from other workers.
|
||||
void AllreduceCategories(MetaInfo const& info);
|
||||
void AllreduceCategories(Context const* ctx, MetaInfo const& info);
|
||||
};
|
||||
|
||||
class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
|
||||
|
||||
Reference in New Issue
Block a user