Cleanup code for distributed training. (#9805)

* Cleanup code for distributed training.

- Merge `GetNcclResult` into nccl stub.
- Split up utilities from the main dask module.
- Let Channel return `Result` to accommodate nccl channel.
- Remove old `use_label_encoder` parameter.
This commit is contained in:
Jiaming Yuan
2023-11-25 09:10:56 +08:00
committed by GitHub
parent e9260de3f3
commit 8fe1a2213c
19 changed files with 221 additions and 192 deletions

View File

@@ -25,15 +25,18 @@ TEST_F(CommTest, Channel) {
WorkerForTest worker{host, port, timeout, n_workers, i};
if (i % 2 == 0) {
auto p_chan = worker.Comm().Chan(i + 1);
p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
auto rc = Success() << [&] {
return p_chan->SendAll(
EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
} else {
auto p_chan = worker.Comm().Chan(i - 1);
std::int32_t r{-1};
p_chan->RecvAll(EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
auto rc = p_chan->Block();
auto rc = Success() << [&] {
return p_chan->RecvAll(
EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
} << [&] { return p_chan->Block(); };
ASSERT_TRUE(rc.OK()) << rc.Report();
ASSERT_EQ(r, i - 1);
}

View File

@@ -23,7 +23,7 @@ TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
auto stub = std::make_shared<NcclStub>(DefaultNcclName());
auto rc = GetNCCLResult(stub, ncclSystemError);
auto rc = stub->GetNcclResult(ncclSystemError);
auto msg = rc.Report();
ASSERT_TRUE(msg.find("environment variables") != std::string::npos);
}