xgboost/demo/nvflare/custom/controller.py
2023-03-06 17:30:27 +08:00

69 lines
2.5 KiB
Python

"""
Example of training controller with NVFlare
===========================================
"""
import multiprocessing
from nvflare.apis.client import Client
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from trainer import SupportedTasks
import xgboost.federated
class XGBoostController(Controller):
def __init__(self, port: int, world_size: int, server_key_path: str,
server_cert_path: str, client_cert_path: str):
"""Controller for federated XGBoost.
Args:
port: the port for the gRPC server to listen on.
world_size: the number of sites.
server_key_path: the path to the server key file.
server_cert_path: the path to the server certificate file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._port = port
self._world_size = world_size
self._server_key_path = server_key_path
self._server_cert_path = server_cert_path
self._client_cert_path = client_cert_path
self._server = None
def start_controller(self, fl_ctx: FLContext):
self._server = multiprocessing.Process(
target=xgboost.federated.run_federated_server,
args=(self._port, self._world_size, self._server_key_path,
self._server_cert_path, self._client_cert_path))
self._server.start()
def stop_controller(self, fl_ctx: FLContext):
if self._server:
self._server.terminate()
def process_result_of_unknown_task(self, client: Client, task_name: str,
client_task_id: str, result: Shareable,
fl_ctx: FLContext):
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_info(fl_ctx, "XGBoost training control flow started.")
if abort_signal.triggered:
return
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
self.broadcast_and_wait(
task=task,
min_responses=self._world_size,
fl_ctx=fl_ctx,
wait_time_after_min_received=1,
abort_signal=abort_signal,
)
if abort_signal.triggered:
return
self.log_info(fl_ctx, "XGBoost training control flow finished.")