106 lines
2.9 KiB
Python

"""XGBoost Experimental Federated Learning related API."""
import ctypes
from threading import Thread
from typing import Any, Dict, Optional
from .core import _LIB, _check_call, make_jcargs
from .tracker import RabitTracker
class FederatedTracker(RabitTracker):
"""Tracker for federated training.
Parameters
----------
n_workers :
The number of federated workers.
port :
The port to listen on.
secure :
Whether this is a secure instance. If True, then the following arguments for SSL
must be provided.
server_key_path :
Path to the server private key file.
server_cert_path :
Path to the server certificate file.
client_cert_path :
Path to the client certificate file.
"""
def __init__( # pylint: disable=R0913, W0231
self,
n_workers: int,
port: int,
secure: bool,
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
timeout: int = 300,
) -> None:
handle = ctypes.c_void_p()
args = make_jcargs(
n_workers=n_workers,
port=port,
dmlc_communicator="federated",
federated_secure=secure,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
timeout=int(timeout),
)
_check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle)))
self.handle = handle
def run_federated_server( # pylint: disable=too-many-arguments
n_workers: int,
port: int,
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
blocking: bool = True,
timeout: int = 300,
) -> Optional[Dict[str, Any]]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info.
Parameters
----------
blocking :
Block the server until the training is finished. If set to False, the function
launches an additional thread and returns the worker arguments. The default is
True and a higher level framework is responsible for setting worker parameters.
"""
args: Dict[str, Any] = {"n_workers": n_workers}
secure = all(
path is not None
for path in [server_key_path, server_cert_path, client_cert_path]
)
tracker = FederatedTracker(
n_workers=n_workers,
port=port,
secure=secure,
timeout=timeout,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
)
tracker.start()
if blocking:
tracker.wait_for()
return None
thread = Thread(target=tracker.wait_for)
thread.daemon = True
thread.start()
args.update(tracker.worker_args())
return args