Allow blocking launch of federated tracker. (#10414)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2024-06-16 01:43:53 +08:00 committed by GitHub
parent 49e25cfb36
commit 6c83c8c2ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

View File

@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction {
void SaveConfig(Json* p_out) const override { void SaveConfig(Json* p_out) const override {
auto& out = *p_out; auto& out = *p_out;
out["name"] = String("my_logistic"); out["name"] = String("mylogistic");
out["my_logistic_param"] = ToJson(param_); out["my_logistic_param"] = ToJson(param_);
} }

View File

@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments
server_key_path: Optional[str] = None, server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None, server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None, client_cert_path: Optional[str] = None,
blocking: bool = True,
timeout: int = 300, timeout: int = 300,
) -> Dict[str, Any]: ) -> Optional[Dict[str, Any]]:
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info.""" """See :py:class:`~xgboost.federated.FederatedTracker` for more info.
Parameters
----------
blocking :
Block the server until the training is finished. If set to False, the function
launches an additional thread and returns the worker arguments. The default is
True and a higher level framework is responsible for setting worker parameters.
"""
args: Dict[str, Any] = {"n_workers": n_workers} args: Dict[str, Any] = {"n_workers": n_workers}
secure = all( secure = all(
path is not None path is not None
@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments
) )
tracker.start() tracker.start()
if blocking:
tracker.wait_for()
return None
thread = Thread(target=tracker.wait_for) thread = Thread(target=tracker.wait_for)
thread.daemon = True thread.daemon = True
thread.start() thread.start()

View File

@ -63,7 +63,7 @@ def test_federated_communicator():
world_size = 2 world_size = 2
tracker = multiprocessing.Process( tracker = multiprocessing.Process(
target=federated.run_federated_server, target=federated.run_federated_server,
kwargs={"port": port, "n_workers": world_size}, kwargs={"port": port, "n_workers": world_size, "blocking": False},
) )
tracker.start() tracker.start()
if not tracker.is_alive(): if not tracker.is_alive():