[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -52,15 +52,15 @@ class XGBoostTrainer(Executor):
|
||||
def _do_training(self, fl_ctx: FLContext):
|
||||
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
||||
rank = int(client_name.split('-')[1]) - 1
|
||||
rabit_env = [
|
||||
f'federated_server_address={self._server_address}',
|
||||
f'federated_world_size={self._world_size}',
|
||||
f'federated_rank={rank}',
|
||||
f'federated_server_cert={self._server_cert_path}',
|
||||
f'federated_client_key={self._client_key_path}',
|
||||
f'federated_client_cert={self._client_cert_path}'
|
||||
]
|
||||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||
communicator_env = {
|
||||
'federated_server_address': self._server_address,
|
||||
'federated_world_size': self._world_size,
|
||||
'federated_rank': rank,
|
||||
'federated_server_cert': self._server_cert_path,
|
||||
'federated_client_key': self._client_key_path,
|
||||
'federated_client_cert': self._client_cert_path
|
||||
}
|
||||
with xgb.collective.CommunicatorContext(**communicator_env):
|
||||
# Load file, file will not be sharded in federated mode.
|
||||
dtrain = xgb.DMatrix('agaricus.txt.train')
|
||||
dtest = xgb.DMatrix('agaricus.txt.test')
|
||||
@@ -86,4 +86,4 @@ class XGBoostTrainer(Executor):
|
||||
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
||||
run_dir = workspace.get_run_dir(run_number)
|
||||
bst.save_model(os.path.join(run_dir, "test.model.json"))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
xgb.collective.communicator_print("Finished training\n")
|
||||
|
||||
Reference in New Issue
Block a user