[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -1,7 +1,7 @@
"""Copyright 2019-2022 XGBoost contributors"""
import sys
import os
from typing import Type, TypeVar, Any, Dict, List
from typing import Type, TypeVar, Any, Dict, List, Union
import pytest
import numpy as np
import asyncio
@@ -425,7 +425,7 @@ class TestDistributedGPU:
)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
with dxgb.CommunicatorContext(**rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
@@ -505,20 +505,13 @@ class TestDistributedGPU:
test = "--gtest_filter=GPUQuantile." + name
def runit(
worker_addr: str, rabit_args: List[bytes]
worker_addr: str, rabit_args: Dict[str, Union[int, str]]
) -> subprocess.CompletedProcess:
port_env = ""
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode("utf-8").startswith("DMLC_TRACKER_PORT"):
port_env = arg.decode("utf-8")
if arg.decode("utf-8").startswith("DMLC_TRACKER_URI"):
uri_env = arg.decode("utf-8")
port = port_env.split("=")
env = os.environ.copy()
env[port[0]] = port[1]
uri = uri_env.split("=")
env[uri[0]] = uri[1]
env['DMLC_TRACKER_PORT'] = str(rabit_args['DMLC_TRACKER_PORT'])
env["DMLC_TRACKER_URI"] = str(rabit_args["DMLC_TRACKER_URI"])
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
workers = _get_client_workers(local_cuda_client)