[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -21,21 +21,19 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
||||
|
||||
|
||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||
rabit_env = [
|
||||
'xgboost_communicator=federated',
|
||||
f'federated_server_address=localhost:{port}',
|
||||
f'federated_world_size={world_size}',
|
||||
f'federated_rank={rank}'
|
||||
]
|
||||
communicator_env = {
|
||||
'xgboost_communicator': 'federated',
|
||||
'federated_server_address': f'localhost:{port}',
|
||||
'federated_world_size': world_size,
|
||||
'federated_rank': rank
|
||||
}
|
||||
if with_ssl:
|
||||
rabit_env = rabit_env + [
|
||||
f'federated_server_cert={SERVER_CERT}',
|
||||
f'federated_client_key={CLIENT_KEY}',
|
||||
f'federated_client_cert={CLIENT_CERT}'
|
||||
]
|
||||
communicator_env['federated_server_cert'] = SERVER_CERT
|
||||
communicator_env['federated_client_key'] = CLIENT_KEY
|
||||
communicator_env['federated_client_cert'] = CLIENT_CERT
|
||||
|
||||
# Always call this before using distributed module
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
|
||||
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)
|
||||
@@ -55,9 +53,9 @@ def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu:
|
||||
early_stopping_rounds=2)
|
||||
|
||||
# Save the model, only ask process 0 to save the model.
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
if xgb.collective.get_rank() == 0:
|
||||
bst.save_model("test.model.json")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
|
||||
|
||||
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
|
||||
|
||||
Reference in New Issue
Block a user