Allow blocking launch of federated tracker. (#10414)
--------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
49e25cfb36
commit
6c83c8c2ef
@ -69,7 +69,7 @@ class MyLogistic : public ObjFunction {
|
|||||||
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
out["name"] = String("my_logistic");
|
out["name"] = String("mylogistic");
|
||||||
out["my_logistic_param"] = ToJson(param_);
|
out["my_logistic_param"] = ToJson(param_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -65,9 +65,19 @@ def run_federated_server( # pylint: disable=too-many-arguments
|
|||||||
server_key_path: Optional[str] = None,
|
server_key_path: Optional[str] = None,
|
||||||
server_cert_path: Optional[str] = None,
|
server_cert_path: Optional[str] = None,
|
||||||
client_cert_path: Optional[str] = None,
|
client_cert_path: Optional[str] = None,
|
||||||
|
blocking: bool = True,
|
||||||
timeout: int = 300,
|
timeout: int = 300,
|
||||||
) -> Dict[str, Any]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info."""
|
"""See :py:class:`~xgboost.federated.FederatedTracker` for more info.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
blocking :
|
||||||
|
Block the server until the training is finished. If set to False, the function
|
||||||
|
launches an additional thread and returns the worker arguments. The default is
|
||||||
|
True and a higher level framework is responsible for setting worker parameters.
|
||||||
|
|
||||||
|
"""
|
||||||
args: Dict[str, Any] = {"n_workers": n_workers}
|
args: Dict[str, Any] = {"n_workers": n_workers}
|
||||||
secure = all(
|
secure = all(
|
||||||
path is not None
|
path is not None
|
||||||
@ -78,6 +88,10 @@ def run_federated_server( # pylint: disable=too-many-arguments
|
|||||||
)
|
)
|
||||||
tracker.start()
|
tracker.start()
|
||||||
|
|
||||||
|
if blocking:
|
||||||
|
tracker.wait_for()
|
||||||
|
return None
|
||||||
|
|
||||||
thread = Thread(target=tracker.wait_for)
|
thread = Thread(target=tracker.wait_for)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|||||||
@ -63,7 +63,7 @@ def test_federated_communicator():
|
|||||||
world_size = 2
|
world_size = 2
|
||||||
tracker = multiprocessing.Process(
|
tracker = multiprocessing.Process(
|
||||||
target=federated.run_federated_server,
|
target=federated.run_federated_server,
|
||||||
kwargs={"port": port, "n_workers": world_size},
|
kwargs={"port": port, "n_workers": world_size, "blocking": False},
|
||||||
)
|
)
|
||||||
tracker.start()
|
tracker.start()
|
||||||
if not tracker.is_alive():
|
if not tracker.is_alive():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user