[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -62,6 +62,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
if (comm.World() == 1) {
return Success();
}
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.

View File

@@ -10,6 +10,7 @@
#include <sstream> // for stringstream
#include <vector> // for vector
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for DefaultStream
#include "../common/type.h" // for EraseType
#include "broadcast.h" // for Broadcast
@@ -60,7 +61,7 @@ Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
root.TaskID()},
stream_{dh::DefaultStream()} {
stream_{ctx->CUDACtx()->Stream()} {
this->world_ = root.World();
this->rank_ = root.Rank();
this->domain_ = root.Domain();

View File

@@ -105,7 +105,7 @@ CommGroup::CommGroup()
}
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
static std::unique_ptr<collective::CommGroup> sptr;
static thread_local std::unique_ptr<collective::CommGroup> sptr;
if (!sptr) {
Json config{Null{}};
sptr.reset(CommGroup::Create(config));