Demo of federated learning using NVFlare (#7879)
Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
11e46e4bc0
commit
af907e2d0d
55
demo/nvflare/README.md
Normal file
55
demo/nvflare/README.md
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Experimental Support of Federated XGBoost using NVFlare
|
||||||
|
|
||||||
|
This directory contains a demo of Federated Learning using
|
||||||
|
[NVFlare](https://nvidia.github.io/NVFlare/).
|
||||||
|
|
||||||
|
To run the demo, first build XGBoost with the federated learning plugin enabled (see the
|
||||||
|
[README](../../plugin/federated/README.md)).
|
||||||
|
|
||||||
|
Install NVFlare (note that currently NVFlare only supports Python 3.8):
|
||||||
|
```shell
|
||||||
|
pip install nvflare
|
||||||
|
```
|
||||||
|
|
||||||
|
Prepare the data:
|
||||||
|
```shell
|
||||||
|
./prepare_data.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the NVFlare federated server:
|
||||||
|
```shell
|
||||||
|
./poc/server/startup/start.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
In another terminal, start the first worker:
|
||||||
|
```shell
|
||||||
|
./poc/site-1/startup/start.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
And the second worker:
|
||||||
|
```shell
|
||||||
|
./poc/site-2/startup/start.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Then start the admin CLI, using `admin/admin` as username/password:
|
||||||
|
```shell
|
||||||
|
./poc/admin/startup/fl_admin.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
In the admin CLI, run the following commands:
|
||||||
|
```shell
|
||||||
|
upload_app hello-xgboost
|
||||||
|
set_run_number 1
|
||||||
|
deploy_app hello-xgboost all
|
||||||
|
start_app all
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the training finishes, the model file should be written into
|
||||||
|
`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json`
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
Finally, shutdown everything from the admin CLI:
|
||||||
|
```shell
|
||||||
|
shutdown client
|
||||||
|
shutdown server
|
||||||
|
```
|
||||||
22
demo/nvflare/config/config_fed_client.json
Executable file
22
demo/nvflare/config/config_fed_client.json
Executable file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"format_version": 2,
|
||||||
|
"executors": [
|
||||||
|
{
|
||||||
|
"tasks": [
|
||||||
|
"train"
|
||||||
|
],
|
||||||
|
"executor": {
|
||||||
|
"path": "trainer.XGBoostTrainer",
|
||||||
|
"args": {
|
||||||
|
"server_address": "localhost:9091",
|
||||||
|
"world_size": 2,
|
||||||
|
"server_cert_path": "server-cert.pem",
|
||||||
|
"client_key_path": "client-key.pem",
|
||||||
|
"client_cert_path": "client-cert.pem"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"task_result_filters": [],
|
||||||
|
"task_data_filters": []
|
||||||
|
}
|
||||||
22
demo/nvflare/config/config_fed_server.json
Executable file
22
demo/nvflare/config/config_fed_server.json
Executable file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"format_version": 2,
|
||||||
|
"server": {
|
||||||
|
"heart_beat_timeout": 600
|
||||||
|
},
|
||||||
|
"task_data_filters": [],
|
||||||
|
"task_result_filters": [],
|
||||||
|
"workflows": [
|
||||||
|
{
|
||||||
|
"id": "server_workflow",
|
||||||
|
"path": "controller.XGBoostController",
|
||||||
|
"args": {
|
||||||
|
"port": 9091,
|
||||||
|
"world_size": 2,
|
||||||
|
"server_key_path": "server-key.pem",
|
||||||
|
"server_cert_path": "server-cert.pem",
|
||||||
|
"client_cert_path": "client-cert.pem"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"components": []
|
||||||
|
}
|
||||||
68
demo/nvflare/custom/controller.py
Normal file
68
demo/nvflare/custom/controller.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Example of training controller with NVFlare
|
||||||
|
===========================================
|
||||||
|
"""
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
import xgboost.federated
|
||||||
|
from nvflare.apis.client import Client
|
||||||
|
from nvflare.apis.fl_context import FLContext
|
||||||
|
from nvflare.apis.impl.controller import Controller, Task
|
||||||
|
from nvflare.apis.shareable import Shareable
|
||||||
|
from nvflare.apis.signal import Signal
|
||||||
|
|
||||||
|
from trainer import SupportedTasks
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostController(Controller):
|
||||||
|
def __init__(self, port: int, world_size: int, server_key_path: str,
|
||||||
|
server_cert_path: str, client_cert_path: str):
|
||||||
|
"""Controller for federated XGBoost.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: the port for the gRPC server to listen on.
|
||||||
|
world_size: the number of sites.
|
||||||
|
server_key_path: the path to the server key file.
|
||||||
|
server_cert_path: the path to the server certificate file.
|
||||||
|
client_cert_path: the path to the client certificate file.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._port = port
|
||||||
|
self._world_size = world_size
|
||||||
|
self._server_key_path = server_key_path
|
||||||
|
self._server_cert_path = server_cert_path
|
||||||
|
self._client_cert_path = client_cert_path
|
||||||
|
self._server = None
|
||||||
|
|
||||||
|
def start_controller(self, fl_ctx: FLContext):
|
||||||
|
self._server = multiprocessing.Process(
|
||||||
|
target=xgboost.federated.run_federated_server,
|
||||||
|
args=(self._port, self._world_size, self._server_key_path,
|
||||||
|
self._server_cert_path, self._client_cert_path))
|
||||||
|
self._server.start()
|
||||||
|
|
||||||
|
def stop_controller(self, fl_ctx: FLContext):
|
||||||
|
if self._server:
|
||||||
|
self._server.terminate()
|
||||||
|
|
||||||
|
def process_result_of_unknown_task(self, client: Client, task_name: str,
|
||||||
|
client_task_id: str, result: Shareable,
|
||||||
|
fl_ctx: FLContext):
|
||||||
|
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
|
||||||
|
|
||||||
|
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
|
||||||
|
self.log_info(fl_ctx, "XGBoost training control flow started.")
|
||||||
|
if abort_signal.triggered:
|
||||||
|
return
|
||||||
|
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
|
||||||
|
self.broadcast_and_wait(
|
||||||
|
task=task,
|
||||||
|
min_responses=self._world_size,
|
||||||
|
fl_ctx=fl_ctx,
|
||||||
|
wait_time_after_min_received=1,
|
||||||
|
abort_signal=abort_signal,
|
||||||
|
)
|
||||||
|
if abort_signal.triggered:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.log_info(fl_ctx, "XGBoost training control flow finished.")
|
||||||
84
demo/nvflare/custom/trainer.py
Normal file
84
demo/nvflare/custom/trainer.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from nvflare.apis.executor import Executor
|
||||||
|
from nvflare.apis.fl_constant import ReturnCode, FLContextKey
|
||||||
|
from nvflare.apis.fl_context import FLContext
|
||||||
|
from nvflare.apis.shareable import Shareable, make_reply
|
||||||
|
from nvflare.apis.signal import Signal
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import callback
|
||||||
|
|
||||||
|
|
||||||
|
class SupportedTasks(object):
|
||||||
|
TRAIN = "train"
|
||||||
|
|
||||||
|
|
||||||
|
class XGBoostTrainer(Executor):
|
||||||
|
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
|
||||||
|
client_key_path: str, client_cert_path: str):
|
||||||
|
"""Trainer for federated XGBoost.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_address: address for the gRPC server to connect to.
|
||||||
|
world_size: the number of sites.
|
||||||
|
server_cert_path: the path to the server certificate file.
|
||||||
|
client_key_path: the path to the client key file.
|
||||||
|
client_cert_path: the path to the client certificate file.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._server_address = server_address
|
||||||
|
self._world_size = world_size
|
||||||
|
self._server_cert_path = server_cert_path
|
||||||
|
self._client_key_path = client_key_path
|
||||||
|
self._client_cert_path = client_cert_path
|
||||||
|
|
||||||
|
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
|
||||||
|
abort_signal: Signal) -> Shareable:
|
||||||
|
self.log_info(fl_ctx, f"Executing {task_name}")
|
||||||
|
try:
|
||||||
|
if task_name == SupportedTasks.TRAIN:
|
||||||
|
self._do_training(fl_ctx)
|
||||||
|
return make_reply(ReturnCode.OK)
|
||||||
|
else:
|
||||||
|
self.log_error(fl_ctx, f"{task_name} is not a supported task.")
|
||||||
|
return make_reply(ReturnCode.TASK_UNKNOWN)
|
||||||
|
except BaseException as e:
|
||||||
|
self.log_exception(fl_ctx,
|
||||||
|
f"Task {task_name} failed. Exception: {e.__str__()}")
|
||||||
|
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
|
||||||
|
|
||||||
|
def _do_training(self, fl_ctx: FLContext):
|
||||||
|
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
|
||||||
|
rank = int(client_name.split('-')[1]) - 1
|
||||||
|
rabit_env = [
|
||||||
|
f'federated_server_address={self._server_address}',
|
||||||
|
f'federated_world_size={self._world_size}',
|
||||||
|
f'federated_rank={rank}',
|
||||||
|
f'federated_server_cert={self._server_cert_path}',
|
||||||
|
f'federated_client_key={self._client_key_path}',
|
||||||
|
f'federated_client_cert={self._client_cert_path}'
|
||||||
|
]
|
||||||
|
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
|
||||||
|
# Load file, file will not be sharded in federated mode.
|
||||||
|
dtrain = xgb.DMatrix('agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix('agaricus.txt.test')
|
||||||
|
|
||||||
|
# Specify parameters via map, definition are same as c++ version
|
||||||
|
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
|
||||||
|
|
||||||
|
# Specify validations set to watch performance
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 20
|
||||||
|
|
||||||
|
# Run training, all the features in training API is available.
|
||||||
|
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
|
||||||
|
early_stopping_rounds=2, verbose_eval=False,
|
||||||
|
callbacks=[callback.EvaluationMonitor(rank=rank)])
|
||||||
|
|
||||||
|
# Save the model.
|
||||||
|
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
|
||||||
|
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
|
||||||
|
run_dir = workspace.get_run_dir(run_number)
|
||||||
|
bst.save_model(os.path.join(run_dir, "test.model.json"))
|
||||||
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
25
demo/nvflare/prepare_data.sh
Executable file
25
demo/nvflare/prepare_data.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
rm -fr ./agaricus* ./*.pem ./poc
|
||||||
|
|
||||||
|
world_size=2
|
||||||
|
|
||||||
|
# Generate server and client certificates.
|
||||||
|
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
|
||||||
|
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"
|
||||||
|
|
||||||
|
# Split train and test files manually to simulate a federated environment.
|
||||||
|
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site-
|
||||||
|
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-
|
||||||
|
|
||||||
|
poc -n 2
|
||||||
|
mkdir -p poc/admin/transfer/hello-xgboost
|
||||||
|
cp -fr config custom poc/admin/transfer/hello-xgboost
|
||||||
|
cp server-*.pem client-cert.pem poc/server/
|
||||||
|
for id in $(eval echo "{1..$world_size}"); do
|
||||||
|
cp server-cert.pem client-*.pem poc/site-"$id"/
|
||||||
|
cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train
|
||||||
|
cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test
|
||||||
|
done
|
||||||
@ -111,9 +111,7 @@ class FederatedEngine : public IEngine {
|
|||||||
|
|
||||||
void TrackerPrint(const std::string &msg) override {
|
void TrackerPrint(const std::string &msg) override {
|
||||||
// simply print information into the tracker
|
// simply print information into the tracker
|
||||||
if (GetRank() == 0) {
|
utils::Printf("%s", msg.c_str());
|
||||||
utils::Printf("%s", msg.c_str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -224,25 +224,16 @@ def _assert_dask_support() -> None:
|
|||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
class RabitContext:
|
class RabitContext(rabit.RabitContext):
|
||||||
"""A context controlling rabit initialization and finalization."""
|
"""A context controlling rabit initialization and finalization."""
|
||||||
|
|
||||||
def __init__(self, args: List[bytes]) -> None:
|
def __init__(self, args: List[bytes]) -> None:
|
||||||
self.args = args
|
super().__init__(args)
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
self.args.append(
|
self.args.append(
|
||||||
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
|
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
|
||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self) -> None:
|
|
||||||
rabit.init(self.args)
|
|
||||||
assert rabit.is_distributed()
|
|
||||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
|
||||||
|
|
||||||
def __exit__(self, *args: List) -> None:
|
|
||||||
rabit.finalize()
|
|
||||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
|
||||||
|
|
||||||
|
|
||||||
def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements
|
||||||
"""To be replaced with dask builtin."""
|
"""To be replaced with dask builtin."""
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Distributed XGBoost Rabit related API."""
|
"""Distributed XGBoost Rabit related API."""
|
||||||
import ctypes
|
import ctypes
|
||||||
from enum import IntEnum, unique
|
from enum import IntEnum, unique
|
||||||
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||||
|
|
||||||
@ -8,6 +9,8 @@ import numpy as np
|
|||||||
|
|
||||||
from .core import _LIB, c_str, _check_call
|
from .core import _LIB, c_str, _check_call
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("[xgboost.rabit]")
|
||||||
|
|
||||||
|
|
||||||
def _init_rabit() -> None:
|
def _init_rabit() -> None:
|
||||||
"""internal library initializer."""
|
"""internal library initializer."""
|
||||||
@ -224,5 +227,21 @@ def version_number() -> int:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class RabitContext:
|
||||||
|
"""A context controlling rabit initialization and finalization."""
|
||||||
|
|
||||||
|
def __init__(self, args: List[bytes]) -> None:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
init(self.args)
|
||||||
|
assert is_distributed()
|
||||||
|
LOGGER.debug("-------------- rabit say hello ------------------")
|
||||||
|
|
||||||
|
def __exit__(self, *args: List) -> None:
|
||||||
|
finalize()
|
||||||
|
LOGGER.debug("--------------- rabit say bye ------------------")
|
||||||
|
|
||||||
|
|
||||||
# initialization script
|
# initialization script
|
||||||
_init_rabit()
|
_init_rabit()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user