[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -1267,17 +1267,17 @@ def test_dask_iteration_range(client: "Client"):
class TestWithDask:
def test_dmatrix_binary(self, client: "Client") -> None:
def save_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def save_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
X, y = tm.make_categorical(100, 4, 4, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
path = os.path.join(tmpdir, f"{rank}.bin")
Xy.save_binary(path)
def load_dmatrix(rabit_args: List[bytes], tmpdir: str) -> None:
with xgb.dask.RabitContext(rabit_args):
rank = xgb.rabit.get_rank()
def load_dmatrix(rabit_args: Dict[str, Union[int,str]], tmpdir: str) -> None:
with xgb.dask.CommunicatorContext(**rabit_args):
rank = xgb.collective.get_rank()
path = os.path.join(tmpdir, f"{rank}.bin")
Xy = xgb.DMatrix(path)
assert Xy.num_row() == 100
@@ -1488,20 +1488,13 @@ class TestWithDask:
test = "--gtest_filter=Quantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
port_env = ''
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port_env = arg.decode('utf-8')
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env["DMLC_TRACKER_URI"] = uri[1]
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, capture_output=True)
with LocalCluster(n_workers=4, dashboard_address=":0") as cluster:
@@ -1543,8 +1536,8 @@ class TestWithDask:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
with xgb.dask.RabitContext(rabit_args):
def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool:
with xgb.dask.CommunicatorContext(**rabit_args):
if worker_id == 0:
y = np.array([0.0, 0.0, 0.0])
x = np.array([[0.0]] * 3)
@@ -1686,12 +1679,12 @@ class TestWithDask:
n_workers = len(workers)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with xgb.dask.RabitContext(rabit_args):
with xgb.dask.CommunicatorContext(**rabit_args):
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(
**data_ref, nthread=7
)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
total = xgb.collective.allreduce(total, xgb.collective.Op.SUM)
assert total[0] == kRows
futures = []