Demo of federated learning using NVFlare (#7879)

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou 2022-05-14 07:45:41 -07:00 committed by GitHub
parent 11e46e4bc0
commit af907e2d0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 298 additions and 14 deletions

55
demo/nvflare/README.md Normal file
View File

@ -0,0 +1,55 @@
# Experimental Support of Federated XGBoost using NVFlare
This directory contains a demo of Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).
Install NVFlare (note that currently NVFlare only supports Python 3.8):
```shell
pip install nvflare
```
Prepare the data:
```shell
./prepare_data.sh
```
Start the NVFlare federated server:
```shell
./poc/server/startup/start.sh
```
In another terminal, start the first worker:
```shell
./poc/site-1/startup/start.sh
```
And the second worker:
```shell
./poc/site-2/startup/start.sh
```
Then start the admin CLI, using `admin/admin` as username/password:
```shell
./poc/admin/startup/fl_admin.sh
```
In the admin CLI, run the following commands:
```shell
upload_app hello-xgboost
set_run_number 1
deploy_app hello-xgboost all
start_app all
```
Once the training finishes, the model file should be written into
`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json`
respectively.
Finally, shutdown everything from the admin CLI:
```shell
shutdown client
shutdown server
```

View File

@ -0,0 +1,22 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"path": "trainer.XGBoostTrainer",
"args": {
"server_address": "localhost:9091",
"world_size": 2,
"server_cert_path": "server-cert.pem",
"client_key_path": "client-key.pem",
"client_cert_path": "client-cert.pem"
}
}
}
],
"task_result_filters": [],
"task_data_filters": []
}

View File

@ -0,0 +1,22 @@
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"workflows": [
{
"id": "server_workflow",
"path": "controller.XGBoostController",
"args": {
"port": 9091,
"world_size": 2,
"server_key_path": "server-key.pem",
"server_cert_path": "server-cert.pem",
"client_cert_path": "client-cert.pem"
}
}
],
"components": []
}

View File

@ -0,0 +1,68 @@
"""
Example of training controller with NVFlare
===========================================
"""
import multiprocessing
import xgboost.federated
from nvflare.apis.client import Client
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from trainer import SupportedTasks
class XGBoostController(Controller):
def __init__(self, port: int, world_size: int, server_key_path: str,
server_cert_path: str, client_cert_path: str):
"""Controller for federated XGBoost.
Args:
port: the port for the gRPC server to listen on.
world_size: the number of sites.
server_key_path: the path to the server key file.
server_cert_path: the path to the server certificate file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._port = port
self._world_size = world_size
self._server_key_path = server_key_path
self._server_cert_path = server_cert_path
self._client_cert_path = client_cert_path
self._server = None
def start_controller(self, fl_ctx: FLContext):
self._server = multiprocessing.Process(
target=xgboost.federated.run_federated_server,
args=(self._port, self._world_size, self._server_key_path,
self._server_cert_path, self._client_cert_path))
self._server.start()
def stop_controller(self, fl_ctx: FLContext):
if self._server:
self._server.terminate()
def process_result_of_unknown_task(self, client: Client, task_name: str,
client_task_id: str, result: Shareable,
fl_ctx: FLContext):
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_info(fl_ctx, "XGBoost training control flow started.")
if abort_signal.triggered:
return
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
self.broadcast_and_wait(
task=task,
min_responses=self._world_size,
fl_ctx=fl_ctx,
wait_time_after_min_received=1,
abort_signal=abort_signal,
)
if abort_signal.triggered:
return
self.log_info(fl_ctx, "XGBoost training control flow finished.")

View File

@ -0,0 +1,84 @@
import os
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode, FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
import xgboost as xgb
from xgboost import callback
class SupportedTasks(object):
TRAIN = "train"
class XGBoostTrainer(Executor):
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
client_key_path: str, client_cert_path: str):
"""Trainer for federated XGBoost.
Args:
server_address: address for the gRPC server to connect to.
world_size: the number of sites.
server_cert_path: the path to the server certificate file.
client_key_path: the path to the client key file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._server_address = server_address
self._world_size = world_size
self._server_cert_path = server_cert_path
self._client_key_path = client_key_path
self._client_cert_path = client_cert_path
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"Executing {task_name}")
try:
if task_name == SupportedTasks.TRAIN:
self._do_training(fl_ctx)
return make_reply(ReturnCode.OK)
else:
self.log_error(fl_ctx, f"{task_name} is not a supported task.")
return make_reply(ReturnCode.TASK_UNKNOWN)
except BaseException as e:
self.log_exception(fl_ctx,
f"Task {task_name} failed. Exception: {e.__str__()}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
rabit_env = [
f'federated_server_address={self._server_address}',
f'federated_world_size={self._world_size}',
f'federated_rank={rank}',
f'federated_server_cert={self._server_cert_path}',
f'federated_client_key={self._client_key_path}',
f'federated_client_cert={self._client_cert_path}'
]
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train')
dtest = xgb.DMatrix('agaricus.txt.test')
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20
# Run training, all the features in training API is available.
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2, verbose_eval=False,
callbacks=[callback.EvaluationMonitor(rank=rank)])
# Save the model.
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")

25
demo/nvflare/prepare_data.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
set -e
rm -fr ./agaricus* ./*.pem ./poc
world_size=2
# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
# Split train and test files manually to simulate a federated environment.
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site-
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-
poc -n 2
mkdir -p poc/admin/transfer/hello-xgboost
cp -fr config custom poc/admin/transfer/hello-xgboost
cp server-*.pem client-cert.pem poc/server/
for id in $(eval echo "{1..$world_size}"); do
cp server-cert.pem client-*.pem poc/site-"$id"/
cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train
cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test
done

View File

@ -111,9 +111,7 @@ class FederatedEngine : public IEngine {
void TrackerPrint(const std::string &msg) override { void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker // simply print information into the tracker
if (GetRank() == 0) { utils::Printf("%s", msg.c_str());
utils::Printf("%s", msg.c_str());
}
} }
private: private:

View File

@ -224,25 +224,16 @@ def _assert_dask_support() -> None:
LOGGER.warning(msg) LOGGER.warning(msg)
class RabitContext: class RabitContext(rabit.RabitContext):
"""A context controlling rabit initialization and finalization.""" """A context controlling rabit initialization and finalization."""
def __init__(self, args: List[bytes]) -> None: def __init__(self, args: List[bytes]) -> None:
self.args = args super().__init__(args)
worker = distributed.get_worker() worker = distributed.get_worker()
self.args.append( self.args.append(
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode() ("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
) )
def __enter__(self) -> None:
rabit.init(self.args)
assert rabit.is_distributed()
LOGGER.debug("-------------- rabit say hello ------------------")
def __exit__(self, *args: List) -> None:
rabit.finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
"""To be replaced with dask builtin.""" """To be replaced with dask builtin."""

View File

@ -1,6 +1,7 @@
"""Distributed XGBoost Rabit related API.""" """Distributed XGBoost Rabit related API."""
import ctypes import ctypes
from enum import IntEnum, unique from enum import IntEnum, unique
import logging
import pickle import pickle
from typing import Any, TypeVar, Callable, Optional, cast, List, Union from typing import Any, TypeVar, Callable, Optional, cast, List, Union
@ -8,6 +9,8 @@ import numpy as np
from .core import _LIB, c_str, _check_call from .core import _LIB, c_str, _check_call
LOGGER = logging.getLogger("[xgboost.rabit]")
def _init_rabit() -> None: def _init_rabit() -> None:
"""internal library initializer.""" """internal library initializer."""
@ -224,5 +227,21 @@ def version_number() -> int:
return ret return ret
class RabitContext:
"""A context controlling rabit initialization and finalization."""
def __init__(self, args: List[bytes]) -> None:
self.args = args
def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.debug("-------------- rabit say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
# initialization script # initialization script
_init_rabit() _init_rabit()