Add Python binding for rabit ops. (#5743)
This commit is contained in:
parent
e533908922
commit
e49607af19
@ -56,6 +56,12 @@ def get_world_size():
|
||||
return ret
|
||||
|
||||
|
||||
def is_distributed():
|
||||
'''If rabit is distributed.'''
|
||||
is_dist = _LIB.RabitIsDistributed()
|
||||
return is_dist
|
||||
|
||||
|
||||
def tracker_print(msg):
|
||||
"""Print message to the tracker.
|
||||
|
||||
@ -143,6 +149,14 @@ DTYPE_ENUM__ = {
|
||||
}
|
||||
|
||||
|
||||
class Op: # pylint: disable=too-few-public-methods
|
||||
'''Supported operations for rabit.'''
|
||||
MAX = 0
|
||||
MIN = 1
|
||||
SUM = 2
|
||||
OR = 3
|
||||
|
||||
|
||||
def allreduce(data, op, prepare_fun=None):
|
||||
"""Perform allreduce, return the result.
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import time
|
||||
|
||||
from xgboost import RabitTracker
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_rabit_tracker():
|
||||
@ -15,3 +16,39 @@ def test_rabit_tracker():
|
||||
ret = xgb.rabit.broadcast('test1234', 0)
|
||||
assert str(ret) == 'test1234'
|
||||
xgb.rabit.finalize()
|
||||
|
||||
|
||||
def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = _get_rabit_args(workers, client)
|
||||
assert not rabit.is_distributed()
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
a = 1
|
||||
assert rabit.is_distributed()
|
||||
a = np.array([a])
|
||||
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||
assert reduced[0] == n_workers
|
||||
|
||||
worker_id = np.array([worker_id])
|
||||
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||
assert reduced == n_workers - 1
|
||||
|
||||
return 1
|
||||
|
||||
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||
results = client.gather(futures)
|
||||
assert sum(results) == n_workers
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
def test_rabit_ops():
|
||||
from distributed import Client, LocalCluster
|
||||
n_workers = 3
|
||||
with LocalCluster(n_workers=n_workers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_rabit_ops(client, n_workers)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user