[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [
'xgboost_communicator=federated',
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}'
]
communicator_env = {
'xgboost_communicator': 'federated',
'federated_server_address': f'localhost:{port}',
'federated_world_size': world_size,
'federated_rank': rank
}
if with_ssl:
rabit_env = rabit_env + [
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
communicator_env['federated_server_cert'] = SERVER_CERT
communicator_env['federated_client_key'] = CLIENT_KEY
communicator_env['federated_client_cert'] = CLIENT_CERT
# Always call this before using distributed module
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
@@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
if xgb.collective.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None: