committed by
GitHub
parent
4bc59ef7c3
commit
5b76acccff
@@ -39,6 +39,37 @@ def test_rabit_communicator():
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
# TODO(rongou): remove this once we remove the rabit api.
|
||||
def run_rabit_api_worker(rabit_env, world_size):
|
||||
with xgb.rabit.RabitContext(rabit_env):
|
||||
assert xgb.rabit.get_world_size() == world_size
|
||||
assert xgb.rabit.is_distributed()
|
||||
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
|
||||
ret = xgb.rabit.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
|
||||
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||
|
||||
|
||||
# TODO(rongou): remove this once we remove the rabit api.
|
||||
def test_rabit_api():
|
||||
world_size = 2
|
||||
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
rabit_env = []
|
||||
for k, v in tracker.worker_envs().items():
|
||||
rabit_env.append(f"{k}={v}".encode())
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = multiprocessing.Process(target=run_rabit_api_worker,
|
||||
args=(rabit_env, world_size))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
assert worker.exitcode == 0
|
||||
|
||||
|
||||
def run_federated_worker(port, world_size, rank):
|
||||
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
|
||||
federated_server_address=f'localhost:{port}',
|
||||
|
||||
Reference in New Issue
Block a user