Add GPU support to NVFlare demo (#9552)

This commit is contained in:
Rong Ou
2023-09-06 02:03:59 -07:00
committed by GitHub
parent 3b9e5909fb
commit 0f35493b65
4 changed files with 14 additions and 7 deletions

View File

@@ -67,7 +67,7 @@ class XGBoostTrainer(Executor):
dtest = xgb.DMatrix('agaricus.txt.test?format=libsvm')
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
param = {'tree_method': 'hist', 'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
if self._use_gpus:
self.log_info(fl_ctx, f'Training with GPU {rank}')
param['device'] = f"cuda:{rank}"