68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
"""
|
|
Example of training controller with NVFlare
|
|
===========================================
|
|
"""
|
|
import multiprocessing
|
|
|
|
import xgboost.federated
|
|
from nvflare.apis.client import Client
|
|
from nvflare.apis.fl_context import FLContext
|
|
from nvflare.apis.impl.controller import Controller, Task
|
|
from nvflare.apis.shareable import Shareable
|
|
from nvflare.apis.signal import Signal
|
|
from trainer import SupportedTasks
|
|
|
|
|
|
class XGBoostController(Controller):
|
|
def __init__(self, port: int, world_size: int, server_key_path: str,
|
|
server_cert_path: str, client_cert_path: str):
|
|
"""Controller for federated XGBoost.
|
|
|
|
Args:
|
|
port: the port for the gRPC server to listen on.
|
|
world_size: the number of sites.
|
|
server_key_path: the path to the server key file.
|
|
server_cert_path: the path to the server certificate file.
|
|
client_cert_path: the path to the client certificate file.
|
|
"""
|
|
super().__init__()
|
|
self._port = port
|
|
self._world_size = world_size
|
|
self._server_key_path = server_key_path
|
|
self._server_cert_path = server_cert_path
|
|
self._client_cert_path = client_cert_path
|
|
self._server = None
|
|
|
|
def start_controller(self, fl_ctx: FLContext):
|
|
self._server = multiprocessing.Process(
|
|
target=xgboost.federated.run_federated_server,
|
|
args=(self._port, self._world_size, self._server_key_path,
|
|
self._server_cert_path, self._client_cert_path))
|
|
self._server.start()
|
|
|
|
def stop_controller(self, fl_ctx: FLContext):
|
|
if self._server:
|
|
self._server.terminate()
|
|
|
|
def process_result_of_unknown_task(self, client: Client, task_name: str,
|
|
client_task_id: str, result: Shareable,
|
|
fl_ctx: FLContext):
|
|
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
|
|
|
|
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
|
|
self.log_info(fl_ctx, "XGBoost training control flow started.")
|
|
if abort_signal.triggered:
|
|
return
|
|
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
|
|
self.broadcast_and_wait(
|
|
task=task,
|
|
min_responses=self._world_size,
|
|
fl_ctx=fl_ctx,
|
|
wait_time_after_min_received=1,
|
|
abort_signal=abort_signal,
|
|
)
|
|
if abort_signal.triggered:
|
|
return
|
|
|
|
self.log_info(fl_ctx, "XGBoost training control flow finished.")
|