Add GPU support to NVFlare demo (#9552)

This commit is contained in:
Rong Ou
2023-09-06 02:03:59 -07:00
committed by GitHub
parent 3b9e5909fb
commit 0f35493b65
4 changed files with 14 additions and 7 deletions

View File

@@ -77,13 +77,15 @@ class XGBoostTrainer(Executor):
'gamma': 1.0,
'max_depth': 8,
'min_child_weight': 100,
'tree_method': 'approx',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'objective': 'binary:logistic',
'eval_metric': 'auc',
}
if self._use_gpus:
self.log_info(fl_ctx, 'GPUs are not currently supported by vertical federated XGBoost')
if self._use_gpus:
self.log_info(fl_ctx, f'Training with GPU {rank}')
param['device'] = f"cuda:{rank}"
# specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]