[dask] Add scheduler address to dask config. (#7581)
- Add user configuration. - Bring back to the logic of using scheduler address from dask. This was removed when we were trying to support GKE, now we bring it back and let xgboost try it if direct guess or host IP from user config failed.
This commit is contained in:
parent
5ddd4a9d06
commit
ef4dae4c0e
@ -475,6 +475,32 @@ interface, including callback functions, custom evaluation metric and objective:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
.. _tracker-ip:
|
||||||
|
|
||||||
|
***************
|
||||||
|
Tracker Host IP
|
||||||
|
***************
|
||||||
|
|
||||||
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
|
In some environments XGBoost might fail to resolve the IP address of the scheduler, a
|
||||||
|
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
|
||||||
|
during training. A quick workaround is to specify the address explicitly. To do that
|
||||||
|
dask config is used:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import dask
|
||||||
|
from distributed import Client
|
||||||
|
from xgboost import dask as dxgb
|
||||||
|
# let xgboost know the scheduler address
|
||||||
|
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
|
||||||
|
|
||||||
|
with Client(scheduler_file="sched.json") as client:
|
||||||
|
reg = dxgb.DaskXGBRegressor()
|
||||||
|
|
||||||
|
XGBoost will read configuration before training.
|
||||||
|
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
|
|||||||
@ -3,8 +3,12 @@
|
|||||||
# pylint: disable=too-many-lines, fixme
|
# pylint: disable=too-many-lines, fixme
|
||||||
# pylint: disable=too-few-public-methods
|
# pylint: disable=too-few-public-methods
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
"""Dask extensions for distributed training. See :doc:`Distributed XGBoost with Dask
|
"""
|
||||||
</tutorials/dask>` for simple tutorial. Also xgboost/demo/dask for some examples.
|
Dask extensions for distributed training
|
||||||
|
----------------------------------------
|
||||||
|
|
||||||
|
See :doc:`Distributed XGBoost with Dask </tutorials/dask>` for simple tutorial. Also
|
||||||
|
:doc:`/python/dask-examples/index` for some examples.
|
||||||
|
|
||||||
There are two sets of APIs in this module, one is the functional API including
|
There are two sets of APIs in this module, one is the functional API including
|
||||||
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
|
``train`` and ``predict`` methods. Another is stateful Scikit-Learner wrapper
|
||||||
@ -13,10 +17,22 @@ inherited from single-node Scikit-Learn interface.
|
|||||||
The implementation is heavily influenced by dask_xgboost:
|
The implementation is heavily influenced by dask_xgboost:
|
||||||
https://github.com/dask/dask-xgboost
|
https://github.com/dask/dask-xgboost
|
||||||
|
|
||||||
|
Optional dask configuration
|
||||||
|
===========================
|
||||||
|
|
||||||
|
- **xgboost.scheduler_address**: Specify the scheduler address, see :ref:`tracker-ip`.
|
||||||
|
|
||||||
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
|
import socket
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -136,17 +152,37 @@ def _multi_lock() -> Any:
|
|||||||
return MultiLock
|
return MultiLock
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
def _try_start_tracker(
|
||||||
"""Start Rabit tracker """
|
n_workers: int, addrs: List[Optional[str]]
|
||||||
env: Dict[str, Union[int, str]] = {'DMLC_NUM_WORKER': n_workers}
|
) -> Dict[str, Union[int, str]]:
|
||||||
host = get_host_ip('auto')
|
env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers}
|
||||||
rabit_context = RabitTracker(hostIP=host, n_workers=n_workers, use_logger=False)
|
try:
|
||||||
|
rabit_context = RabitTracker(
|
||||||
|
hostIP=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
|
||||||
|
)
|
||||||
env.update(rabit_context.worker_envs())
|
env.update(rabit_context.worker_envs())
|
||||||
|
|
||||||
rabit_context.start(n_workers)
|
rabit_context.start(n_workers)
|
||||||
thread = Thread(target=rabit_context.join)
|
thread = Thread(target=rabit_context.join)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
except socket.error as e:
|
||||||
|
if len(addrs) < 2 or e.errno != 99:
|
||||||
|
raise
|
||||||
|
LOGGER.warning(
|
||||||
|
"Failed to bind address '%s', trying to use '%s' instead.",
|
||||||
|
str(addrs[0]),
|
||||||
|
str(addrs[1]),
|
||||||
|
)
|
||||||
|
env = _try_start_tracker(n_workers, addrs[1:])
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def _start_tracker(
|
||||||
|
n_workers: int, addr_from_dask: Optional[str], addr_from_user: Optional[str]
|
||||||
|
) -> Dict[str, Union[int, str]]:
|
||||||
|
"""Start Rabit tracker, recurse to try different addresses."""
|
||||||
|
env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask])
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@ -174,6 +210,7 @@ class RabitContext:
|
|||||||
|
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
rabit.init(self.args)
|
rabit.init(self.args)
|
||||||
|
assert rabit.is_distributed()
|
||||||
LOGGER.debug('-------------- rabit say hello ------------------')
|
LOGGER.debug('-------------- rabit say hello ------------------')
|
||||||
|
|
||||||
def __exit__(self, *args: List) -> None:
|
def __exit__(self, *args: List) -> None:
|
||||||
@ -805,12 +842,43 @@ def _dmatrix_from_list_of_parts(
|
|||||||
return _create_dmatrix(**kwargs)
|
return _create_dmatrix(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[bytes]:
|
async def _get_rabit_args(
|
||||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
n_workers: int, dconfig: Optional[Dict[str, Any]], client: "distributed.Client"
|
||||||
env = await client.run_on_scheduler(_start_tracker, n_workers)
|
) -> List[bytes]:
|
||||||
|
"""Get rabit context arguments from data distribution in DaskDMatrix.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# There are 3 possible different addresses:
|
||||||
|
# 1. Provided by user via dask.config
|
||||||
|
# 2. Guessed by xgboost `get_host_ip` function
|
||||||
|
# 3. From dask scheduler
|
||||||
|
# We try 1 and 3 if 1 is available, otherwise 2 and 3.
|
||||||
|
valid_config = ["scheduler_address"]
|
||||||
|
# See if user config is available
|
||||||
|
if dconfig is not None:
|
||||||
|
for k in dconfig:
|
||||||
|
if k not in valid_config:
|
||||||
|
raise ValueError(f"Unknown configuration: {k}")
|
||||||
|
host_ip: Optional[str] = dconfig.get("scheduler_address", None)
|
||||||
|
else:
|
||||||
|
host_ip = None
|
||||||
|
# Try address from dask scheduler, this might not work, see
|
||||||
|
# https://github.com/dask/dask-xgboost/pull/40
|
||||||
|
try:
|
||||||
|
sched_addr = distributed.comm.get_address_host(client.scheduler.address)
|
||||||
|
sched_addr = sched_addr.strip("/:")
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
sched_addr = None
|
||||||
|
|
||||||
|
env = await client.run_on_scheduler(_start_tracker, n_workers, sched_addr, host_ip)
|
||||||
|
|
||||||
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
rabit_args = [f"{k}={v}".encode() for k, v in env.items()]
|
||||||
return rabit_args
|
return rabit_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dask_config() -> Optional[Dict[str, Any]]:
|
||||||
|
return dask.config.get("xgboost", default=None)
|
||||||
|
|
||||||
# train and predict methods are supposed to be "functional", which meets the
|
# train and predict methods are supposed to be "functional", which meets the
|
||||||
# dask paradigm. But as a side effect, the `evals_result` in single-node API
|
# dask paradigm. But as a side effect, the `evals_result` in single-node API
|
||||||
# is no longer supported since it mutates the input parameter, and it's not
|
# is no longer supported since it mutates the input parameter, and it's not
|
||||||
@ -837,6 +905,7 @@ def _get_workers_from_data(
|
|||||||
async def _train_async(
|
async def _train_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
|
dconfig: Optional[Dict[str, Any]],
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
dtrain: DaskDMatrix,
|
dtrain: DaskDMatrix,
|
||||||
num_boost_round: int,
|
num_boost_round: int,
|
||||||
@ -850,7 +919,7 @@ async def _train_async(
|
|||||||
custom_metric: Optional[Metric],
|
custom_metric: Optional[Metric],
|
||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
workers = _get_workers_from_data(dtrain, evals)
|
workers = _get_workers_from_data(dtrain, evals)
|
||||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
||||||
|
|
||||||
if params.get("booster", None) == "gblinear":
|
if params.get("booster", None) == "gblinear":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -995,7 +1064,12 @@ def train( # pylint: disable=unused-argument
|
|||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
args = locals()
|
args = locals()
|
||||||
return client.sync(_train_async, global_config=config.get_config(), **args)
|
return client.sync(
|
||||||
|
_train_async,
|
||||||
|
global_config=config.get_config(),
|
||||||
|
dconfig=_get_dask_config(),
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||||
@ -1693,6 +1767,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
asynchronous=True,
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
global_config=config.get_config(),
|
global_config=config.get_config(),
|
||||||
|
dconfig=_get_dask_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
@ -1796,6 +1871,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
asynchronous=True,
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
global_config=config.get_config(),
|
global_config=config.get_config(),
|
||||||
|
dconfig=_get_dask_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
@ -1987,6 +2063,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
asynchronous=True,
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
global_config=config.get_config(),
|
global_config=config.get_config(),
|
||||||
|
dconfig=_get_dask_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
|
|||||||
@ -192,6 +192,7 @@ class RabitTracker:
|
|||||||
logging.info('start listen on %s:%d', hostIP, self.port)
|
logging.info('start listen on %s:%d', hostIP, self.port)
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
|
if hasattr(self, "sock"):
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -371,7 +371,7 @@ class TestDistributedGPU:
|
|||||||
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
||||||
|
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client)
|
||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
with dxgb.RabitContext(rabit_args):
|
with dxgb.RabitContext(rabit_args):
|
||||||
@ -473,7 +473,7 @@ class TestDistributedGPU:
|
|||||||
|
|
||||||
with Client(local_cuda_cluster) as client:
|
with Client(local_cuda_cluster) as client:
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
rabit_args = client.sync(dxgb._get_rabit_args, workers, None, client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def run_rabit_ops(client, n_workers):
|
|||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
|
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
rabit_args = client.sync(_get_rabit_args, len(workers), None, client)
|
||||||
assert not rabit.is_distributed()
|
assert not rabit.is_distributed()
|
||||||
n_workers_from_dask = len(workers)
|
n_workers_from_dask = len(workers)
|
||||||
assert n_workers == n_workers_from_dask
|
assert n_workers == n_workers_from_dask
|
||||||
|
|||||||
@ -30,6 +30,7 @@ if tm.no_dask()['condition']:
|
|||||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||||
|
|
||||||
from distributed import LocalCluster, Client
|
from distributed import LocalCluster, Client
|
||||||
|
import dask
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
@ -1219,6 +1220,10 @@ class TestWithDask:
|
|||||||
os.remove(before_fname)
|
os.remove(before_fname)
|
||||||
os.remove(after_fname)
|
os.remove(after_fname)
|
||||||
|
|
||||||
|
with dask.config.set({'xgboost.foo': "bar"}):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
xgb.dask.train(client, {}, dtrain, num_boost_round=4)
|
||||||
|
|
||||||
def run_updater_test(
|
def run_updater_test(
|
||||||
self,
|
self,
|
||||||
client: "Client",
|
client: "Client",
|
||||||
@ -1318,7 +1323,8 @@ class TestWithDask:
|
|||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(
|
rabit_args = client.sync(
|
||||||
xgb.dask._get_rabit_args, len(workers), client)
|
xgb.dask._get_rabit_args, len(workers), None, client
|
||||||
|
)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
pure=False,
|
pure=False,
|
||||||
@ -1446,7 +1452,9 @@ class TestWithDask:
|
|||||||
n_partitions = X.npartitions
|
n_partitions = X.npartitions
|
||||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
workers = _get_client_workers(client)
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
rabit_args = client.sync(
|
||||||
|
xgb.dask._get_rabit_args, len(workers), None, client
|
||||||
|
)
|
||||||
n_workers = len(workers)
|
n_workers = len(workers)
|
||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user