Fix federated learning demos and tests (#9488)

This commit is contained in:
Sean Yang
2023-08-16 00:25:05 -07:00
committed by GitHub
parent b2e93d2742
commit 12fe2fc06c
10 changed files with 60 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ This directory contains a demo of Vertical Federated Learning using
## Training with CPU only
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).
[README](../../../plugin/federated/README.md)).
Install NVFlare (note that currently NVFlare only supports Python 3.8):
```shell

View File

@@ -16,7 +16,7 @@ class SupportedTasks(object):
class XGBoostTrainer(Executor):
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
client_key_path: str, client_cert_path: str):
client_key_path: str, client_cert_path: str, use_gpus: bool):
"""Trainer for federated XGBoost.
Args:
@@ -32,6 +32,7 @@ class XGBoostTrainer(Executor):
self._server_cert_path = server_cert_path
self._client_key_path = client_key_path
self._client_cert_path = client_cert_path
self._use_gpus = use_gpus
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
abort_signal: Signal) -> Shareable:
@@ -81,6 +82,8 @@ class XGBoostTrainer(Executor):
'objective': 'binary:logistic',
'eval_metric': 'auc',
}
if self._use_gpus:
self.log_info(fl_ctx, 'GPUs are not currently supported by vertical federated XGBoost')
# specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]

View File

@@ -56,7 +56,7 @@ fi
nvflare poc -n 2 --prepare
mkdir -p /tmp/nvflare/poc/admin/transfer/vertical-xgboost
cp -fr config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
cp -fr ../config custom /tmp/nvflare/poc/admin/transfer/vertical-xgboost
cp server-*.pem client-cert.pem /tmp/nvflare/poc/server/
for (( site=1; site<=world_size; site++ )); do
cp server-cert.pem client-*.pem /tmp/nvflare/poc/site-"${site}"/