Support GPU training in the NVFlare demo (#7965)
This commit is contained in:
@@ -16,7 +16,7 @@ class SupportedTasks(object):
|
||||
|
||||
class XGBoostTrainer(Executor):
|
||||
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
|
||||
client_key_path: str, client_cert_path: str):
|
||||
client_key_path: str, client_cert_path: str, use_gpus: bool):
|
||||
"""Trainer for federated XGBoost.
|
||||
|
||||
Args:
|
||||
@@ -32,6 +32,7 @@ class XGBoostTrainer(Executor):
|
||||
self._server_cert_path = server_cert_path
|
||||
self._client_key_path = client_key_path
|
||||
self._client_cert_path = client_cert_path
|
||||
self._use_gpus = use_gpus
|
||||
|
||||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
|
||||
abort_signal: Signal) -> Shareable:
|
||||
@@ -66,6 +67,10 @@ class XGBoostTrainer(Executor):
|
||||
|
||||
# Specify parameters via map, definition are same as c++ version
|
||||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||
if self._use_gpus:
|
||||
self.log_info(fl_ctx, f'Training with GPU {rank}')
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param['gpu_id'] = rank
|
||||
|
||||
# Specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
|
||||
Reference in New Issue
Block a user