[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -62,6 +62,9 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
|
||||
|
||||
Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
|
||||
ArrayInterfaceHandler::Type type) {
|
||||
if (comm.World() == 1) {
|
||||
return Success();
|
||||
}
|
||||
return DispatchDType(type, [&](auto t) {
|
||||
using T = decltype(t);
|
||||
// Divide the data into segments according to the number of workers.
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <sstream> // for stringstream
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/device_helpers.cuh" // for DefaultStream
|
||||
#include "../common/type.h" // for EraseType
|
||||
#include "broadcast.h" // for Broadcast
|
||||
@@ -60,7 +61,7 @@ Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> pimpl)
|
||||
: Comm{root.TrackerInfo().host, root.TrackerInfo().port, root.Timeout(), root.Retry(),
|
||||
root.TaskID()},
|
||||
stream_{dh::DefaultStream()} {
|
||||
stream_{ctx->CUDACtx()->Stream()} {
|
||||
this->world_ = root.World();
|
||||
this->rank_ = root.Rank();
|
||||
this->domain_ = root.Domain();
|
||||
|
||||
@@ -105,7 +105,7 @@ CommGroup::CommGroup()
|
||||
}
|
||||
|
||||
std::unique_ptr<collective::CommGroup>& GlobalCommGroup() {
|
||||
static std::unique_ptr<collective::CommGroup> sptr;
|
||||
static thread_local std::unique_ptr<collective::CommGroup> sptr;
|
||||
if (!sptr) {
|
||||
Json config{Null{}};
|
||||
sptr.reset(CommGroup::Create(config));
|
||||
|
||||
Reference in New Issue
Block a user