[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -745,7 +745,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
void MetaInfo::SynchronizeNumberOfColumns() {
void MetaInfo::SynchronizeNumberOfColumns(Context const*) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {

View File

@@ -95,7 +95,7 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_featur
namespace {
// Synchronize feature type in case of empty DMatrix
void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
void SyncFeatureType(Context const*, std::vector<FeatureType>* p_h_ft) {
if (!collective::IsDistributed()) {
return;
}
@@ -193,7 +193,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
Info().SynchronizeNumberOfColumns();
Info().SynchronizeNumberOfColumns(ctx);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";
@@ -213,9 +213,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
while (iter.Next()) {
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty()});
SyncFeatureType(ctx, &h_ft);
p_sketch = std::make_unique<common::HostSketchContainer>(ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty());
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
@@ -230,7 +230,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
CHECK_EQ(accumulated_rows, Info().num_row_);
CHECK(p_sketch);
p_sketch->MakeCuts(Info(), &cuts);
p_sketch->MakeCuts(ctx, Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);

View File

@@ -105,7 +105,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
sketch_containers.clear();
sketch_containers.shrink_to_fit();
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
final_sketch.MakeCuts(ctx, &cuts, this->info_.IsColumnSplit());
} else {
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
}
@@ -167,7 +167,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
iter.Reset();
// Synchronise worker columns
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(ctx);
}
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,

View File

@@ -283,7 +283,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
ReindexFeatures(&ctx);
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);
if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT =

View File

@@ -42,7 +42,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);
this->fmat_ctx_ = ctx;
}

View File

@@ -97,7 +97,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;
info_.SynchronizeNumberOfColumns();
info_.SynchronizeNumberOfColumns(&ctx);
CHECK_NE(info_.num_col_, 0);
fmat_ctx_ = ctx;