Use in-memory communicator to test quantile (#8710)
This commit is contained in:
@@ -40,20 +40,10 @@ void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo c
|
||||
Span<float const> hessian) {
|
||||
container->PushRowPage(page, info, hessian);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
template <bool use_column>
|
||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, kWorkers);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
void DoTestDistributedQuantile(size_t rows, size_t cols) {
|
||||
auto const world = collective::GetWorldSize();
|
||||
std::vector<MetaInfo> infos(2);
|
||||
auto& h_weights = infos.front().weights_.HostVector();
|
||||
h_weights.resize(rows);
|
||||
@@ -152,47 +142,36 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
}
|
||||
}
|
||||
|
||||
template <bool use_column>
|
||||
void TestDistributedQuantile(size_t const rows, size_t const cols) {
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, DoTestDistributedQuantile<use_column>, rows, cols);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Quantile, DistributedBasic) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 10, kCols = 10;
|
||||
TestDistributedQuantile<false>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, Distributed) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 4000, kCols = 200;
|
||||
TestDistributedQuantile<false>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, SortedDistributedBasic) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 10, kCols = 10;
|
||||
TestDistributedQuantile<true>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, SortedDistributed) {
|
||||
#if defined(__unix__)
|
||||
constexpr size_t kRows = 4000, kCols = 200;
|
||||
TestDistributedQuantile<true>(kRows, kCols);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Quantile, SameOnAllWorkers) {
|
||||
#if defined(__unix__)
|
||||
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
CHECK_EQ(world, kWorkers);
|
||||
} else {
|
||||
LOG(WARNING) << msg;
|
||||
return;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestSameOnAllWorkers() {
|
||||
auto const world = collective::GetWorldSize();
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
@@ -256,8 +235,13 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
}
|
||||
}
|
||||
});
|
||||
collective::Finalize();
|
||||
#endif // defined(__unix__)
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Quantile, SameOnAllWorkers) {
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, TestSameOnAllWorkers);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user