Add Python binding for rabit ops. (#5743)
This commit is contained in:
parent
e533908922
commit
e49607af19
@ -56,6 +56,12 @@ def get_world_size():
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed():
|
||||||
|
'''If rabit is distributed.'''
|
||||||
|
is_dist = _LIB.RabitIsDistributed()
|
||||||
|
return is_dist
|
||||||
|
|
||||||
|
|
||||||
def tracker_print(msg):
|
def tracker_print(msg):
|
||||||
"""Print message to the tracker.
|
"""Print message to the tracker.
|
||||||
|
|
||||||
@ -143,6 +149,14 @@ DTYPE_ENUM__ = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Op: # pylint: disable=too-few-public-methods
|
||||||
|
'''Supported operations for rabit.'''
|
||||||
|
MAX = 0
|
||||||
|
MIN = 1
|
||||||
|
SUM = 2
|
||||||
|
OR = 3
|
||||||
|
|
||||||
|
|
||||||
def allreduce(data, op, prepare_fun=None):
|
def allreduce(data, op, prepare_fun=None):
|
||||||
"""Perform allreduce, return the result.
|
"""Perform allreduce, return the result.
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
from xgboost import RabitTracker
|
from xgboost import RabitTracker
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
import pytest
|
||||||
|
import testing as tm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def test_rabit_tracker():
|
def test_rabit_tracker():
|
||||||
@ -15,3 +16,39 @@ def test_rabit_tracker():
|
|||||||
ret = xgb.rabit.broadcast('test1234', 0)
|
ret = xgb.rabit.broadcast('test1234', 0)
|
||||||
assert str(ret) == 'test1234'
|
assert str(ret) == 'test1234'
|
||||||
xgb.rabit.finalize()
|
xgb.rabit.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
def run_rabit_ops(client, n_workers):
|
||||||
|
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
|
||||||
|
from xgboost import rabit
|
||||||
|
|
||||||
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
rabit_args = _get_rabit_args(workers, client)
|
||||||
|
assert not rabit.is_distributed()
|
||||||
|
|
||||||
|
def local_test(worker_id):
|
||||||
|
with RabitContext(rabit_args):
|
||||||
|
a = 1
|
||||||
|
assert rabit.is_distributed()
|
||||||
|
a = np.array([a])
|
||||||
|
reduced = rabit.allreduce(a, rabit.Op.SUM)
|
||||||
|
assert reduced[0] == n_workers
|
||||||
|
|
||||||
|
worker_id = np.array([worker_id])
|
||||||
|
reduced = rabit.allreduce(worker_id, rabit.Op.MAX)
|
||||||
|
assert reduced == n_workers - 1
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
futures = client.map(local_test, range(len(workers)), workers=workers)
|
||||||
|
results = client.gather(futures)
|
||||||
|
assert sum(results) == n_workers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
|
def test_rabit_ops():
|
||||||
|
from distributed import Client, LocalCluster
|
||||||
|
n_workers = 3
|
||||||
|
with LocalCluster(n_workers=n_workers) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
run_rabit_ops(client, n_workers)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user