Define metainfo and other parameters for all DMatrix interfaces. (#6601)
This PR ensures all DMatrix types have a common interface. * Fix logic in avoiding duplicated DMatrix in sklearn. * Check for consistency between DMatrix types. * Add doc for bounds.
This commit is contained in:
parent
561809200a
commit
8942c98054
@ -312,15 +312,18 @@ class DataIter:
|
||||
data, feature_names, feature_types
|
||||
)
|
||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||
self.proxy.set_info(label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
feature_weights=feature_weights)
|
||||
self.proxy.set_info(
|
||||
label=label,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
feature_weights=feature_weights
|
||||
)
|
||||
try:
|
||||
# Differ the exception in order to return 0 and stop the iteration.
|
||||
# Exception inside a ctype callback function has no effect except
|
||||
@ -408,7 +411,7 @@ def _deprecate_positional_args(f):
|
||||
return inner_f
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
DMatrix is an internal data structure that is used by XGBoost,
|
||||
@ -416,13 +419,26 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
You can construct DMatrix from multiple different sources of data.
|
||||
"""
|
||||
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread=None,
|
||||
enable_categorical=False):
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
label=None,
|
||||
*,
|
||||
weight=None,
|
||||
base_margin=None,
|
||||
missing: Optional[float] = None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread: Optional[int] = None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_weights=None,
|
||||
enable_categorical: bool = False,
|
||||
) -> None:
|
||||
"""Parameters
|
||||
----------
|
||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
||||
@ -432,12 +448,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
libsvm format txt file, csv file (by specifying uri parameter
|
||||
'path_to_csv?format=csv'), or binary file that xgboost can read
|
||||
from.
|
||||
label : list, numpy 1-D array or cudf.DataFrame, optional
|
||||
label : array_like
|
||||
Label of the training data.
|
||||
missing : float, optional
|
||||
Value in the input data which needs to be present as a missing
|
||||
value. If None, defaults to np.nan.
|
||||
weight : list, numpy 1-D array or cudf.DataFrame , optional
|
||||
weight : array_like
|
||||
Weight for each instance.
|
||||
|
||||
.. note:: For ranking task, weights are per-group.
|
||||
@ -447,6 +460,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
ordering of data points within each group, so it doesn't make
|
||||
sense to assign weights to individual data points.
|
||||
|
||||
base_margin: array_like
|
||||
Base margin used for boosting from existing model.
|
||||
missing : float, optional
|
||||
Value in the input data which needs to be present as a missing
|
||||
value. If None, defaults to np.nan.
|
||||
silent : boolean, optional
|
||||
Whether print messages during construction
|
||||
feature_names : list, optional
|
||||
@ -456,7 +474,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
nthread : integer, optional
|
||||
Number of threads to use for loading data when parallelization is
|
||||
applicable. If -1, uses maximum threads available on the system.
|
||||
|
||||
group : array_like
|
||||
Group size for all ranking group.
|
||||
qid : array_like
|
||||
Query ID for data samples, used for ranking.
|
||||
label_lower_bound : array_like
|
||||
Lower bound for survival training.
|
||||
label_upper_bound : array_like
|
||||
Upper bound for survival training.
|
||||
feature_weights : array_like, optional
|
||||
Set feature weights for column sampling.
|
||||
enable_categorical: boolean, optional
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
@ -469,7 +496,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
raise TypeError('Input data can not be a list.')
|
||||
raise TypeError("Input data can not be a list.")
|
||||
if group is not None and qid is not None:
|
||||
raise ValueError("Either one of `group` or `qid` should be None.")
|
||||
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.nthread = nthread if nthread is not None else -1
|
||||
@ -481,16 +510,28 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
return
|
||||
|
||||
from .data import dispatch_data_backend
|
||||
|
||||
handle, feature_names, feature_types = dispatch_data_backend(
|
||||
data, missing=self.missing,
|
||||
data,
|
||||
missing=self.missing,
|
||||
threads=self.nthread,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
enable_categorical=enable_categorical)
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
assert handle is not None
|
||||
self.handle = handle
|
||||
|
||||
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
||||
self.set_info(
|
||||
label=label,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_weights=feature_weights,
|
||||
)
|
||||
|
||||
if feature_names is not None:
|
||||
self.feature_names = feature_names
|
||||
@ -503,17 +544,23 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.handle = None
|
||||
|
||||
@_deprecate_positional_args
|
||||
def set_info(self, *,
|
||||
label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
feature_weights=None):
|
||||
'''Set meta info for DMatrix.'''
|
||||
def set_info(
|
||||
self,
|
||||
*,
|
||||
label=None,
|
||||
weight=None,
|
||||
base_margin=None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
feature_weights=None
|
||||
) -> None:
|
||||
"""Set meta info for DMatrix. See doc string for DMatrix constructor."""
|
||||
from .data import dispatch_meta_backend
|
||||
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
@ -918,39 +965,67 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
information may be lost in quantisation. This DMatrix is primarily designed
|
||||
to save memory in training from device memory inputs by avoiding
|
||||
intermediate storage. Set max_bin to control the number of bins during
|
||||
quantisation.
|
||||
quantisation. See doc string in `DMatrix` for documents on meta info.
|
||||
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
"""
|
||||
|
||||
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
|
||||
base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread=None, max_bin=256):
|
||||
@_deprecate_positional_args
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
data,
|
||||
label=None,
|
||||
*,
|
||||
weight=None,
|
||||
base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread: Optional[int] = None,
|
||||
max_bin: int = 256,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_weights=None,
|
||||
enable_categorical: bool = False,
|
||||
):
|
||||
self.max_bin = max_bin
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.nthread = nthread if nthread is not None else 1
|
||||
self._silent = silent # unused, kept for compatibility
|
||||
|
||||
if isinstance(data, ctypes.c_void_p):
|
||||
self.handle = data
|
||||
return
|
||||
from .data import init_device_quantile_dmatrix
|
||||
handle, feature_names, feature_types = init_device_quantile_dmatrix(
|
||||
data, missing=self.missing, threads=self.nthread,
|
||||
max_bin=self.max_bin,
|
||||
data,
|
||||
label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
group=group,
|
||||
qid=qid,
|
||||
missing=self.missing,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_weights=feature_weights,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
feature_types=feature_types,
|
||||
threads=self.nthread,
|
||||
max_bin=self.max_bin,
|
||||
)
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
'categorical support is not enabled on DeviceQuantileDMatrix.'
|
||||
)
|
||||
self.handle = handle
|
||||
if qid is not None and group is not None:
|
||||
raise ValueError(
|
||||
'Only one of the eval_qid or eval_group for each evaluation '
|
||||
'dataset should be provided.'
|
||||
)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
|
||||
@ -38,8 +38,9 @@ from .core import Objective, Metric
|
||||
from .core import _deprecate_positional_args
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import xgboost_model_doc, _objective_decorator
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||
from .sklearn import XGBRankerMixIn
|
||||
from .sklearn import xgboost_model_doc
|
||||
from .sklearn import _cls_predict_proba
|
||||
from .sklearn import XGBRanker
|
||||
|
||||
@ -180,10 +181,12 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
||||
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
|
||||
a `DaskDMatrix` forces all lazy computation to be carried out. Wait for
|
||||
the input data explicitly if you want to see actual computation of
|
||||
constructing `DaskDMatrix`.
|
||||
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
|
||||
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
|
||||
See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only
|
||||
dask collection.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -197,29 +200,6 @@ class DaskDMatrix:
|
||||
client :
|
||||
Specify the dask client used for training. Use default client returned from dask
|
||||
if it's set to None.
|
||||
data :
|
||||
data source of DMatrix.
|
||||
label :
|
||||
label used for trainin.
|
||||
missing :
|
||||
Value in the input data (e.g. `numpy.ndarray`) which needs to be present as a
|
||||
missing value. If None, defaults to np.nan.
|
||||
weight :
|
||||
Weight for each instance.
|
||||
base_margin :
|
||||
Global bias for each instance.
|
||||
qid :
|
||||
Query ID for ranking.
|
||||
label_lower_bound :
|
||||
Upper bound for survival training.
|
||||
label_upper_bound :
|
||||
Lower bound for survival training.
|
||||
feature_weights :
|
||||
Weight for features used in column sampling.
|
||||
feature_names :
|
||||
Set names for features.
|
||||
feature_types :
|
||||
Set types for features
|
||||
|
||||
'''
|
||||
|
||||
@ -230,15 +210,18 @@ class DaskDMatrix:
|
||||
data: _DaskCollection,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
missing: float = None,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
missing: float = None,
|
||||
silent: bool = False, # pylint: disable=unused-argument
|
||||
feature_names: Optional[Union[str, List[str]]] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
feature_names: Optional[Union[str, List[str]]] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None
|
||||
enable_categorical: bool = False
|
||||
) -> None:
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
@ -248,30 +231,41 @@ class DaskDMatrix:
|
||||
self.missing = missing
|
||||
|
||||
if qid is not None and weight is not None:
|
||||
raise NotImplementedError('per-group weight is not implemented.')
|
||||
raise NotImplementedError("per-group weight is not implemented.")
|
||||
if group is not None:
|
||||
raise NotImplementedError(
|
||||
"group structure is not implemented, use qid instead."
|
||||
)
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
"categorical support is not enabled on `DaskDMatrix`."
|
||||
)
|
||||
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError(
|
||||
'Expecting 2 dimensional input, got: {shape}'.format(
|
||||
shape=data.shape))
|
||||
"Expecting 2 dimensional input, got: {shape}".format(shape=data.shape)
|
||||
)
|
||||
|
||||
if not isinstance(data, (dd.DataFrame, da.Array)):
|
||||
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
|
||||
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series,
|
||||
type(None))):
|
||||
raise TypeError(
|
||||
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
|
||||
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||
|
||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||
self.is_quantile: bool = False
|
||||
|
||||
self._init = client.sync(self.map_local_data,
|
||||
client, data, label=label, weights=weight,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
feature_weights=feature_weights,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
self._init = client.sync(
|
||||
self.map_local_data,
|
||||
client,
|
||||
data,
|
||||
label=label,
|
||||
weights=weight,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
feature_weights=feature_weights,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
)
|
||||
|
||||
def __await__(self) -> Generator:
|
||||
return self._init.__await__()
|
||||
@ -571,11 +565,11 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
|
||||
|
||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
'''Specialized data type for `gpu_hist` tree method. This class is used to
|
||||
reduce the memory usage by eliminating data copies. Internally the all
|
||||
partitions/chunks of data are merged by weighted GK sketching. So the
|
||||
number of partitions from dask may affect training accuracy as GK generates
|
||||
bounded error for each merge.
|
||||
'''Specialized data type for `gpu_hist` tree method. This class is used to reduce the
|
||||
memory usage by eliminating data copies. Internally the all partitions/chunks of data
|
||||
are merged by weighted GK sketching. So the number of partitions from dask may affect
|
||||
training accuracy as GK generates bounded error for each merge. See doc string for
|
||||
`DeviceQuantileDMatrix` and `DMatrix` for other parameters.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
@ -584,42 +578,50 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
max_bin : Number of bins for histogram construction.
|
||||
|
||||
'''
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
missing: float = None,
|
||||
*,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
missing: float = None,
|
||||
silent: bool = False,
|
||||
feature_names: Optional[Union[str, List[str]]] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
max_bin: int = 256,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
feature_names: Optional[Union[str, List[str]]] = None,
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
max_bin: int = 256
|
||||
enable_categorical: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
client=client,
|
||||
data=data,
|
||||
label=label,
|
||||
missing=missing,
|
||||
feature_weights=feature_weights,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
missing=missing,
|
||||
silent=silent,
|
||||
feature_weights=feature_weights,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types
|
||||
feature_types=feature_types,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||
args = super().create_fn_args(worker_addr)
|
||||
args['max_bin'] = self.max_bin
|
||||
args["max_bin"] = self.max_bin
|
||||
return args
|
||||
|
||||
|
||||
@ -630,35 +632,49 @@ def _create_device_quantile_dmatrix(
|
||||
meta_names: List[str],
|
||||
missing: float,
|
||||
parts: Optional[_DataParts],
|
||||
max_bin: int
|
||||
max_bin: int,
|
||||
) -> DeviceQuantileDMatrix:
|
||||
worker = distributed.get_worker()
|
||||
if parts is None:
|
||||
msg = 'worker {address} has an empty DMatrix. '.format(
|
||||
address=worker.address)
|
||||
msg = "worker {address} has an empty DMatrix.".format(address=worker.address)
|
||||
LOGGER.warning(msg)
|
||||
import cupy
|
||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin)
|
||||
|
||||
d = DeviceQuantileDMatrix(
|
||||
cupy.zeros((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin,
|
||||
)
|
||||
return d
|
||||
|
||||
(data, labels, weights, base_margin, qid,
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||
parts, meta_names)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
(
|
||||
data,
|
||||
labels,
|
||||
weights,
|
||||
base_margin,
|
||||
qid,
|
||||
label_lower_bound,
|
||||
label_upper_bound,
|
||||
) = _get_worker_parts(parts, meta_names)
|
||||
it = DaskPartitionIter(
|
||||
data=data,
|
||||
label=labels,
|
||||
weight=weights,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
)
|
||||
|
||||
dmatrix = DeviceQuantileDMatrix(it,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=max_bin)
|
||||
dmatrix = DeviceQuantileDMatrix(
|
||||
it,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads,
|
||||
max_bin=max_bin,
|
||||
)
|
||||
dmatrix.set_info(feature_weights=feature_weights)
|
||||
return dmatrix
|
||||
|
||||
@ -712,13 +728,15 @@ def _create_dmatrix(
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads
|
||||
nthread=worker.nthreads,
|
||||
)
|
||||
dmatrix.set_info(
|
||||
base_margin=_base_margin, qid=_qid, weight=_weights,
|
||||
base_margin=_base_margin,
|
||||
qid=_qid,
|
||||
weight=_weights,
|
||||
label_lower_bound=_label_lower_bound,
|
||||
label_upper_bound=_label_upper_bound,
|
||||
feature_weights=feature_weights
|
||||
feature_weights=feature_weights,
|
||||
)
|
||||
return dmatrix
|
||||
|
||||
@ -753,6 +771,8 @@ def _get_workers_from_data(
|
||||
for e in evals:
|
||||
assert len(e) == 2
|
||||
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||
if e[0] is dtrain:
|
||||
continue
|
||||
worker_map = set(e[0].worker_map.keys())
|
||||
X_worker_map = X_worker_map.union(worker_map)
|
||||
return X_worker_map
|
||||
@ -960,7 +980,7 @@ async def _predict_async(
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(
|
||||
data=m,
|
||||
output_margin=output_margin,
|
||||
@ -1587,7 +1607,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
For dask implementation, group is not supported, use qid instead.
|
||||
""",
|
||||
)
|
||||
class DaskXGBRanker(DaskScikitLearnBase):
|
||||
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
@_deprecate_positional_args
|
||||
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
|
||||
if callable(objective):
|
||||
@ -1632,11 +1652,10 @@ class DaskXGBRanker(DaskScikitLearnBase):
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
raise ValueError(
|
||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||
"Custom evaluation metric is not yet supported for XGBRanker."
|
||||
)
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model,
|
||||
eval_metric=eval_metric,
|
||||
params=params
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
client=self.client,
|
||||
|
||||
@ -737,16 +737,28 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
area for meta info.
|
||||
|
||||
'''
|
||||
def __init__(self, data, label, weight, base_margin, group,
|
||||
label_lower_bound, label_upper_bound,
|
||||
feature_names, feature_types):
|
||||
def __init__(
|
||||
self, data,
|
||||
label,
|
||||
weight,
|
||||
base_margin,
|
||||
group,
|
||||
qid,
|
||||
label_lower_bound,
|
||||
label_upper_bound,
|
||||
feature_weights,
|
||||
feature_names,
|
||||
feature_types
|
||||
):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.weight = weight
|
||||
self.base_margin = base_margin
|
||||
self.group = group
|
||||
self.qid = qid
|
||||
self.label_lower_bound = label_lower_bound
|
||||
self.label_upper_bound = label_upper_bound
|
||||
self.feature_weights = feature_weights
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
@ -759,8 +771,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
input_data(data=self.data, label=self.label,
|
||||
weight=self.weight, base_margin=self.base_margin,
|
||||
group=self.group,
|
||||
qid=self.qid,
|
||||
label_lower_bound=self.label_lower_bound,
|
||||
label_upper_bound=self.label_upper_bound,
|
||||
feature_weights=self.feature_weights,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
return 1
|
||||
@ -770,7 +784,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
|
||||
|
||||
def init_device_quantile_dmatrix(
|
||||
data, missing, max_bin, threads, feature_names, feature_types, **meta):
|
||||
data, missing, max_bin, threads, feature_names, feature_types, **meta
|
||||
):
|
||||
'''Constructor for DeviceQuantileDMatrix.'''
|
||||
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
|
||||
_is_dlpack(data), _is_iter(data)]):
|
||||
|
||||
@ -556,7 +556,7 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
def _configure_fit(
|
||||
self,
|
||||
booster: Optional[Booster],
|
||||
booster: Optional[Union[Booster, "XGBModel"]],
|
||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
|
||||
@ -631,7 +631,7 @@ class XGBModel(XGBModelBase):
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
xgb_model : Union[str, Booster, XGBModel]
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||
loaded before training (allows training continuation).
|
||||
sample_weight_eval_set : list, optional
|
||||
@ -942,10 +942,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
super().__init__(objective=objective, **kwargs)
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
||||
eval_set=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
||||
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
|
||||
def fit(
|
||||
self,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
eval_metric=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True,
|
||||
xgb_model=None,
|
||||
sample_weight_eval_set=None,
|
||||
feature_weights=None,
|
||||
callbacks=None
|
||||
):
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements
|
||||
|
||||
can_use_label_encoder = True
|
||||
@ -1283,7 +1295,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self, X, y, *,
|
||||
self,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
group=None,
|
||||
qid=None,
|
||||
sample_weight=None,
|
||||
@ -1372,7 +1387,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
xgb_model : str
|
||||
xgb_model : Union[str, Booster, XGBModel]
|
||||
file name of stored XGBoost model or 'Booster' instance XGBoost
|
||||
model to be loaded before training (allows training continuation).
|
||||
feature_weights: array_like
|
||||
@ -1391,9 +1406,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
save_best=True)]
|
||||
|
||||
"""
|
||||
# check if group information is provided
|
||||
if group is None:
|
||||
raise ValueError("group is required for ranking task")
|
||||
if group is None and qid is None:
|
||||
raise ValueError("group or qid is required for ranking task")
|
||||
|
||||
if eval_set is not None:
|
||||
if eval_group is None and eval_qid is None:
|
||||
|
||||
@ -34,3 +34,25 @@ class TestDeviceQuantileDMatrix:
|
||||
import cupy as cp
|
||||
data = cp.random.randn(5, 5)
|
||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_metainfo(self) -> None:
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
|
||||
rows = 10
|
||||
cols = 3
|
||||
data = rng.randn(rows, cols)
|
||||
|
||||
labels = rng.randn(rows)
|
||||
|
||||
fw = rng.randn(rows)
|
||||
fw -= fw.min()
|
||||
|
||||
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw)
|
||||
|
||||
got_fw = m.get_float_info("feature_weights")
|
||||
got_labels = m.get_label()
|
||||
|
||||
cp.testing.assert_allclose(fw, got_fw)
|
||||
cp.testing.assert_allclose(labels, got_labels)
|
||||
|
||||
@ -6,7 +6,9 @@ import numpy as np
|
||||
import asyncio
|
||||
import xgboost
|
||||
import subprocess
|
||||
from hypothesis import given, strategies, settings, note, HealthCheck
|
||||
from collections import OrderedDict
|
||||
from inspect import signature
|
||||
from hypothesis import given, strategies, settings, note
|
||||
from hypothesis._settings import duration
|
||||
from test_gpu_updaters import parameter_strategy
|
||||
|
||||
@ -18,13 +20,15 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
from test_with_dask import suppress
|
||||
from test_with_dask import kCols as random_cols # noqa
|
||||
from test_with_dask import suppress # noqa
|
||||
import testing as tm # noqa
|
||||
|
||||
|
||||
try:
|
||||
import dask.dataframe as dd
|
||||
from xgboost import dask as dxgb
|
||||
import xgboost as xgb
|
||||
from dask.distributed import Client
|
||||
from dask import array as da
|
||||
from dask_cuda import LocalCUDACluster
|
||||
@ -252,6 +256,64 @@ class TestDistributedGPU:
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X, y, _ = generate_array()
|
||||
fw = da.random.random((random_cols, ))
|
||||
fw = fw - fw.min()
|
||||
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
with dxgb.RabitContext(rabit_args):
|
||||
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
|
||||
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||
assert fw_rows == local_dtrain.num_col()
|
||||
|
||||
futures = []
|
||||
for i in range(len(workers)):
|
||||
futures.append(client.submit(worker_fn, workers[i],
|
||||
m.create_fn_args(workers[i]), pure=False,
|
||||
workers=[workers[i]]))
|
||||
client.gather(futures)
|
||||
|
||||
def test_interface_consistency(self) -> None:
|
||||
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
|
||||
del sig["client"]
|
||||
ddm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
|
||||
del sig["client"]
|
||||
del sig["max_bin"]
|
||||
ddqdm_names = list(sig.keys())
|
||||
assert len(ddm_names) == len(ddqdm_names)
|
||||
|
||||
# between dask
|
||||
for i in range(len(ddm_names)):
|
||||
assert ddm_names[i] == ddqdm_names[i]
|
||||
|
||||
sig = OrderedDict(signature(xgb.DMatrix).parameters)
|
||||
del sig["nthread"] # no nthread in dask
|
||||
dm_names = list(sig.keys())
|
||||
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
|
||||
del sig["nthread"]
|
||||
del sig["max_bin"]
|
||||
dqdm_names = list(sig.keys())
|
||||
|
||||
# between single node
|
||||
assert len(dm_names) == len(dqdm_names)
|
||||
for i in range(len(dm_names)):
|
||||
assert dm_names[i] == dqdm_names[i]
|
||||
|
||||
# ddm <-> dm
|
||||
for i in range(len(ddm_names)):
|
||||
assert ddm_names[i] == dm_names[i]
|
||||
|
||||
# dqdm <-> ddqdm
|
||||
for i in range(len(ddqdm_names)):
|
||||
assert ddqdm_names[i] == dqdm_names[i]
|
||||
|
||||
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user