Dask device dmatrix (#5901)

* Fix softprob with empty dmatrix.
This commit is contained in:
Jiaming Yuan
2020-07-17 13:17:43 +08:00
committed by GitHub
parent e471056ec4
commit 7c2686146e
12 changed files with 392 additions and 149 deletions

View File

@@ -14,6 +14,7 @@ https://github.com/dask/dask-xgboost
import platform
import logging
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
import numpy
@@ -28,7 +29,7 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_concat
from .compat import lazy_isinstance
from .core import DMatrix, Booster, _expect
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .training import train as worker_train
from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
@@ -357,6 +358,146 @@ class DaskDMatrix:
return (rows, cols)
class DaskPartitionIter(DataIter): # pylint: disable=R0902
'''A data iterator for `DaskDeviceQuantileDMatrix`.
'''
def __init__(self, data, label=None, weight=None, base_margin=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
self._data = data
self._labels = label
self._weights = weight
self._base_margin = base_margin
self._label_lower_bound = label_lower_bound
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
self._feature_types = feature_types
assert isinstance(self._data, Sequence)
types = (Sequence, type(None))
assert isinstance(self._labels, types)
assert isinstance(self._weights, types)
assert isinstance(self._base_margin, types)
assert isinstance(self._label_lower_bound, types)
assert isinstance(self._label_upper_bound, types)
self._iter = 0 # set iterator to 0
super().__init__()
def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self._iter]
def labels(self):
'''Utility function for obtaining current batch of label.'''
if self._labels is not None:
return self._labels[self._iter]
return None
def weights(self):
'''Utility function for obtaining current batch of label.'''
if self._weights is not None:
return self._weights[self._iter]
return None
def base_margins(self):
'''Utility function for obtaining current batch of base_margin.'''
if self._base_margin is not None:
return self._base_margin[self._iter]
return None
def label_lower_bounds(self):
'''Utility function for obtaining current batch of label_lower_bound.
'''
if self._label_lower_bound is not None:
return self._label_lower_bound[self._iter]
return None
def label_upper_bounds(self):
'''Utility function for obtaining current batch of label_upper_bound.
'''
if self._label_upper_bound is not None:
return self._label_upper_bound[self._iter]
return None
def reset(self):
'''Reset the iterator'''
self._iter = 0
def next(self, input_data):
'''Yield next batch of data'''
if self._iter == len(self._data):
# Return 0 when there's no more batch.
return 0
if self._feature_names:
feature_names = self._feature_names
else:
if hasattr(self.data(), 'columns'):
feature_names = self.data().columns.format()
else:
feature_names = None
input_data(data=self.data(), label=self.labels(),
weight=self.weights(), group=None,
label_lower_bound=self.label_lower_bounds(),
label_upper_bound=self.label_upper_bounds(),
feature_names=feature_names,
feature_types=self._feature_types)
self._iter += 1
return 1
class DaskDeviceQuantileDMatrix(DaskDMatrix):
'''Specialized data type for `gpu_hist` tree method. This class is
used to reduce the memory usage by eliminating data copies.
Internally the data is merged by weighted GK sketching. So the
number of partitions from dask may affect training accuracy as GK
generates error for each merge.
.. versionadded:: 1.2.0
Parameters
----------
max_bin: Number of bins for histogram construction.
'''
def __init__(self, client, data, label=None, weight=None,
missing=None,
feature_names=None,
feature_types=None,
max_bin=256):
super().__init__(client=client, data=data, label=label, weight=weight,
missing=missing,
feature_names=feature_names,
feature_types=feature_types)
self.max_bin = max_bin
def get_worker_data(self, worker):
if worker.address not in set(self.worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
address=worker.address,
workers=set(self.worker_map.keys()))
LOGGER.warning(msg)
import cupy # pylint: disable=import-error
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=self.feature_names,
feature_types=self.feature_types,
max_bin=self.max_bin)
return d
data, labels, weights = self.get_worker_parts(worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights)
dmatrix = DeviceQuantileDMatrix(it,
missing=self.missing,
feature_names=self.feature_names,
feature_types=self.feature_types,
nthread=worker.nthreads,
max_bin=self.max_bin)
return dmatrix
def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)

View File

@@ -15,7 +15,7 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
def _warn_unused_missing(data, missing):
if not (np.isnan(missing) or None):
if (not np.isnan(missing)) or (missing is None):
warnings.warn(
'`missing` is not used for current input data type:' +
str(type(data)))