Federated learning plugin for xgboost: * A gRPC server to aggregate MPI-style requests (allgather, allreduce, broadcast) from federated workers. * A Rabit engine for the federated environment. * Integration test to simulate federated learning. Additional followups are needed to address GPU support, better security, and privacy, etc.
37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
"""XGBoost Federated Learning related API."""
|
|
|
|
from .core import _LIB, _check_call, c_str, build_info, XGBoostError
|
|
|
|
|
|
def run_federated_server(port: int,
|
|
world_size: int,
|
|
server_key_path: str,
|
|
server_cert_path: str,
|
|
client_cert_path: str) -> None:
|
|
"""Run the Federated Learning server.
|
|
|
|
Parameters
|
|
----------
|
|
port : int
|
|
The port to listen on.
|
|
world_size: int
|
|
The number of federated workers.
|
|
server_key_path: str
|
|
Path to the server private key file.
|
|
server_cert_path: str
|
|
Path to the server certificate file.
|
|
client_cert_path: str
|
|
Path to the client certificate file.
|
|
"""
|
|
if build_info()['USE_FEDERATED']:
|
|
_check_call(_LIB.XGBRunFederatedServer(port,
|
|
world_size,
|
|
c_str(server_key_path),
|
|
c_str(server_cert_path),
|
|
c_str(client_cert_path)))
|
|
else:
|
|
raise XGBoostError(
|
|
"XGBoost needs to be built with the federated learning plugin "
|
|
"enabled in order to use this module"
|
|
)
|