diff --git a/plugin/example/custom_obj.cc b/plugin/example/custom_obj.cc index b996447a3..5d61e812a 100644 --- a/plugin/example/custom_obj.cc +++ b/plugin/example/custom_obj.cc @@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction { void SaveConfig(Json* p_out) const override { auto& out = *p_out; - out["name"] = String("my_logistic"); + out["name"] = String("mylogistic"); out["my_logistic_param"] = ToJson(param_); } diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index dcba9ec81..2e42c03ac 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments server_key_path: Optional[str] = None, server_cert_path: Optional[str] = None, client_cert_path: Optional[str] = None, + blocking: bool = True, timeout: int = 300, -) -> Dict[str, Any]: - """See :py:class:`~xgboost.federated.FederatedTracker` for more info.""" +) -> Optional[Dict[str, Any]]: + """See :py:class:`~xgboost.federated.FederatedTracker` for more info. + + Parameters + ---------- + blocking : + Block the server until the training is finished. If set to False, the function + launches an additional thread and returns the worker arguments. The default is + True and a higher level framework is responsible for setting worker parameters. + + """ args: Dict[str, Any] = {"n_workers": n_workers} secure = all( path is not None @@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments ) tracker.start() + if blocking: + tracker.wait_for() + return None + thread = Thread(target=tracker.wait_for) thread.daemon = True thread.start() diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index a3923e9df..2beedf8a1 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -63,7 +63,7 @@ def test_federated_communicator(): world_size = 2 tracker = multiprocessing.Process( target=federated.run_federated_server, - kwargs={"port": port, "n_workers": world_size}, + kwargs={"port": port, "n_workers": world_size, "blocking": False}, ) tracker.start() if not tracker.is_alive():