[dask] Add scheduler address to dask config. (#7581)

- Add user configuration.
- Bring back to the logic of using scheduler address from dask.  This was removed when we were trying to support GKE, now we bring it back and let xgboost try it if direct guess or host IP from user config failed.
This commit is contained in:
Jiaming Yuan 2022-01-22 01:56:32 +08:00 committed by GitHub
parent 5ddd4a9d06
commit ef4dae4c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 136 additions and 24 deletions

View File

@ -475,6 +475,32 @@ interface, including callback functions, custom evaluation metric and objective:
) )
.. _tracker-ip:
***************
Tracker Host IP
***************
.. versionadded:: 1.6.0
In some environments XGBoost might fail to resolve the IP address of the scheduler, a
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
during training. A quick workaround is to specify the address explicitly. To do that
dask config is used:
.. code-block:: python
import dask
from distributed import Client
from xgboost import dask as dxgb
# let xgboost know the scheduler address
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
with Client(scheduler_file="sched.json") as client:
reg = dxgb.DaskXGBRegressor()
XGBoost will read configuration before training.
***************************************************************************** *****************************************************************************
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
***************************************************************************** *****************************************************************************

View File

@ -3,8 +3,12 @@
# pylint: disable=too-many-lines, fixme # pylint: disable=too-many-lines, fixme
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
# pylint: disable=import-error # pylint: disable=import-error
"""Dask extensions for distributed training. See :doc:`Distributed XGBoost with Dask """
</tutorials/dask>` for simple tutorial. Also xgboost/demo/dask for some examples. Dask extensions for distributed training
----------------------------------------
See :doc:`Distributed XGBoost with Dask </tutorials/dask>` for simple tutorial. Also
:doc:`/python/dask-examples/index` for some examples.
There are two sets of APIs in this module, one is the functional API including There are two sets of APIs in this module, one is the functional API including
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper ``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
@ -13,10 +17,22 @@ inherited from single-node Scikit-Learn interface.
The implementation is heavily influenced by dask_xgboost: The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost https://github.com/dask/dask-xgboost
Optional dask configuration
===========================
- **xgboost.scheduler_address**: Specify the scheduler address, see :ref:`tracker-ip`.
.. versionadded:: 1.6.0
.. code-block:: python
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
""" """
import platform import platform
import logging import logging
import collections import collections
import socket
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
from threading import Thread from threading import Thread
@ -136,17 +152,37 @@ def _multi_lock() -> Any:
return MultiLock return MultiLock
def _start_tracker(n_workers: int) -> Dict[str, Any]: def _try_start_tracker(
"""Start Rabit tracker """ n_workers: int, addrs: List[Optional[str]]
env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers} ) -> Dict[str, Union[int, str]]:
host = get_host_ip('auto') env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False) try:
env.update(rabit_context.worker_envs()) rabit_context = RabitTracker(
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start()
except socket.error as e:
if len(addrs) < 2 or e.errno != 99:
raise
LOGGER.warning(
"Failed to bind address '%s', trying to use '%s' instead.",
str(addrs[0]),
str(addrs[1]),
)
env = _try_start_tracker(n_workers, addrs[1:])
rabit_context.start(n_workers) return env
thread = Thread(target=rabit_context.join)
thread.daemon = True
thread.start() def _start_tracker(
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
) -> Dict[str, Union[int, str]]:
"""Start Rabit tracker, recurse to try different addresses."""
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
return env return env
@ -174,6 +210,7 @@ class RabitContext:
def __enter__(self) -> None: def __enter__(self) -> None:
rabit.init(self.args) rabit.init(self.args)
assert rabit.is_distributed()
LOGGER.debug('-------------- rabit say hello ------------------') LOGGER.debug('-------------- rabit say hello ------------------')
def __exit__(self, *args: List) -> None: def __exit__(self, *args: List) -> None:
@ -805,12 +842,43 @@ def _dmatrix_from_list_of_parts(
return _create_dmatrix(**kwargs) return _create_dmatrix(**kwargs)
async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]: async def _get_rabit_args(
'''Get rabit context arguments from data distribution in DaskDMatrix.''' n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
env = await client.run_on_scheduler(_start_tracker, n_workers) ) -> List[bytes]:
"""Get rabit context arguments from data distribution in DaskDMatrix.
"""
# There are 3 possible different addresses:
# 1. Provided by user via dask.config
# 2. Guessed by xgboost `get_host_ip` function
# 3. From dask scheduler
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
valid_config = ["scheduler_address"]
# See if user config is available
if dconfig is not None:
for k in dconfig:
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
else:
host_ip = None
# Try address from dask scheduler, this might not work, see
# https://github.com/dask/dask-xgboost/pull/40
try:
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
sched_addr = sched_addr.strip("/:")
except Exception: # pylint: disable=broad-except
sched_addr = None
env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)
rabit_args = [f"{k}={v}".encode() for k, v in env.items()] rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
return rabit_args return rabit_args
def _get_dask_config() -> Optional[Dict[str, Any]]:
return dask.config.get("xgboost", default=None)
# train and predict methods are supposed to be "functional", which meets the # train and predict methods are supposed to be "functional", which meets the
# dask paradigm. But as a side effect, the `evals_result` in single-node API # dask paradigm. But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not # is no longer supported since it mutates the input parameter, and it's not
@ -837,6 +905,7 @@ def _get_workers_from_data(
async def _train_async( async def _train_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
dconfig: Optional[Dict[str, Any]],
params: Dict[str, Any], params: Dict[str, Any],
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
num_boost_round: int, num_boost_round: int,
@ -850,7 +919,7 @@ async def _train_async(
custom_metric: Optional[Metric], custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals) workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client) _rabit_args = await _get_rabit_args(len(workers), dconfig, client)
if params.get("booster", None) == "gblinear": if params.get("booster", None) == "gblinear":
raise NotImplementedError( raise NotImplementedError(
@ -948,7 +1017,7 @@ async def _train_async(
@_deprecate_positional_args @_deprecate_positional_args
def train( # pylint: disable=unused-argument def train( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
params: Dict[str, Any], params: Dict[str, Any],
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
@ -995,7 +1064,12 @@ def train( # pylint: disable=unused-argument
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
args = locals() args = locals()
return client.sync(_train_async, global_config=config.get_config(), **args) return client.sync(
_train_async,
global_config=config.get_config(),
dconfig=_get_dask_config(),
**args,
)
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
@ -1693,6 +1767,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
asynchronous=True, asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(), global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
@ -1796,6 +1871,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
asynchronous=True, asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(), global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
@ -1987,6 +2063,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
asynchronous=True, asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(), global_config=config.get_config(),
dconfig=_get_dask_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),

View File

@ -192,7 +192,8 @@ class RabitTracker:
logging.info('start listen on %s:%d', hostIP, self.port) logging.info('start listen on %s:%d', hostIP, self.port)
def __del__(self) -> None: def __del__(self) -> None:
self.sock.close() if hasattr(self, "sock"):
self.sock.close()
@staticmethod @staticmethod
def get_neighbor(rank: int, n_workers: int) -> List[int]: def get_neighbor(rank: int, n_workers: int) -> List[int]:

View File

@ -371,7 +371,7 @@ class TestDistributedGPU:
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw) m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client) rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
def worker_fn(worker_addr: str, data_ref: Dict) -> None: def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args): with dxgb.RabitContext(rabit_args):
@ -473,7 +473,7 @@ class TestDistributedGPU:
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client) rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
futures = client.map(runit, futures = client.map(runit,
workers, workers,
pure=False, pure=False,

View File

@ -28,7 +28,7 @@ def run_rabit_ops(client, n_workers):
from xgboost import rabit from xgboost import rabit
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), client) rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
assert not rabit.is_distributed() assert not rabit.is_distributed()
n_workers_from_dask = len(workers) n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask assert n_workers == n_workers_from_dask

View File

@ -30,6 +30,7 @@ if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True) pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
from distributed import LocalCluster, Client from distributed import LocalCluster, Client
import dask
import dask.dataframe as dd import dask.dataframe as dd
import dask.array as da import dask.array as da
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -1219,6 +1220,10 @@ class TestWithDask:
os.remove(before_fname) os.remove(before_fname)
os.remove(after_fname) os.remove(after_fname)
with dask.config.set({'xgboost.foo': "bar"}):
with pytest.raises(ValueError):
xgb.dask.train(client, {}, dtrain, num_boost_round=4)
def run_updater_test( def run_updater_test(
self, self,
client: "Client", client: "Client",
@ -1318,7 +1323,8 @@ class TestWithDask:
with Client(cluster) as client: with Client(cluster) as client:
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client) xgb.dask._get_rabit_args, len(workers), None, client
)
futures = client.map(runit, futures = client.map(runit,
workers, workers,
pure=False, pure=False,
@ -1446,7 +1452,9 @@ class TestWithDask:
n_partitions = X.npartitions n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y) m = xgb.dask.DaskDMatrix(client, X, y)
workers = _get_client_workers(client) workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client) rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
n_workers = len(workers) n_workers = len(workers)
def worker_fn(worker_addr: str, data_ref: Dict) -> None: def worker_fn(worker_addr: str, data_ref: Dict) -> None: