Use in-memory communicator to test quantile (#8710)

This commit is contained in:
Rong Ou 2023-01-27 07:28:28 -08:00 committed by GitHub
parent 96e6b6beba
commit 8af98e30fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 86 additions and 207 deletions

View File

@ -40,20 +40,10 @@ void PushPage(HostSketchContainer* container, SparsePage const& page, MetaInfo c
Span<float const> hessian) {
container->PushRowPage(page, info, hessian);
}
} // anonymous namespace
template <bool use_column>
void TestDistributedQuantile(size_t rows, size_t cols) {
std::string msg {"Skipping AllReduce test"};
int32_t constexpr kWorkers = 4;
InitCommunicatorContext(msg, kWorkers);
auto world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, kWorkers);
} else {
return;
}
void DoTestDistributedQuantile(size_t rows, size_t cols) {
auto const world = collective::GetWorldSize();
std::vector<MetaInfo> infos(2);
auto& h_weights = infos.front().weights_.HostVector();
h_weights.resize(rows);
@ -152,47 +142,36 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
}
}
template <bool use_column>
void TestDistributedQuantile(size_t const rows, size_t const cols) {
auto constexpr kWorkers = 4;
RunWithInMemoryCommunicator(kWorkers, DoTestDistributedQuantile<use_column>, rows, cols);
}
} // anonymous namespace
TEST(Quantile, DistributedBasic) {
#if defined(__unix__)
constexpr size_t kRows = 10, kCols = 10;
TestDistributedQuantile<false>(kRows, kCols);
#endif
}
TEST(Quantile, Distributed) {
#if defined(__unix__)
constexpr size_t kRows = 4000, kCols = 200;
TestDistributedQuantile<false>(kRows, kCols);
#endif
}
TEST(Quantile, SortedDistributedBasic) {
#if defined(__unix__)
constexpr size_t kRows = 10, kCols = 10;
TestDistributedQuantile<true>(kRows, kCols);
#endif
}
TEST(Quantile, SortedDistributed) {
#if defined(__unix__)
constexpr size_t kRows = 4000, kCols = 200;
TestDistributedQuantile<true>(kRows, kCols);
#endif
}
TEST(Quantile, SameOnAllWorkers) {
#if defined(__unix__)
std::string msg{"Skipping Quantile AllreduceBasic test"};
int32_t constexpr kWorkers = 4;
InitCommunicatorContext(msg, kWorkers);
auto world = collective::GetWorldSize();
if (world != 1) {
CHECK_EQ(world, kWorkers);
} else {
LOG(WARNING) << msg;
return;
}
namespace {
void TestSameOnAllWorkers() {
auto const world = collective::GetWorldSize();
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
@ -256,8 +235,13 @@ TEST(Quantile, SameOnAllWorkers) {
}
}
});
collective::Finalize();
#endif // defined(__unix__)
}
} // anonymous namespace
TEST(Quantile, SameOnAllWorkers) {
auto constexpr kWorkers = 4;
RunWithInMemoryCommunicator(kWorkers, TestSameOnAllWorkers);
}
} // namespace common
} // namespace xgboost

View File

@ -338,12 +338,9 @@ TEST(GPUQuantile, MultiMerge) {
});
}
TEST(GPUQuantile, AllReduceBasic) {
// This test is supposed to run by a python test that setups the environment.
std::string msg {"Skipping AllReduce test"};
auto n_gpus = AllVisibleGPUs();
InitCommunicatorContext(msg, n_gpus);
auto world = collective::GetWorldSize();
namespace {
void TestAllReduceBasic(int32_t n_gpus) {
auto const world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
@ -420,13 +417,16 @@ TEST(GPUQuantile, AllReduceBasic) {
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
}
});
collective::Finalize();
}
} // anonymous namespace
TEST(GPUQuantile, AllReduceBasic) {
auto const n_gpus = AllVisibleGPUs();
RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus);
}
TEST(GPUQuantile, SameOnAllWorkers) {
std::string msg {"Skipping SameOnAllWorkers test"};
auto n_gpus = AllVisibleGPUs();
InitCommunicatorContext(msg, n_gpus);
namespace {
void TestSameOnAllWorkers(int32_t n_gpus) {
auto world = collective::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
@ -490,6 +490,12 @@ TEST(GPUQuantile, SameOnAllWorkers) {
}
});
}
} // anonymous namespace
TEST(GPUQuantile, SameOnAllWorkers) {
auto const n_gpus = AllVisibleGPUs();
RunWithInMemoryCommunicator(n_gpus, TestSameOnAllWorkers, n_gpus);
}
TEST(GPUQuantile, Push) {
size_t constexpr kRows = 100;

View File

@ -10,31 +10,6 @@
namespace xgboost {
namespace common {
inline void InitCommunicatorContext(std::string msg, int32_t n_workers) {
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
port_str = port;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up.";
return;
}
auto uri = std::getenv("DMLC_TRACKER_URI");
std::string uri_str;
if (uri) {
uri_str = uri;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_URI` is not set up.";
return;
}
Json config{JsonObject()};
config["DMLC_TRACKER_PORT"] = port_str;
config["DMLC_TRACKER_URI"] = uri_str;
config["DMLC_NUM_WORKER"] = n_workers;
collective::Init(config);
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(2);
SimpleLCG lcg;

View File

@ -1,8 +1,7 @@
/**
* Copyright 2016-2023 by XGBoost contributors
*/
#ifndef XGBOOST_TESTS_CPP_HELPERS_H_
#define XGBOOST_TESTS_CPP_HELPERS_H_
#pragma once
#include <gtest/gtest.h>
#include <sys/stat.h>
@ -16,8 +15,10 @@
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "../../src/collective/communicator-inl.h"
#include "../../src/common/common.h"
#include "../../src/data/array_interface.h"
#include "../../src/gbm/gbtree_model.h"
@ -460,5 +461,25 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint
return mparam;
}
template <typename Function, typename... Args>
void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < world_size; rank++) {
threads.emplace_back([&, rank]() {
Json config{JsonObject()};
config["xgboost_communicator"] = String("in-memory");
config["in_memory_world_size"] = world_size;
config["in_memory_rank"] = rank;
xgboost::collective::Init(config);
std::forward<Function>(function)(std::forward<Args>(args)...);
xgboost::collective::Finalize();
});
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace xgboost
#endif

View File

@ -90,21 +90,14 @@ TEST(CpuPredictor, Basic) {
}
}
TEST(CpuPredictor, ColumnSplit) {
namespace {
void TestColumnSplitPredictBatch() {
size_t constexpr kRows = 5;
size_t constexpr kCols = 5;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
std::vector<std::thread> threads;
std::int32_t constexpr kWorldSize = 2;
size_t constexpr kSliceSize = (kCols + 1) / kWorldSize;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back([=, &dmat]() {
Json config{JsonObject()};
config["xgboost_communicator"] = String("in-memory");
config["in_memory_world_size"] = kWorldSize;
config["in_memory_rank"] = rank;
xgboost::collective::Init(config);
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const kSliceSize = (kCols + 1) / world_size;
auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor =
@ -126,12 +119,12 @@ TEST(CpuPredictor, ColumnSplit) {
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5);
}
xgboost::collective::Finalize();
});
}
for (auto& thread : threads) {
thread.join();
}
} // anonymous namespace
TEST(CpuPredictor, ColumnSplit) {
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestColumnSplitPredictBatch);
}
TEST(CpuPredictor, IterationRange) {

View File

@ -2,4 +2,3 @@
markers =
mgpu: Mark a test that requires multiple GPUs to run.
ci: Mark a test that runs only on CI.
gtest: Mark a test that requires C++ Google Test executable.

View File

@ -486,49 +486,6 @@ class TestDistributedGPU:
for rn, drn in zip(ranker_names, dranker_names):
assert rn == drn
def run_quantile(self, name: str, local_cuda_client: Client) -> None:
exe = None
for possible_path in {
"./testxgboost",
"./build/testxgboost",
"../build/testxgboost",
"../gpu-build/testxgboost",
}:
if os.path.exists(possible_path):
exe = possible_path
assert exe, "No testxgboost executable found."
test = "--gtest_filter=GPUQuantile." + name
def runit(
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
# setup environment for running the c++ part.
env = os.environ.copy()
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
workers = tm.get_client_workers(local_cuda_client)
rabit_args = local_cuda_client.sync(
dxgb._get_rabit_args, len(workers), None, local_cuda_client
)
futures = local_cuda_client.map(
runit, workers, pure=False, workers=workers, rabit_args=rabit_args
)
results = local_cuda_client.gather(futures)
for ret in results:
msg = ret.stdout.decode("utf-8")
assert msg.find("1 test from GPUQuantile") != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.gtest
def test_quantile_basic(self, local_cuda_client: Client) -> None:
self.run_quantile("AllReduceBasic", local_cuda_client)
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self, local_cuda_client: Client) -> None:
self.run_quantile("SameOnAllWorkers", local_cuda_client)
@pytest.mark.skipif(**tm.no_cupy())
def test_with_asyncio(local_cuda_client: Client) -> None:

View File

@ -1490,62 +1490,6 @@ class TestWithDask:
num_rounds = 10
self.run_updater_test(client, params, num_rounds, dataset, 'approx')
def run_quantile(self, name: str) -> None:
exe: Optional[str] = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/cpubuild/testxgboost',
'../cpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
if exe is None:
return
test = "--gtest_filter=Quantile." + name
def runit(
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
# setup environment for running the c++ part.
env = os.environ.copy()
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
with Client(cluster) as client:
workers = tm.get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from Quantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_basic(self) -> None:
self.run_quantile('DistributedBasic')
self.run_quantile('SortedDistributedBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile(self) -> None:
self.run_quantile('Distributed')
self.run_quantile('SortedDistributed')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self) -> None:
self.run_quantile("SameOnAllWorkers")
def test_adaptive(self) -> None:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])