diff --git a/demo/nvflare/vertical/custom/trainer.py b/demo/nvflare/vertical/custom/trainer.py index efe320734..b6c3855ef 100644 --- a/demo/nvflare/vertical/custom/trainer.py +++ b/demo/nvflare/vertical/custom/trainer.py @@ -83,9 +83,8 @@ class XGBoostTrainer(Executor): 'eval_metric': 'auc', } if self._use_gpus: - if self._use_gpus: - self.log_info(fl_ctx, f'Training with GPU {rank}') - param['device'] = f"cuda:{rank}" + self.log_info(fl_ctx, f'Training with GPU {rank}') + param['device'] = f"cuda:{rank}" # specify validations set to watch performance watchlist = [(dtest, "eval"), (dtrain, "train")]