[dask] Supoort running on GKE. (#6343)

* Avoid accessing `scheduler_info()['workers']`.
* Avoid calling `client.gather` inside task.
* Avoid using `client.scheduler_address`.
This commit is contained in:
Jiaming Yuan 2020-11-11 18:04:34 +08:00 committed by GitHub
parent 8a17610666
commit 6e12c2a6f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 128 deletions

View File

@ -66,9 +66,11 @@ distributed = LazyLoader('distributed', globals(), 'dask.distributed')
LOGGER = logging.getLogger('[xgboost.dask]')
def _start_tracker(host, n_workers):
def _start_tracker(n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
import socket
host = socket.gethostbyname(socket.gethostname())
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())
@ -141,11 +143,6 @@ def _xgb_get_client(client):
ret = distributed.get_client() if client is None else client
return ret
def _get_client_workers(client):
workers = client.scheduler_info()['workers']
return workers
# From the implementation point of view, DaskDMatrix complicates a lots of
# things. A large portion of the code base is about syncing and extracting
# stuffs from DaskDMatrix. But having an independent data structure gives us a
@ -333,7 +330,7 @@ class DaskDMatrix:
return self
def create_fn_args(self):
def create_fn_args(self, worker_addr: str):
'''Create a dictionary of objects that can be pickled for function
arguments.
@ -342,57 +339,55 @@ class DaskDMatrix:
'feature_types': self.feature_types,
'meta_names': self.meta_names,
'missing': self.missing,
'worker_map': self.worker_map,
'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile}
def _get_worker_parts_ordered(meta_names, worker_map, partition_order, worker):
list_of_parts: List[tuple] = worker_map[worker.address]
def _get_worker_parts_ordered(meta_names, list_of_keys, list_of_parts, partition_order):
# List of partitions like: [(x3, y3, w3, m3, ..), ..], order is not preserved.
assert isinstance(list_of_parts, list)
with distributed.worker_client() as client:
list_of_parts_value = client.gather(list_of_parts)
list_of_parts_value = list_of_parts
result = []
result = []
for i, part in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])
if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[part.key]))
for i, _ in enumerate(list_of_parts):
data = list_of_parts_value[i][0]
labels = None
weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
# there are constant number of meta info available.
for j, blob in enumerate(list_of_parts_value[i][1:]):
if meta_names[j] == 'labels':
labels = blob
elif meta_names[j] == 'weights':
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
raise ValueError('Unknown metainfo:', meta_names[j])
if partition_order:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound, partition_order[list_of_keys[i]]))
else:
result.append((data, labels, weights, base_margin, label_lower_bound,
label_upper_bound))
return result
def _unzip(list_of_parts):
return list(zip(*list_of_parts))
def _get_worker_parts(worker_map, meta_names, worker):
partitions = _get_worker_parts_ordered(meta_names, worker_map, None, worker)
def _get_worker_parts(list_of_parts: List[tuple], meta_names):
partitions = _get_worker_parts_ordered(meta_names, None, list_of_parts, None)
partitions = _unzip(partitions)
return partitions
@ -522,21 +517,19 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
self.max_bin = max_bin
self.is_quantile = True
def create_fn_args(self):
args = super().create_fn_args()
def create_fn_args(self, worker_addr: str):
args = super().create_fn_args(worker_addr)
args['max_bin'] = self.max_bin
return args
def _create_device_quantile_dmatrix(feature_names, feature_types,
meta_names, missing, worker_map,
meta_names, missing, parts,
max_bin):
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(worker_map.keys()))
if parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(
address=worker.address)
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
@ -547,7 +540,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
worker_map, meta_names, worker)
parts, meta_names)
it = DaskPartitionIter(data=data, label=labels, weight=weights,
base_margin=base_margin,
label_lower_bound=label_lower_bound,
@ -562,8 +555,7 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
return dmatrix
def _create_dmatrix(feature_names, feature_types, meta_names, missing,
worker_map):
def _create_dmatrix(feature_names, feature_types, meta_names, missing, parts):
'''Get data that local to worker from DaskDMatrix.
Returns
@ -572,11 +564,9 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
'''
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(worker_map.keys()))
list_of_parts = parts
if list_of_parts is None:
msg = 'worker {address} has an empty DMatrix. '.format(address=worker.address)
LOGGER.warning(msg)
d = DMatrix(numpy.empty((0, 0)),
feature_names=feature_names,
@ -584,13 +574,12 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return d
def concat_or_none(data):
if all([part is None for part in data]):
if any([part is None for part in data]):
return None
return concat(data)
(data, labels, weights, base_margin,
label_lower_bound, label_upper_bound) = _get_worker_parts(
worker_map, meta_names, worker)
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
labels = concat_or_none(labels)
weights = concat_or_none(weights)
@ -611,17 +600,15 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
return dmatrix
def _dmatrix_from_worker_map(is_quantile, **kwargs):
def _dmatrix_from_list_of_parts(is_quantile, **kwargs):
if is_quantile:
return _create_device_quantile_dmatrix(**kwargs)
return _create_dmatrix(**kwargs)
async def _get_rabit_args(worker_map, client):
async def _get_rabit_args(n_workers: int, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed.comm.get_address_host(client.scheduler.address)
env = await client.run_on_scheduler(
_start_tracker, host.strip('/:'), len(worker_map))
env = await client.run_on_scheduler(_start_tracker, n_workers)
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
return rabit_args
@ -632,49 +619,58 @@ async def _get_rabit_args(worker_map, client):
# evaluation history is instead returned.
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
early_stopping_rounds=None, **kwargs):
_assert_dask_support()
client: distributed.Client = _xgb_get_client(client)
def _get_workers_from_data(dtrain: DaskDMatrix, evals=()):
X_worker_map = set(dtrain.worker_map.keys())
if evals:
for e in evals:
assert len(e) == 2
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
worker_map = set(e[0].worker_map.keys())
X_worker_map.union(worker_map)
return X_worker_map
async def _train_async(client,
params,
dtrain: DaskDMatrix,
*args,
evals=(),
early_stopping_rounds=None,
**kwargs):
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
'The evaluation history is returned as result of training.')
workers = list(_get_client_workers(client).keys())
_rabit_args = await _get_rabit_args(workers, client)
workers = list(_get_workers_from_data(dtrain, evals))
_rabit_args = await _get_rabit_args(len(workers), client)
def dispatched_train(worker_addr, rabit_args, dtrain_ref, evals_ref):
def dispatched_train(worker_addr, rabit_args, dtrain_ref, dtrain_idt, evals_ref):
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed.get_worker()
with RabitContext(rabit_args):
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
local_evals = []
if evals_ref:
for ref, name in evals_ref:
if ref['worker_map'] == dtrain_ref['worker_map']:
for ref, name, idt in evals_ref:
if idt == dtrain_idt:
local_evals.append((local_dtrain, name))
continue
local_evals.append((_dmatrix_from_worker_map(**ref), name))
local_evals.append((_dmatrix_from_list_of_parts(**ref), name))
local_history = {}
local_param = params.copy() # just to be consistent
msg = 'Overriding `nthreads` defined in dask worker.'
if 'nthread' in local_param.keys() and \
local_param['nthread'] is not None and \
local_param['nthread'] != worker.nthreads:
msg += '`nthread` is specified. ' + msg
LOGGER.warning(msg)
elif 'n_jobs' in local_param.keys() and \
local_param['n_jobs'] is not None and \
local_param['n_jobs'] != worker.nthreads:
msg = '`n_jobs` is specified. ' + msg
LOGGER.warning(msg)
else:
local_param['nthread'] = worker.nthreads
override = ['nthread', 'n_jobs']
for p in override:
val = local_param.get(p, None)
if val is not None and val != worker.nthreads:
LOGGER.info(msg)
else:
local_param[p] = worker.nthreads
bst = worker_train(params=local_param,
dtrain=local_dtrain,
*args,
@ -687,20 +683,26 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
ret = None
return ret
if evals:
evals = [(e.create_fn_args(), name) for e, name in evals]
# Note for function purity:
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
futures = client.map(dispatched_train,
workers,
[_rabit_args] * len(workers),
[dtrain.create_fn_args()] * len(workers),
[evals] * len(workers),
pure=False,
workers=workers)
futures = []
for i, worker_addr in enumerate(workers):
if evals:
evals_per_worker = [(e.create_fn_args(worker_addr), name, id(e))
for e, name in evals]
else:
evals_per_worker = []
f = client.submit(dispatched_train,
worker_addr,
_rabit_args,
dtrain.create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False)
futures.append(f)
results = await client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
@ -796,14 +798,16 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
missing = data.missing
meta_names = data.meta_names
def dispatched_predict(worker_id):
def dispatched_predict(worker_id, list_of_keys, list_of_parts):
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
meta_names, worker_map, partition_order, worker)
meta_names, list_of_keys, list_of_parts, partition_order)
predictions = []
booster.set_param({'nthread': worker.nthreads})
for parts in list_of_parts:
(data, _, _, base_margin, _, _, order) = parts
@ -822,17 +826,19 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
return predictions
def dispatched_get_shape(worker_id):
def dispatched_get_shape(worker_id, list_of_keys, list_of_parts):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed.get_worker()
c = distributed.get_client()
list_of_keys = c.compute(list_of_keys).result()
list_of_parts = _get_worker_parts_ordered(
meta_names,
worker_map,
list_of_keys,
list_of_parts,
partition_order,
worker
)
shapes = []
for parts in list_of_parts:
@ -843,15 +849,20 @@ async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
async def map_function(func):
'''Run function for each part of the data.'''
futures = []
for wid in range(len(worker_map)):
list_of_workers = [list(worker_map.keys())[wid]]
f = await client.submit(func, wid,
pure=False,
workers=list_of_workers)
workers_address = list(worker_map.keys())
for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr]
list_of_keys = [part.key for part in list_of_parts]
f = await client.submit(func, worker_id=wid,
list_of_keys=dask.delayed(list_of_keys),
list_of_parts=list_of_parts,
pure=False, workers=[worker_addr])
futures.append(f)
# Get delayed objects
results = await client.gather(futures)
results = [t for l in results for t in l] # flatten into 1 dim list
# flatten into 1 dim list
results = [t for list_per_worker in results for t in list_per_worker]
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
@ -1144,6 +1155,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
'Implementation of the scikit-learn API for XGBoost classification.',
['estimators', 'model'])
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring
async def _fit_async(self, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, early_stopping_rounds,
verbose):
@ -1215,7 +1227,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
output_margin=output_margin)
return pred_probs
def predict_proba(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ,missing-docstring
# pylint: disable=arguments-differ,missing-docstring
def predict_proba(self, data, output_margin=False, base_margin=None):
_assert_dask_support()
return self.client.sync(
self._predict_proba_async,
@ -1241,7 +1254,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return preds
def predict(self, data, output_margin=False, base_margin=None): # pylint: disable=arguments-differ
# pylint: disable=arguments-differ
def predict(self, data, output_margin=False, base_margin=None):
_assert_dask_support()
return self.client.sync(
self._predict_async,

View File

@ -15,6 +15,7 @@ if sys.platform.startswith("win"):
sys.path.append("tests/python")
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
import testing as tm # noqa
@ -217,7 +218,7 @@ class TestDistributedGPU:
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
with Client(local_cuda_cluster) as client:
workers = list(dxgb._get_client_workers(client).keys())
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
futures = client.map(runit,
workers,

View File

@ -23,11 +23,12 @@ def test_rabit_tracker():
def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args, _get_client_workers
from test_with_dask import _get_client_workers
from xgboost.dask import RabitContext, _get_rabit_args
from xgboost import rabit
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(_get_rabit_args, workers, client)
rabit_args = client.sync(_get_rabit_args, len(workers), client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask

View File

@ -41,6 +41,11 @@ kCols = 10
kWorkers = 5
def _get_client_workers(client):
workers = client.scheduler_info()['workers']
return workers
def generate_array(with_weights=False):
partition_size = 20
X = da.random.random((kRows, kCols), partition_size)
@ -704,9 +709,9 @@ class TestWithDask:
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = list(xgb.dask._get_client_workers(client).keys())
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(
xgb.dask._get_rabit_args, workers, client)
xgb.dask._get_rabit_args, len(workers), client)
futures = client.map(runit,
workers,
pure=False,
@ -750,7 +755,6 @@ class TestDaskCallbacks:
num_boost_round=1000,
early_stopping_rounds=early_stopping_rounds)['booster']
assert hasattr(booster, 'best_score')
assert booster.best_iteration == 10
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@ -783,20 +787,22 @@ class TestDaskCallbacks:
X, y = generate_array()
n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y)
workers = list(xgb.dask._get_client_workers(client).keys())
rabit_args = client.sync(xgb.dask._get_rabit_args, workers, client)
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
n_workers = len(workers)
def worker_fn(worker_addr, data_ref):
with xgb.dask.RabitContext(rabit_args):
local_dtrain = xgb.dask._dmatrix_from_worker_map(**data_ref)
local_dtrain = xgb.dask._dmatrix_from_list_of_parts(**data_ref)
total = np.array([local_dtrain.num_row()])
total = xgb.rabit.allreduce(total, xgb.rabit.Op.SUM)
assert total[0] == kRows
futures = client.map(
worker_fn, workers, [m.create_fn_args()] * len(workers),
pure=False, workers=workers)
futures = []
for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i],
m.create_fn_args(workers[i]), pure=False,
workers=[workers[i]]))
client.gather(futures)
has_what = client.has_what()