@@ -14,6 +14,7 @@ https://github.com/dask/dask-xgboost
|
||||
import platform
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
|
||||
import numpy
|
||||
@@ -28,7 +29,7 @@ from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import CUDF_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
from .core import DMatrix, Booster, _expect
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
@@ -357,6 +358,146 @@ class DaskDMatrix:
|
||||
return (rows, cols)
|
||||
|
||||
|
||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
'''A data iterator for `DaskDeviceQuantileDMatrix`.
|
||||
'''
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
label_lower_bound=None, label_upper_bound=None,
|
||||
feature_names=None, feature_types=None):
|
||||
self._data = data
|
||||
self._labels = label
|
||||
self._weights = weight
|
||||
self._base_margin = base_margin
|
||||
self._label_lower_bound = label_lower_bound
|
||||
self._label_upper_bound = label_upper_bound
|
||||
self._feature_names = feature_names
|
||||
self._feature_types = feature_types
|
||||
|
||||
assert isinstance(self._data, Sequence)
|
||||
|
||||
types = (Sequence, type(None))
|
||||
assert isinstance(self._labels, types)
|
||||
assert isinstance(self._weights, types)
|
||||
assert isinstance(self._base_margin, types)
|
||||
assert isinstance(self._label_lower_bound, types)
|
||||
assert isinstance(self._label_upper_bound, types)
|
||||
|
||||
self._iter = 0 # set iterator to 0
|
||||
super().__init__()
|
||||
|
||||
def data(self):
|
||||
'''Utility function for obtaining current batch of data.'''
|
||||
return self._data[self._iter]
|
||||
|
||||
def labels(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
if self._labels is not None:
|
||||
return self._labels[self._iter]
|
||||
return None
|
||||
|
||||
def weights(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
if self._weights is not None:
|
||||
return self._weights[self._iter]
|
||||
return None
|
||||
|
||||
def base_margins(self):
|
||||
'''Utility function for obtaining current batch of base_margin.'''
|
||||
if self._base_margin is not None:
|
||||
return self._base_margin[self._iter]
|
||||
return None
|
||||
|
||||
def label_lower_bounds(self):
|
||||
'''Utility function for obtaining current batch of label_lower_bound.
|
||||
'''
|
||||
if self._label_lower_bound is not None:
|
||||
return self._label_lower_bound[self._iter]
|
||||
return None
|
||||
|
||||
def label_upper_bounds(self):
|
||||
'''Utility function for obtaining current batch of label_upper_bound.
|
||||
'''
|
||||
if self._label_upper_bound is not None:
|
||||
return self._label_upper_bound[self._iter]
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
'''Reset the iterator'''
|
||||
self._iter = 0
|
||||
|
||||
def next(self, input_data):
|
||||
'''Yield next batch of data'''
|
||||
if self._iter == len(self._data):
|
||||
# Return 0 when there's no more batch.
|
||||
return 0
|
||||
if self._feature_names:
|
||||
feature_names = self._feature_names
|
||||
else:
|
||||
if hasattr(self.data(), 'columns'):
|
||||
feature_names = self.data().columns.format()
|
||||
else:
|
||||
feature_names = None
|
||||
input_data(data=self.data(), label=self.labels(),
|
||||
weight=self.weights(), group=None,
|
||||
label_lower_bound=self.label_lower_bounds(),
|
||||
label_upper_bound=self.label_upper_bounds(),
|
||||
feature_names=feature_names,
|
||||
feature_types=self._feature_types)
|
||||
self._iter += 1
|
||||
return 1
|
||||
|
||||
|
||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
'''Specialized data type for `gpu_hist` tree method. This class is
|
||||
used to reduce the memory usage by eliminating data copies.
|
||||
Internally the data is merged by weighted GK sketching. So the
|
||||
number of partitions from dask may affect training accuracy as GK
|
||||
generates error for each merge.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bin: Number of bins for histogram construction.
|
||||
|
||||
'''
|
||||
def __init__(self, client, data, label=None, weight=None,
|
||||
missing=None,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
max_bin=256):
|
||||
super().__init__(client=client, data=data, label=label, weight=weight,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
self.max_bin = max_bin
|
||||
|
||||
def get_worker_data(self, worker):
|
||||
if worker.address not in set(self.worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
address=worker.address,
|
||||
workers=set(self.worker_map.keys()))
|
||||
LOGGER.warning(msg)
|
||||
import cupy # pylint: disable=import-error
|
||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
max_bin=self.max_bin)
|
||||
return d
|
||||
|
||||
data, labels, weights = self.get_worker_parts(worker)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(it,
|
||||
missing=self.missing,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=self.max_bin)
|
||||
return dmatrix
|
||||
|
||||
|
||||
def _get_rabit_args(worker_map, client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||
|
||||
@@ -15,7 +15,7 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _warn_unused_missing(data, missing):
|
||||
if not (np.isnan(missing) or None):
|
||||
if (not np.isnan(missing)) or (missing is None):
|
||||
warnings.warn(
|
||||
'`missing` is not used for current input data type:' +
|
||||
str(type(data)))
|
||||
|
||||
Reference in New Issue
Block a user