Add GPU support to NVFlare demo (#9552)
This commit is contained in:
@@ -77,13 +77,15 @@ class XGBoostTrainer(Executor):
|
||||
'gamma': 1.0,
|
||||
'max_depth': 8,
|
||||
'min_child_weight': 100,
|
||||
'tree_method': 'approx',
|
||||
'tree_method': 'hist',
|
||||
'grow_policy': 'depthwise',
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': 'auc',
|
||||
}
|
||||
if self._use_gpus:
|
||||
self.log_info(fl_ctx, 'GPUs are not currently supported by vertical federated XGBoost')
|
||||
if self._use_gpus:
|
||||
self.log_info(fl_ctx, f'Training with GPU {rank}')
|
||||
param['device'] = f"cuda:{rank}"
|
||||
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
|
||||
Reference in New Issue
Block a user