[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
|
||||
|
||||
class TestWithDask:
|
||||
def test_dmatrix_binary(self, client: "Client") -> None:
|
||||
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
X, y = tm.make_categorical(100, 4, 4, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy.save_binary(path)
|
||||
|
||||
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
rank = xgb.rabit.get_rank()
|
||||
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
rank = xgb.collective.get_rank()
|
||||
path = os.path.join(tmpdir, f"{rank}.bin")
|
||||
Xy = xgb.DMatrix(path)
|
||||
assert Xy.num_row() == 100
|
||||
@@ -1488,20 +1488,13 @@ class TestWithDask:
|
||||
test = "--gtest_filter=Quantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ''
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
|
||||
port_env = arg.decode('utf-8')
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split('=')
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env["DMLC_TRACKER_URI"] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, capture_output=True)
|
||||
|
||||
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
|
||||
@@ -1543,8 +1536,8 @@ class TestWithDask:
|
||||
def get_score(config: Dict) -> float:
|
||||
return float(config["learner"]["learner_model_param"]["base_score"])
|
||||
|
||||
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
if worker_id == 0:
|
||||
y = np.array([0.0, 0.0, 0.0])
|
||||
x = np.array([[0.0]] * 3)
|
||||
@@ -1686,12 +1679,12 @@ class TestWithDask:
|
||||
n_workers = len(workers)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with xgb.dask.RabitContext(rabit_args):
|
||||
with xgb.dask.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
|
||||
**data_ref, nthread=7
|
||||
)
|
||||
total = np.array([local_dtrain.num_row()])
|
||||
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
|
||||
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
|
||||
assert total[0] == kRows
|
||||
|
||||
futures = []
|
||||
|
||||
Reference in New Issue
Block a user