[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Copyright 2019-2022 XGBoost contributors"""
|
||||
import sys
|
||||
import os
|
||||
from typing import Type, TypeVar, Any, Dict, List
|
||||
from typing import Type, TypeVar, Any, Dict, List, Union
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
@@ -425,7 +425,7 @@ class TestDistributedGPU:
|
||||
)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
with dxgb.CommunicatorContext(**rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
@@ -505,20 +505,13 @@ class TestDistributedGPU:
|
||||
test = "--gtest_filter=GPUQuantile." + name
|
||||
|
||||
def runit(
|
||||
worker_addr: str, rabit_args: List[bytes]
|
||||
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
|
||||
) -> subprocess.CompletedProcess:
|
||||
port_env = ""
|
||||
# setup environment for running the c++ part.
|
||||
for arg in rabit_args:
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
|
||||
port_env = arg.decode("utf-8")
|
||||
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
|
||||
uri_env = arg.decode("utf-8")
|
||||
port = port_env.split("=")
|
||||
env = os.environ.copy()
|
||||
env[port[0]] = port[1]
|
||||
uri = uri_env.split("=")
|
||||
env[uri[0]] = uri[1]
|
||||
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
|
||||
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
|
||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
workers = _get_client_workers(local_cuda_client)
|
||||
|
||||
Reference in New Issue
Block a user