[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -370,6 +370,7 @@ void TestAllReduceBasic() {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
auto const device = DeviceOrd::CUDA(GPUIDX);
auto ctx = MakeCUDACtx(device.ordinal);
// Set up single node version;
HostDeviceVector<FeatureType> ft({}, device);
@@ -413,7 +414,7 @@ void TestAllReduceBasic() {
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce(false);
sketch_distributed.AllReduce(&ctx, false);
sketch_distributed.Unique();
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
@@ -517,6 +518,7 @@ void TestSameOnAllWorkers() {
MetaInfo const &info) {
auto const rank = collective::GetRank();
auto const device = DeviceOrd::CUDA(GPUIDX);
Context ctx = MakeCUDACtx(device.ordinal);
HostDeviceVector<FeatureType> ft({}, device);
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device);
HostDeviceVector<float> storage({}, device);
@@ -528,7 +530,7 @@ void TestSameOnAllWorkers() {
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce(false);
sketch_distributed.AllReduce(&ctx, false);
sketch_distributed.Unique();
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);