[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -46,8 +46,8 @@ template <bool use_column>
|
||||
void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, kWorkers);
|
||||
} else {
|
||||
@@ -65,7 +65,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto sparsity = 0.5f;
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
std::vector<FeatureType> ft(cols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
|
||||
@@ -99,8 +99,8 @@ void TestDistributedQuantile(size_t rows, size_t cols) {
|
||||
sketch_distributed.MakeCuts(&distributed_cuts);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
rabit::Finalize();
|
||||
CHECK_EQ(rabit::GetWorldSize(), 1);
|
||||
collective::Finalize();
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; });
|
||||
m->Info().num_row_ = world * rows;
|
||||
ContainerType<use_column> sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(),
|
||||
@@ -184,8 +184,8 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
#if defined(__unix__)
|
||||
std::string msg{"Skipping Quantile AllreduceBasic test"};
|
||||
int32_t constexpr kWorkers = 4;
|
||||
InitRabitContext(msg, kWorkers);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, kWorkers);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
CHECK_EQ(world, kWorkers);
|
||||
} else {
|
||||
@@ -196,7 +196,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(
|
||||
kRows, [=](int32_t seed, size_t n_bins, MetaInfo const&) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<float> storage;
|
||||
std::vector<FeatureType> ft(kCols);
|
||||
for (size_t i = 0; i < ft.size(); ++i) {
|
||||
@@ -217,12 +217,12 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::vector<float> cut_min_values(cuts.MinValues().size() * world, 0);
|
||||
|
||||
size_t value_size = cuts.Values().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&value_size, 1);
|
||||
size_t ptr_size = cuts.Ptrs().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&ptr_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&ptr_size, 1);
|
||||
CHECK_EQ(ptr_size, kCols + 1);
|
||||
size_t min_value_size = cuts.MinValues().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&min_value_size, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&min_value_size, 1);
|
||||
CHECK_EQ(min_value_size, kCols);
|
||||
|
||||
size_t value_offset = value_size * rank;
|
||||
@@ -235,9 +235,9 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(),
|
||||
cut_min_values.begin() + min_values_offset);
|
||||
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_values.data(), cut_values.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
rabit::Allreduce<rabit::op::Sum>(cut_min_values.data(), cut_min_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_values.data(), cut_values.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_ptrs.data(), cut_ptrs.size());
|
||||
collective::Allreduce<collective::Operation::kSum>(cut_min_values.data(), cut_min_values.size());
|
||||
|
||||
for (int32_t i = 0; i < world; i++) {
|
||||
for (size_t j = 0; j < value_size; ++j) {
|
||||
@@ -256,7 +256,7 @@ TEST(Quantile, SameOnAllWorkers) {
|
||||
}
|
||||
}
|
||||
});
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
#endif // defined(__unix__)
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include "test_quantile.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/collective/device_communicator.cuh"
|
||||
#include "../../../src/common/hist_util.cuh"
|
||||
#include "../../../src/common/quantile.cuh"
|
||||
|
||||
@@ -341,17 +342,14 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
// This test is supposed to run by a python test that setups the environment.
|
||||
std::string msg {"Skipping AllReduce test"};
|
||||
auto n_gpus = AllVisibleGPUs();
|
||||
InitRabitContext(msg, n_gpus);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, n_gpus);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, n_gpus);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
auto reducer = std::make_shared<dh::AllReducer>();
|
||||
reducer->Init(0);
|
||||
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
|
||||
// Set up single node version;
|
||||
@@ -385,8 +383,8 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
|
||||
// Set up distributed version. We rely on using rank as seed to generate
|
||||
// the exact same copy of data.
|
||||
auto rank = rabit::GetRank();
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
||||
auto rank = collective::GetRank();
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||
HostDeviceVector<float> storage;
|
||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||
.Device(0)
|
||||
@@ -422,28 +420,26 @@ TEST(GPUQuantile, AllReduceBasic) {
|
||||
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
|
||||
}
|
||||
});
|
||||
rabit::Finalize();
|
||||
collective::Finalize();
|
||||
}
|
||||
|
||||
TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
std::string msg {"Skipping SameOnAllWorkers test"};
|
||||
auto n_gpus = AllVisibleGPUs();
|
||||
InitRabitContext(msg, n_gpus);
|
||||
auto world = rabit::GetWorldSize();
|
||||
InitCommunicatorContext(msg, n_gpus);
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world != 1) {
|
||||
ASSERT_EQ(world, n_gpus);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
auto reducer = std::make_shared<dh::AllReducer>();
|
||||
reducer->Init(0);
|
||||
|
||||
constexpr size_t kRows = 1000, kCols = 100;
|
||||
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
|
||||
MetaInfo const &info) {
|
||||
auto rank = rabit::GetRank();
|
||||
auto rank = collective::GetRank();
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0, reducer);
|
||||
SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0);
|
||||
HostDeviceVector<float> storage;
|
||||
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
|
||||
.Device(0)
|
||||
@@ -459,7 +455,7 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
|
||||
// Test for all workers having the same sketch.
|
||||
size_t n_data = sketch_distributed.Data().size();
|
||||
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
|
||||
collective::Allreduce<collective::Operation::kMax>(&n_data, 1);
|
||||
ASSERT_EQ(n_data, sketch_distributed.Data().size());
|
||||
size_t size_as_float =
|
||||
sketch_distributed.Data().size_bytes() / sizeof(float);
|
||||
@@ -472,9 +468,10 @@ TEST(GPUQuantile, SameOnAllWorkers) {
|
||||
thrust::copy(thrust::device, local_data.data(),
|
||||
local_data.data() + local_data.size(),
|
||||
all_workers.begin() + local_data.size() * rank);
|
||||
reducer->AllReduceSum(all_workers.data().get(), all_workers.data().get(),
|
||||
all_workers.size());
|
||||
reducer->Synchronize();
|
||||
collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(0);
|
||||
|
||||
communicator->AllReduceSum(all_workers.data().get(), all_workers.size());
|
||||
communicator->Synchronize();
|
||||
|
||||
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
|
||||
std::vector<float> h_base_line(base_line.size());
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
#ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||
#define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_
|
||||
|
||||
#include <rabit/rabit.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../src/collective/communicator-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
||||
inline void InitCommunicatorContext(std::string msg, int32_t n_workers) {
|
||||
auto port = std::getenv("DMLC_TRACKER_PORT");
|
||||
std::string port_str;
|
||||
if (port) {
|
||||
@@ -28,12 +28,11 @@ inline void InitRabitContext(std::string msg, int32_t n_workers) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::string> envs{
|
||||
"DMLC_TRACKER_PORT=" + port_str,
|
||||
"DMLC_TRACKER_URI=" + uri_str,
|
||||
"DMLC_NUM_WORKER=" + std::to_string(n_workers)};
|
||||
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
|
||||
rabit::Init(3, c_envs);
|
||||
Json config{JsonObject()};
|
||||
config["DMLC_TRACKER_PORT"] = port_str;
|
||||
config["DMLC_TRACKER_URI"] = uri_str;
|
||||
config["DMLC_NUM_WORKER"] = n_workers;
|
||||
collective::Init(config);
|
||||
}
|
||||
|
||||
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
|
||||
|
||||
@@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||
|
||||
|
||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||
rabit_env = [
|
||||
'xgboost_communicator=federated',
|
||||
f'federated_server_address=localhost:{port}',
|
||||
f'federated_world_size={world_size}',
|
||||
f'federated_rank={rank}'
|
||||
]
|
||||
communicator_env = {
|
||||
'xgboost_communicator': 'federated',
|
||||
'federated_server_address': f'localhost:{port}',
|
||||
'federated_world_size': world_size,
|
||||
'federated_rank': rank
|
||||
}
|
||||
if with_ssl:
|
||||
rabit_env = rabit_env + [
|
||||
f'federated_server_cert={SERVER_CERT}',
|
||||
f'federated_client_key={CLIENT_KEY}',
|
||||
f'federated_client_cert={CLIENT_CERT}'
|
||||
]
|
||||
communicator_env['federated_server_cert'] = SERVER_CERT
|
||||
communicator_env['federated_client_key'] = CLIENT_KEY
|
||||
communicator_env['federated_client_cert'] = CLIENT_CERT
|
||||
|
||||
# Always call this before using distributed module
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||
@@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
|
||||
early_stopping_rounds=2)
|
||||
|
||||
# Save the model, only ask process 0 to save the model.
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
if xgb.collective.get_rank() == 0:
|
||||
bst.save_model("test.model.json")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
|
||||
|
||||
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Copyright 2019-2022 XGBoost contributors"""
|
||||
import sys
|
||||
import os
|
||||
from typing import Type, TypeVar, Any, Dict, List
|
||||
from typing import Type, TypeVar, Any, Dict, List, Union
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
@@ -425,7 +425,7 @@ class TestDistributedGPU:
|
||||
)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
with dxgb.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
@@ -505,20 +505,13 @@ class TestDistributedGPU:
|
||||
test = "--gtest_filter=GPUQuantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ""
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
|
||||
port_env = arg.decode("utf-8")
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split("=")
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env[uri[0]] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
workers = _get_client_workers(local_cuda_client)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import multiprocessing
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost import collective
|
||||
from xgboost import RabitTracker, build_info, federated
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||
@@ -37,3 +37,41 @@ def test_rabit_communicator():
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
federated_world_size=world_size,
|
||||
federated_rank=rank):
|
||||
assert xgb.collective.get_world_size() == world_size
|
||||
assert xgb.collective.is_distributed()
|
||||
assert xgb.collective.get_processor_name() == f'rank{rank}'
|
||||
ret = xgb.collective.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
def test_federated_communicator():
|
||||
if not build_info()["USE_FEDERATED"]:
|
||||
pytest.skip("XGBoost not built with federated learning enabled")
|
||||
|
||||
port = 9091
|
||||
world_size = 2
|
||||
server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size))
|
||||
server.start()
|
||||
time.sleep(1)
|
||||
if not server.is_alive():
|
||||
raise Exception("Error starting Federated Learning server")
|
||||
|
||||
workers = []
|
||||
for rank in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_federated_worker,
|
||||
args=(port, world_size, rank))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
server.terminate()
|
||||
|
||||
@@ -15,37 +15,33 @@ if sys.platform.startswith("win"):
|
||||
def test_rabit_tracker():
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1)
|
||||
tracker.start(1)
|
||||
worker_env = tracker.worker_envs()
|
||||
rabit_env = []
|
||||
for k, v in worker_env.items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
ret = xgb.rabit.broadcast("test1234", 0)
|
||||
with xgb.collective.CommunicatorContext(**tracker.worker_envs()):
|
||||
ret = xgb.collective.broadcast("test1234", 0)
|
||||
assert str(ret) == "test1234"
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from test_with_dask import _get_client_workers
|
||||
from xgboost.dask import RabitContext, _get_dask_config, _get_rabit_args
|
||||
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args
|
||||
|
||||
from xgboost import rabit
|
||||
from xgboost import collective
|
||||
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
|
||||
assert not rabit.is_distributed()
|
||||
assert not collective.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
with CommunicatorContext(**rabit_args):
|
||||
a = 1
|
||||
assert rabit.is_distributed()
|
||||
assert collective.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||
reduced = collective.allreduce(a, collective.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||
reduced = collective.allreduce(worker_id, collective.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
@@ -83,14 +79,10 @@ def test_rank_assignment() -> None:
|
||||
from test_with_dask import _get_client_workers
|
||||
|
||||
def local_test(worker_id):
|
||||
with xgb.dask.RabitContext(args):
|
||||
for val in args:
|
||||
sval = val.decode("utf-8")
|
||||
if sval.startswith("DMLC_TASK_ID"):
|
||||
task_id = sval
|
||||
break
|
||||
with xgb.dask.CommunicatorContext(**args) as ctx:
|
||||
task_id = ctx["DMLC_TASK_ID"]
|
||||
matched = re.search(".*-([0-9]).*", task_id)
|
||||
rank = xgb.rabit.get_rank()
|
||||
rank = xgb.collective.get_rank()
|
||||
# As long as the number of workers is lesser than 10, rank and worker id
|
||||
# should be the same
|
||||
assert rank == int(matched.group(1))
|
||||
|
||||
@@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
@@ -1488,20 +1488,13 @@ class TestWithDask:
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ''
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env["DMLC_TRACKER_URI"] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||
@@ -1543,8 +1536,8 @@ class TestWithDask:
|
||||
def get_score(config: Dict) -> float:
|
||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||
|
||||
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
if worker_id == 0:
|
||||
y = np.array([0.0, 0.0, 0.0])
|
||||
x = np.array([[0.0]] * 3)
|
||||
@@ -1686,12 +1679,12 @@ class TestWithDask:
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = []
|
||||
|
||||
Reference in New Issue
Block a user