enable ROCm on latest XGBoost
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/collective/allreduce.h"
|
||||
#include "../../../src/collective/coll.h" // for Coll
|
||||
#include "../../../src/collective/tracker.h"
|
||||
#include "test_worker.h" // for WorkerForTest, TestDistributed
|
||||
|
||||
@@ -47,6 +48,19 @@ class AllreduceWorker : public WorkerForTest {
|
||||
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
|
||||
}
|
||||
}
|
||||
|
||||
void BitOr() {
|
||||
Context ctx;
|
||||
std::vector<std::uint32_t> data(comm_.World(), 0);
|
||||
data[comm_.Rank()] = ~std::uint32_t{0};
|
||||
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
|
||||
auto rc = pcoll->Allreduce(&ctx, comm_, EraseType(common::Span{data.data(), data.size()}),
|
||||
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
|
||||
ASSERT_TRUE(rc.OK()) << rc.Report();
|
||||
for (auto v : data) {
|
||||
ASSERT_EQ(v, ~std::uint32_t{0});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AllreduceTest : public SocketTest {};
|
||||
@@ -69,4 +83,13 @@ TEST_F(AllreduceTest, Sum) {
|
||||
worker.Acc();
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(AllreduceTest, BitOr) {
|
||||
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
|
||||
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
|
||||
std::int32_t r) {
|
||||
AllreduceWorker worker{host, port, timeout, n_workers, r};
|
||||
worker.BitOr();
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -41,7 +41,7 @@ class LoopTest : public ::testing::Test {
|
||||
rc = pair_.first.NonBlocking(true);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
loop_ = std::make_shared<Loop>(timeout);
|
||||
loop_ = std::shared_ptr<Loop>{new Loop{timeout}};
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
|
||||
Reference in New Issue
Block a user