Define metainfo and other parameters for all DMatrix interfaces. (#6601)

This PR ensures all DMatrix types have a common interface.

* Fix logic in avoiding duplicated DMatrix in sklearn.
* Check for consistency between DMatrix types.
* Add doc for bounds.
This commit is contained in:
Jiaming Yuan 2021-01-25 16:06:06 +08:00 committed by GitHub
parent 561809200a
commit 8942c98054
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 365 additions and 158 deletions

View File

@ -312,15 +312,18 @@ class DataIter:
data, feature_names, feature_types data, feature_names, feature_types
) )
dispatch_device_quantile_dmatrix_set_data(self.proxy, data) dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info(label=label, weight=weight, self.proxy.set_info(
base_margin=base_margin, label=label,
group=group, weight=weight,
qid=qid, base_margin=base_margin,
label_lower_bound=label_lower_bound, group=group,
label_upper_bound=label_upper_bound, qid=qid,
feature_names=feature_names, label_lower_bound=label_lower_bound,
feature_types=feature_types, label_upper_bound=label_upper_bound,
feature_weights=feature_weights) feature_names=feature_names,
feature_types=feature_types,
feature_weights=feature_weights
)
try: try:
# Differ the exception in order to return 0 and stop the iteration. # Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except # Exception inside a ctype callback function has no effect except
@ -408,7 +411,7 @@ def _deprecate_positional_args(f):
return inner_f return inner_f
class DMatrix: # pylint: disable=too-many-instance-attributes class DMatrix: # pylint: disable=too-many-instance-attributes
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
DMatrix is an internal data structure that is used by XGBoost, DMatrix is an internal data structure that is used by XGBoost,
@ -416,13 +419,26 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
You can construct DMatrix from multiple different sources of data. You can construct DMatrix from multiple different sources of data.
""" """
def __init__(self, data, label=None, weight=None, base_margin=None, @_deprecate_positional_args
missing=None, def __init__(
silent=False, self,
feature_names=None, data,
feature_types=None, label=None,
nthread=None, *,
enable_categorical=False): weight=None,
base_margin=None,
missing: Optional[float] = None,
silent=False,
feature_names=None,
feature_types=None,
nthread: Optional[int] = None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
feature_weights=None,
enable_categorical: bool = False,
) -> None:
"""Parameters """Parameters
---------- ----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/ data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
@ -432,12 +448,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
libsvm format txt file, csv file (by specifying uri parameter libsvm format txt file, csv file (by specifying uri parameter
'path_to_csv?format=csv'), or binary file that xgboost can read 'path_to_csv?format=csv'), or binary file that xgboost can read
from. from.
label : list, numpy 1-D array or cudf.DataFrame, optional label : array_like
Label of the training data. Label of the training data.
missing : float, optional weight : array_like
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
weight : list, numpy 1-D array or cudf.DataFrame , optional
Weight for each instance. Weight for each instance.
.. note:: For ranking task, weights are per-group. .. note:: For ranking task, weights are per-group.
@ -447,6 +460,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
ordering of data points within each group, so it doesn't make ordering of data points within each group, so it doesn't make
sense to assign weights to individual data points. sense to assign weights to individual data points.
base_margin: array_like
Base margin used for boosting from existing model.
missing : float, optional
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
feature_names : list, optional feature_names : list, optional
@ -456,7 +474,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
nthread : integer, optional nthread : integer, optional
Number of threads to use for loading data when parallelization is Number of threads to use for loading data when parallelization is
applicable. If -1, uses maximum threads available on the system. applicable. If -1, uses maximum threads available on the system.
group : array_like
Group size for all ranking group.
qid : array_like
Query ID for data samples, used for ranking.
label_lower_bound : array_like
Lower bound for survival training.
label_upper_bound : array_like
Upper bound for survival training.
feature_weights : array_like, optional
Set feature weights for column sampling.
enable_categorical: boolean, optional enable_categorical: boolean, optional
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -469,7 +496,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
""" """
if isinstance(data, list): if isinstance(data, list):
raise TypeError('Input data can not be a list.') raise TypeError("Input data can not be a list.")
if group is not None and qid is not None:
raise ValueError("Either one of `group` or `qid` should be None.")
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else -1 self.nthread = nthread if nthread is not None else -1
@ -481,16 +510,28 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
return return
from .data import dispatch_data_backend from .data import dispatch_data_backend
handle, feature_names, feature_types = dispatch_data_backend( handle, feature_names, feature_types = dispatch_data_backend(
data, missing=self.missing, data,
missing=self.missing,
threads=self.nthread, threads=self.nthread,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
enable_categorical=enable_categorical) enable_categorical=enable_categorical,
)
assert handle is not None assert handle is not None
self.handle = handle self.handle = handle
self.set_info(label=label, weight=weight, base_margin=base_margin) self.set_info(
label=label,
weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_weights=feature_weights,
)
if feature_names is not None: if feature_names is not None:
self.feature_names = feature_names self.feature_names = feature_names
@ -503,17 +544,23 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.handle = None self.handle = None
@_deprecate_positional_args @_deprecate_positional_args
def set_info(self, *, def set_info(
label=None, weight=None, base_margin=None, self,
group=None, *,
qid=None, label=None,
label_lower_bound=None, weight=None,
label_upper_bound=None, base_margin=None,
feature_names=None, group=None,
feature_types=None, qid=None,
feature_weights=None): label_lower_bound=None,
'''Set meta info for DMatrix.''' label_upper_bound=None,
feature_names=None,
feature_types=None,
feature_weights=None
) -> None:
"""Set meta info for DMatrix. See doc string for DMatrix constructor."""
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
if weight is not None: if weight is not None:
@ -918,39 +965,67 @@ class DeviceQuantileDMatrix(DMatrix):
information may be lost in quantisation. This DMatrix is primarily designed information may be lost in quantisation. This DMatrix is primarily designed
to save memory in training from device memory inputs by avoiding to save memory in training from device memory inputs by avoiding
intermediate storage. Set max_bin to control the number of bins during intermediate storage. Set max_bin to control the number of bins during
quantisation. quantisation. See doc string in `DMatrix` for documents on meta info.
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
.. versionadded:: 1.1.0 .. versionadded:: 1.1.0
""" """
@_deprecate_positional_args
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231 def __init__( # pylint: disable=super-init-not-called
base_margin=None, self,
missing=None, data,
silent=False, label=None,
feature_names=None, *,
feature_types=None, weight=None,
nthread=None, max_bin=256): base_margin=None,
missing=None,
silent=False,
feature_names=None,
feature_types=None,
nthread: Optional[int] = None,
max_bin: int = 256,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
feature_weights=None,
enable_categorical: bool = False,
):
self.max_bin = max_bin self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1 self.nthread = nthread if nthread is not None else 1
self._silent = silent # unused, kept for compatibility
if isinstance(data, ctypes.c_void_p): if isinstance(data, ctypes.c_void_p):
self.handle = data self.handle = data
return return
from .data import init_device_quantile_dmatrix from .data import init_device_quantile_dmatrix
handle, feature_names, feature_types = init_device_quantile_dmatrix( handle, feature_names, feature_types = init_device_quantile_dmatrix(
data, missing=self.missing, threads=self.nthread, data,
max_bin=self.max_bin,
label=label, weight=weight, label=label, weight=weight,
base_margin=base_margin, base_margin=base_margin,
group=None, group=group,
label_lower_bound=None, qid=qid,
label_upper_bound=None, missing=self.missing,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_weights=feature_weights,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types) feature_types=feature_types,
threads=self.nthread,
max_bin=self.max_bin,
)
if enable_categorical:
raise NotImplementedError(
'categorical support is not enabled on DeviceQuantileDMatrix.'
)
self.handle = handle self.handle = handle
if qid is not None and group is not None:
raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation '
'dataset should be provided.'
)
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types

View File

@ -38,8 +38,9 @@ from .core import Objective, Metric
from .core import _deprecate_positional_args from .core import _deprecate_positional_args
from .training import train as worker_train from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc, _objective_decorator from .sklearn import XGBRankerMixIn
from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker from .sklearn import XGBRanker
@ -180,10 +181,12 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
class DaskDMatrix: class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes # pylint: disable=missing-docstring, too-many-instance-attributes
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing '''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
a `DaskDMatrix` forces all lazy computation to be carried out. Wait for `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
the input data explicitly if you want to see actual computation of explicitly if you want to see actual computation of constructing `DaskDMatrix`.
constructing `DaskDMatrix`.
See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only
dask collection.
.. note:: .. note::
@ -197,29 +200,6 @@ class DaskDMatrix:
client : client :
Specify the dask client used for training. Use default client returned from dask Specify the dask client used for training. Use default client returned from dask
if it's set to None. if it's set to None.
data :
data source of DMatrix.
label :
label used for trainin.
missing :
Value in the input data (e.g. `numpy.ndarray`) which needs to be present as a
missing value. If None, defaults to np.nan.
weight :
Weight for each instance.
base_margin :
Global bias for each instance.
qid :
Query ID for ranking.
label_lower_bound :
Upper bound for survival training.
label_upper_bound :
Lower bound for survival training.
feature_weights :
Weight for features used in column sampling.
feature_names :
Set names for features.
feature_types :
Set types for features
''' '''
@ -230,15 +210,18 @@ class DaskDMatrix:
data: _DaskCollection, data: _DaskCollection,
label: Optional[_DaskCollection] = None, label: Optional[_DaskCollection] = None,
*, *,
missing: float = None,
weight: Optional[_DaskCollection] = None, weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
missing: float = None,
silent: bool = False, # pylint: disable=unused-argument
feature_names: Optional[Union[str, List[str]]] = None,
feature_types: Optional[Union[Any, List[Any]]] = None,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
feature_names: Optional[Union[str, List[str]]] = None, enable_categorical: bool = False
feature_types: Optional[Union[Any, List[Any]]] = None
) -> None: ) -> None:
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
@ -248,30 +231,41 @@ class DaskDMatrix:
self.missing = missing self.missing = missing
if qid is not None and weight is not None: if qid is not None and weight is not None:
raise NotImplementedError('per-group weight is not implemented.') raise NotImplementedError("per-group weight is not implemented.")
if group is not None:
raise NotImplementedError(
"group structure is not implemented, use qid instead."
)
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on `DaskDMatrix`."
)
if len(data.shape) != 2: if len(data.shape) != 2:
raise ValueError( raise ValueError(
'Expecting 2 dimensional input, got: {shape}'.format( "Expecting 2 dimensional input, got: {shape}".format(shape=data.shape)
shape=data.shape)) )
if not isinstance(data, (dd.DataFrame, da.Array)): if not isinstance(data, (dd.DataFrame, da.Array)):
raise TypeError(_expect((dd.DataFrame, da.Array), type(data))) raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
type(None))): raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
raise TypeError(
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.is_quantile: bool = False self.is_quantile: bool = False
self._init = client.sync(self.map_local_data, self._init = client.sync(
client, data, label=label, weights=weight, self.map_local_data,
base_margin=base_margin, client,
qid=qid, data,
feature_weights=feature_weights, label=label,
label_lower_bound=label_lower_bound, weights=weight,
label_upper_bound=label_upper_bound) base_margin=base_margin,
qid=qid,
feature_weights=feature_weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
)
def __await__(self) -> Generator: def __await__(self) -> Generator:
return self._init.__await__() return self._init.__await__()
@ -571,11 +565,11 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
class DaskDeviceQuantileDMatrix(DaskDMatrix): class DaskDeviceQuantileDMatrix(DaskDMatrix):
'''Specialized data type for `gpu_hist` tree method. This class is used to '''Specialized data type for `gpu_hist` tree method. This class is used to reduce the
reduce the memory usage by eliminating data copies. Internally the all memory usage by eliminating data copies. Internally the all partitions/chunks of data
partitions/chunks of data are merged by weighted GK sketching. So the are merged by weighted GK sketching. So the number of partitions from dask may affect
number of partitions from dask may affect training accuracy as GK generates training accuracy as GK generates bounded error for each merge. See doc string for
bounded error for each merge. `DeviceQuantileDMatrix` and `DMatrix` for other parameters.
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0
@ -584,42 +578,50 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
max_bin : Number of bins for histogram construction. max_bin : Number of bins for histogram construction.
''' '''
@_deprecate_positional_args
def __init__( def __init__(
self, self,
client: "distributed.Client", client: "distributed.Client",
data: _DaskCollection, data: _DaskCollection,
label: Optional[_DaskCollection] = None, label: Optional[_DaskCollection] = None,
missing: float = None, *,
weight: Optional[_DaskCollection] = None, weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
missing: float = None,
silent: bool = False,
feature_names: Optional[Union[str, List[str]]] = None,
feature_types: Optional[Union[Any, List[Any]]] = None,
max_bin: int = 256,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None, qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None, label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
feature_names: Optional[Union[str, List[str]]] = None, enable_categorical: bool = False,
feature_types: Optional[Union[Any, List[Any]]] = None,
max_bin: int = 256
) -> None: ) -> None:
super().__init__( super().__init__(
client=client, client=client,
data=data, data=data,
label=label, label=label,
missing=missing,
feature_weights=feature_weights,
weight=weight, weight=weight,
base_margin=base_margin, base_margin=base_margin,
group=group,
qid=qid, qid=qid,
label_lower_bound=label_lower_bound, label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound, label_upper_bound=label_upper_bound,
missing=missing,
silent=silent,
feature_weights=feature_weights,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types feature_types=feature_types,
enable_categorical=enable_categorical,
) )
self.max_bin = max_bin self.max_bin = max_bin
self.is_quantile = True self.is_quantile = True
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]: def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
args = super().create_fn_args(worker_addr) args = super().create_fn_args(worker_addr)
args['max_bin'] = self.max_bin args["max_bin"] = self.max_bin
return args return args
@ -630,35 +632,49 @@ def _create_device_quantile_dmatrix(
meta_names: List[str], meta_names: List[str],
missing: float, missing: float,
parts: Optional[_DataParts], parts: Optional[_DataParts],
max_bin: int max_bin: int,
) -> DeviceQuantileDMatrix: ) -> DeviceQuantileDMatrix:
worker = distributed.get_worker() worker = distributed.get_worker()
if parts is None: if parts is None:
msg = 'worker {address} has an empty DMatrix. '.format( msg = "worker {address} has an empty DMatrix.".format(address=worker.address)
address=worker.address)
LOGGER.warning(msg) LOGGER.warning(msg)
import cupy import cupy
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
feature_names=feature_names, d = DeviceQuantileDMatrix(
feature_types=feature_types, cupy.zeros((0, 0)),
max_bin=max_bin) feature_names=feature_names,
feature_types=feature_types,
max_bin=max_bin,
)
return d return d
(data, labels, weights, base_margin, qid, (
label_lower_bound, label_upper_bound) = _get_worker_parts( data,
parts, meta_names) labels,
it = DaskPartitionIter(data=data, label=labels, weight=weights, weights,
base_margin=base_margin, base_margin,
qid=qid, qid,
label_lower_bound=label_lower_bound, label_lower_bound,
label_upper_bound=label_upper_bound) label_upper_bound,
) = _get_worker_parts(parts, meta_names)
it = DaskPartitionIter(
data=data,
label=labels,
weight=weights,
base_margin=base_margin,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
)
dmatrix = DeviceQuantileDMatrix(it, dmatrix = DeviceQuantileDMatrix(
missing=missing, it,
feature_names=feature_names, missing=missing,
feature_types=feature_types, feature_names=feature_names,
nthread=worker.nthreads, feature_types=feature_types,
max_bin=max_bin) nthread=worker.nthreads,
max_bin=max_bin,
)
dmatrix.set_info(feature_weights=feature_weights) dmatrix.set_info(feature_weights=feature_weights)
return dmatrix return dmatrix
@ -712,13 +728,15 @@ def _create_dmatrix(
missing=missing, missing=missing,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads nthread=worker.nthreads,
) )
dmatrix.set_info( dmatrix.set_info(
base_margin=_base_margin, qid=_qid, weight=_weights, base_margin=_base_margin,
qid=_qid,
weight=_weights,
label_lower_bound=_label_lower_bound, label_lower_bound=_label_lower_bound,
label_upper_bound=_label_upper_bound, label_upper_bound=_label_upper_bound,
feature_weights=feature_weights feature_weights=feature_weights,
) )
return dmatrix return dmatrix
@ -753,6 +771,8 @@ def _get_workers_from_data(
for e in evals: for e in evals:
assert len(e) == 2 assert len(e) == 2
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str) assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
if e[0] is dtrain:
continue
worker_map = set(e[0].worker_map.keys()) worker_map = set(e[0].worker_map.keys())
X_worker_map = X_worker_map.union(worker_map) X_worker_map = X_worker_map.union(worker_map)
return X_worker_map return X_worker_map
@ -960,7 +980,7 @@ async def _predict_async(
worker = distributed.get_worker() worker = distributed.get_worker()
with config.config_context(**global_config): with config.config_context(**global_config):
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads) m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict( predt = booster.predict(
data=m, data=m,
output_margin=output_margin, output_margin=output_margin,
@ -1587,7 +1607,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
For dask implementation, group is not supported, use qid instead. For dask implementation, group is not supported, use qid instead.
""", """,
) )
class DaskXGBRanker(DaskScikitLearnBase): class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@_deprecate_positional_args @_deprecate_positional_args
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective): if callable(objective):
@ -1632,11 +1652,10 @@ class DaskXGBRanker(DaskScikitLearnBase):
if eval_metric is not None: if eval_metric is not None:
if callable(eval_metric): if callable(eval_metric):
raise ValueError( raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.') "Custom evaluation metric is not yet supported for XGBRanker."
)
model, metric, params = self._configure_fit( model, metric, params = self._configure_fit(
booster=xgb_model, booster=xgb_model, eval_metric=eval_metric, params=params
eval_metric=eval_metric,
params=params
) )
results = await train( results = await train(
client=self.client, client=self.client,

View File

@ -737,16 +737,28 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
area for meta info. area for meta info.
''' '''
def __init__(self, data, label, weight, base_margin, group, def __init__(
label_lower_bound, label_upper_bound, self, data,
feature_names, feature_types): label,
weight,
base_margin,
group,
qid,
label_lower_bound,
label_upper_bound,
feature_weights,
feature_names,
feature_types
):
self.data = data self.data = data
self.label = label self.label = label
self.weight = weight self.weight = weight
self.base_margin = base_margin self.base_margin = base_margin
self.group = group self.group = group
self.qid = qid
self.label_lower_bound = label_lower_bound self.label_lower_bound = label_lower_bound
self.label_upper_bound = label_upper_bound self.label_upper_bound = label_upper_bound
self.feature_weights = feature_weights
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
self.it = 0 # pylint: disable=invalid-name self.it = 0 # pylint: disable=invalid-name
@ -759,8 +771,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
input_data(data=self.data, label=self.label, input_data(data=self.data, label=self.label,
weight=self.weight, base_margin=self.base_margin, weight=self.weight, base_margin=self.base_margin,
group=self.group, group=self.group,
qid=self.qid,
label_lower_bound=self.label_lower_bound, label_lower_bound=self.label_lower_bound,
label_upper_bound=self.label_upper_bound, label_upper_bound=self.label_upper_bound,
feature_weights=self.feature_weights,
feature_names=self.feature_names, feature_names=self.feature_names,
feature_types=self.feature_types) feature_types=self.feature_types)
return 1 return 1
@ -770,7 +784,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
def init_device_quantile_dmatrix( def init_device_quantile_dmatrix(
data, missing, max_bin, threads, feature_names, feature_types, **meta): data, missing, max_bin, threads, feature_names, feature_types, **meta
):
'''Constructor for DeviceQuantileDMatrix.''' '''Constructor for DeviceQuantileDMatrix.'''
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data), if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
_is_dlpack(data), _is_iter(data)]): _is_dlpack(data), _is_iter(data)]):

View File

@ -556,7 +556,7 @@ class XGBModel(XGBModelBase):
def _configure_fit( def _configure_fit(
self, self,
booster: Optional[Booster], booster: Optional[Union[Booster, "XGBModel"]],
eval_metric: Optional[Union[Callable, str, List[str]]], eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any], params: Dict[str, Any],
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]: ) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
@ -631,7 +631,7 @@ class XGBModel(XGBModelBase):
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
xgb_model : str xgb_model : Union[str, Booster, XGBModel]
file name of stored XGBoost model or 'Booster' instance XGBoost model to be file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation). loaded before training (allows training continuation).
sample_weight_eval_set : list, optional sample_weight_eval_set : list, optional
@ -942,10 +942,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
super().__init__(objective=objective, **kwargs) super().__init__(objective=objective, **kwargs)
@_deprecate_positional_args @_deprecate_positional_args
def fit(self, X, y, *, sample_weight=None, base_margin=None, def fit(
eval_set=None, eval_metric=None, self,
early_stopping_rounds=None, verbose=True, xgb_model=None, X,
sample_weight_eval_set=None, feature_weights=None, callbacks=None): y,
*,
sample_weight=None,
base_margin=None,
eval_set=None,
eval_metric=None,
early_stopping_rounds=None,
verbose=True,
xgb_model=None,
sample_weight_eval_set=None,
feature_weights=None,
callbacks=None
):
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements # pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements
can_use_label_encoder = True can_use_label_encoder = True
@ -1283,7 +1295,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
@_deprecate_positional_args @_deprecate_positional_args
def fit( def fit(
self, X, y, *, self,
X,
y,
*,
group=None, group=None,
qid=None, qid=None,
sample_weight=None, sample_weight=None,
@ -1372,7 +1387,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
verbose : bool verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr. metric measured on the validation set to stderr.
xgb_model : str xgb_model : Union[str, Booster, XGBModel]
file name of stored XGBoost model or 'Booster' instance XGBoost file name of stored XGBoost model or 'Booster' instance XGBoost
model to be loaded before training (allows training continuation). model to be loaded before training (allows training continuation).
feature_weights: array_like feature_weights: array_like
@ -1391,9 +1406,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
save_best=True)] save_best=True)]
""" """
# check if group information is provided if group is None and qid is None:
if group is None: raise ValueError("group or qid is required for ranking task")
raise ValueError("group is required for ranking task")
if eval_set is not None: if eval_set is not None:
if eval_group is None and eval_qid is None: if eval_group is None and eval_qid is None:

View File

@ -34,3 +34,25 @@ class TestDeviceQuantileDMatrix:
import cupy as cp import cupy as cp
data = cp.random.randn(5, 5) data = cp.random.randn(5, 5)
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy())
def test_metainfo(self) -> None:
import cupy as cp
rng = cp.random.RandomState(1994)
rows = 10
cols = 3
data = rng.randn(rows, cols)
labels = rng.randn(rows)
fw = rng.randn(rows)
fw -= fw.min()
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw)
got_fw = m.get_float_info("feature_weights")
got_labels = m.get_label()
cp.testing.assert_allclose(fw, got_fw)
cp.testing.assert_allclose(labels, got_labels)

View File

@ -6,7 +6,9 @@ import numpy as np
import asyncio import asyncio
import xgboost import xgboost
import subprocess import subprocess
from hypothesis import given, strategies, settings, note, HealthCheck from collections import OrderedDict
from inspect import signature
from hypothesis import given, strategies, settings, note
from hypothesis._settings import duration from hypothesis._settings import duration
from test_gpu_updaters import parameter_strategy from test_gpu_updaters import parameter_strategy
@ -18,13 +20,15 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa from test_with_dask import generate_array # noqa
from test_with_dask import suppress from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
import testing as tm # noqa import testing as tm # noqa
try: try:
import dask.dataframe as dd import dask.dataframe as dd
from xgboost import dask as dxgb from xgboost import dask as dxgb
import xgboost as xgb
from dask.distributed import Client from dask.distributed import Client
from dask import array as da from dask import array as da
from dask_cuda import LocalCUDACluster from dask_cuda import LocalCUDACluster
@ -252,6 +256,64 @@ class TestDistributedGPU:
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters) run_empty_dmatrix_cls(client, parameters)
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
X, y, _ = generate_array()
fw = da.random.random((random_cols, ))
fw = fw - fw.min()
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
workers = list(_get_client_workers(client).keys())
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
with dxgb.RabitContext(rabit_args):
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
assert fw_rows == local_dtrain.num_col()
futures = []
for i in range(len(workers)):
futures.append(client.submit(worker_fn, workers[i],
m.create_fn_args(workers[i]), pure=False,
workers=[workers[i]]))
client.gather(futures)
def test_interface_consistency(self) -> None:
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
del sig["client"]
ddm_names = list(sig.keys())
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
del sig["client"]
del sig["max_bin"]
ddqdm_names = list(sig.keys())
assert len(ddm_names) == len(ddqdm_names)
# between dask
for i in range(len(ddm_names)):
assert ddm_names[i] == ddqdm_names[i]
sig = OrderedDict(signature(xgb.DMatrix).parameters)
del sig["nthread"] # no nthread in dask
dm_names = list(sig.keys())
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
del sig["nthread"]
del sig["max_bin"]
dqdm_names = list(sig.keys())
# between single node
assert len(dm_names) == len(dqdm_names)
for i in range(len(dm_names)):
assert dm_names[i] == dqdm_names[i]
# ddm <-> dm
for i in range(len(ddm_names)):
assert ddm_names[i] == dm_names[i]
# dqdm <-> ddqdm
for i in range(len(ddqdm_names)):
assert ddqdm_names[i] == dqdm_names[i]
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None: def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows") pytest.skip("Skipping dask tests on Windows")