Add demo for vertical federated learning (#9103)
This commit is contained in:
68
demo/nvflare/vertical/custom/controller.py
Normal file
68
demo/nvflare/vertical/custom/controller.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Example of training controller with NVFlare
|
||||
===========================================
|
||||
"""
|
||||
import multiprocessing
|
||||
|
||||
from nvflare.apis.client import Client
|
||||
from nvflare.apis.fl_context import FLContext
|
||||
from nvflare.apis.impl.controller import Controller, Task
|
||||
from nvflare.apis.shareable import Shareable
|
||||
from nvflare.apis.signal import Signal
|
||||
from trainer import SupportedTasks
|
||||
|
||||
import xgboost.federated
|
||||
|
||||
|
||||
class XGBoostController(Controller):
|
||||
def __init__(self, port: int, world_size: int, server_key_path: str,
|
||||
server_cert_path: str, client_cert_path: str):
|
||||
"""Controller for federated XGBoost.
|
||||
|
||||
Args:
|
||||
port: the port for the gRPC server to listen on.
|
||||
world_size: the number of sites.
|
||||
server_key_path: the path to the server key file.
|
||||
server_cert_path: the path to the server certificate file.
|
||||
client_cert_path: the path to the client certificate file.
|
||||
"""
|
||||
super().__init__()
|
||||
self._port = port
|
||||
self._world_size = world_size
|
||||
self._server_key_path = server_key_path
|
||||
self._server_cert_path = server_cert_path
|
||||
self._client_cert_path = client_cert_path
|
||||
self._server = None
|
||||
|
||||
def start_controller(self, fl_ctx: FLContext):
|
||||
self._server = multiprocessing.Process(
|
||||
target=xgboost.federated.run_federated_server,
|
||||
args=(self._port, self._world_size, self._server_key_path,
|
||||
self._server_cert_path, self._client_cert_path))
|
||||
self._server.start()
|
||||
|
||||
def stop_controller(self, fl_ctx: FLContext):
|
||||
if self._server:
|
||||
self._server.terminate()
|
||||
|
||||
def process_result_of_unknown_task(self, client: Client, task_name: str,
|
||||
client_task_id: str, result: Shareable,
|
||||
fl_ctx: FLContext):
|
||||
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
|
||||
|
||||
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
|
||||
self.log_info(fl_ctx, "XGBoost training control flow started.")
|
||||
if abort_signal.triggered:
|
||||
return
|
||||
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
|
||||
self.broadcast_and_wait(
|
||||
task=task,
|
||||
min_responses=self._world_size,
|
||||
fl_ctx=fl_ctx,
|
||||
wait_time_after_min_received=1,
|
||||
abort_signal=abort_signal,
|
||||
)
|
||||
if abort_signal.triggered:
|
||||
return
|
||||
|
||||
self.log_info(fl_ctx, "XGBoost training control flow finished.")
|
||||
Reference in New Issue
Block a user