Define metainfo and other parameters for all DMatrix interfaces. (#6601)
This PR ensures all DMatrix types have a common interface. * Fix logic in avoiding duplicated DMatrix in sklearn. * Check for consistency between DMatrix types. * Add doc for bounds.
This commit is contained in:
parent
561809200a
commit
8942c98054
@ -312,7 +312,9 @@ class DataIter:
|
|||||||
data, feature_names, feature_types
|
data, feature_names, feature_types
|
||||||
)
|
)
|
||||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||||
self.proxy.set_info(label=label, weight=weight,
|
self.proxy.set_info(
|
||||||
|
label=label,
|
||||||
|
weight=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
group=group,
|
group=group,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
@ -320,7 +322,8 @@ class DataIter:
|
|||||||
label_upper_bound=label_upper_bound,
|
label_upper_bound=label_upper_bound,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
feature_weights=feature_weights)
|
feature_weights=feature_weights
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Differ the exception in order to return 0 and stop the iteration.
|
# Differ the exception in order to return 0 and stop the iteration.
|
||||||
# Exception inside a ctype callback function has no effect except
|
# Exception inside a ctype callback function has no effect except
|
||||||
@ -416,13 +419,26 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
You can construct DMatrix from multiple different sources of data.
|
You can construct DMatrix from multiple different sources of data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
@_deprecate_positional_args
|
||||||
missing=None,
|
def __init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
label=None,
|
||||||
|
*,
|
||||||
|
weight=None,
|
||||||
|
base_margin=None,
|
||||||
|
missing: Optional[float] = None,
|
||||||
silent=False,
|
silent=False,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
nthread=None,
|
nthread: Optional[int] = None,
|
||||||
enable_categorical=False):
|
group=None,
|
||||||
|
qid=None,
|
||||||
|
label_lower_bound=None,
|
||||||
|
label_upper_bound=None,
|
||||||
|
feature_weights=None,
|
||||||
|
enable_categorical: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Parameters
|
"""Parameters
|
||||||
----------
|
----------
|
||||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
||||||
@ -432,12 +448,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
libsvm format txt file, csv file (by specifying uri parameter
|
libsvm format txt file, csv file (by specifying uri parameter
|
||||||
'path_to_csv?format=csv'), or binary file that xgboost can read
|
'path_to_csv?format=csv'), or binary file that xgboost can read
|
||||||
from.
|
from.
|
||||||
label : list, numpy 1-D array or cudf.DataFrame, optional
|
label : array_like
|
||||||
Label of the training data.
|
Label of the training data.
|
||||||
missing : float, optional
|
weight : array_like
|
||||||
Value in the input data which needs to be present as a missing
|
|
||||||
value. If None, defaults to np.nan.
|
|
||||||
weight : list, numpy 1-D array or cudf.DataFrame , optional
|
|
||||||
Weight for each instance.
|
Weight for each instance.
|
||||||
|
|
||||||
.. note:: For ranking task, weights are per-group.
|
.. note:: For ranking task, weights are per-group.
|
||||||
@ -447,6 +460,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
ordering of data points within each group, so it doesn't make
|
ordering of data points within each group, so it doesn't make
|
||||||
sense to assign weights to individual data points.
|
sense to assign weights to individual data points.
|
||||||
|
|
||||||
|
base_margin: array_like
|
||||||
|
Base margin used for boosting from existing model.
|
||||||
|
missing : float, optional
|
||||||
|
Value in the input data which needs to be present as a missing
|
||||||
|
value. If None, defaults to np.nan.
|
||||||
silent : boolean, optional
|
silent : boolean, optional
|
||||||
Whether print messages during construction
|
Whether print messages during construction
|
||||||
feature_names : list, optional
|
feature_names : list, optional
|
||||||
@ -456,7 +474,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
nthread : integer, optional
|
nthread : integer, optional
|
||||||
Number of threads to use for loading data when parallelization is
|
Number of threads to use for loading data when parallelization is
|
||||||
applicable. If -1, uses maximum threads available on the system.
|
applicable. If -1, uses maximum threads available on the system.
|
||||||
|
group : array_like
|
||||||
|
Group size for all ranking group.
|
||||||
|
qid : array_like
|
||||||
|
Query ID for data samples, used for ranking.
|
||||||
|
label_lower_bound : array_like
|
||||||
|
Lower bound for survival training.
|
||||||
|
label_upper_bound : array_like
|
||||||
|
Upper bound for survival training.
|
||||||
|
feature_weights : array_like, optional
|
||||||
|
Set feature weights for column sampling.
|
||||||
enable_categorical: boolean, optional
|
enable_categorical: boolean, optional
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
@ -469,7 +496,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
raise TypeError('Input data can not be a list.')
|
raise TypeError("Input data can not be a list.")
|
||||||
|
if group is not None and qid is not None:
|
||||||
|
raise ValueError("Either one of `group` or `qid` should be None.")
|
||||||
|
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self.nthread = nthread if nthread is not None else -1
|
self.nthread = nthread if nthread is not None else -1
|
||||||
@ -481,16 +510,28 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return
|
return
|
||||||
|
|
||||||
from .data import dispatch_data_backend
|
from .data import dispatch_data_backend
|
||||||
|
|
||||||
handle, feature_names, feature_types = dispatch_data_backend(
|
handle, feature_names, feature_types = dispatch_data_backend(
|
||||||
data, missing=self.missing,
|
data,
|
||||||
|
missing=self.missing,
|
||||||
threads=self.nthread,
|
threads=self.nthread,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
enable_categorical=enable_categorical)
|
enable_categorical=enable_categorical,
|
||||||
|
)
|
||||||
assert handle is not None
|
assert handle is not None
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
|
||||||
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
self.set_info(
|
||||||
|
label=label,
|
||||||
|
weight=weight,
|
||||||
|
base_margin=base_margin,
|
||||||
|
group=group,
|
||||||
|
qid=qid,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound,
|
||||||
|
feature_weights=feature_weights,
|
||||||
|
)
|
||||||
|
|
||||||
if feature_names is not None:
|
if feature_names is not None:
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
@ -503,17 +544,23 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def set_info(self, *,
|
def set_info(
|
||||||
label=None, weight=None, base_margin=None,
|
self,
|
||||||
|
*,
|
||||||
|
label=None,
|
||||||
|
weight=None,
|
||||||
|
base_margin=None,
|
||||||
group=None,
|
group=None,
|
||||||
qid=None,
|
qid=None,
|
||||||
label_lower_bound=None,
|
label_lower_bound=None,
|
||||||
label_upper_bound=None,
|
label_upper_bound=None,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
feature_weights=None):
|
feature_weights=None
|
||||||
'''Set meta info for DMatrix.'''
|
) -> None:
|
||||||
|
"""Set meta info for DMatrix. See doc string for DMatrix constructor."""
|
||||||
from .data import dispatch_meta_backend
|
from .data import dispatch_meta_backend
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
self.set_label(label)
|
self.set_label(label)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
@ -918,39 +965,67 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
information may be lost in quantisation. This DMatrix is primarily designed
|
information may be lost in quantisation. This DMatrix is primarily designed
|
||||||
to save memory in training from device memory inputs by avoiding
|
to save memory in training from device memory inputs by avoiding
|
||||||
intermediate storage. Set max_bin to control the number of bins during
|
intermediate storage. Set max_bin to control the number of bins during
|
||||||
quantisation.
|
quantisation. See doc string in `DMatrix` for documents on meta info.
|
||||||
|
|
||||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||||
|
|
||||||
.. versionadded:: 1.1.0
|
.. versionadded:: 1.1.0
|
||||||
"""
|
"""
|
||||||
|
@_deprecate_positional_args
|
||||||
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
|
def __init__( # pylint: disable=super-init-not-called
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
label=None,
|
||||||
|
*,
|
||||||
|
weight=None,
|
||||||
base_margin=None,
|
base_margin=None,
|
||||||
missing=None,
|
missing=None,
|
||||||
silent=False,
|
silent=False,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
nthread=None, max_bin=256):
|
nthread: Optional[int] = None,
|
||||||
|
max_bin: int = 256,
|
||||||
|
group=None,
|
||||||
|
qid=None,
|
||||||
|
label_lower_bound=None,
|
||||||
|
label_upper_bound=None,
|
||||||
|
feature_weights=None,
|
||||||
|
enable_categorical: bool = False,
|
||||||
|
):
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
self.missing = missing if missing is not None else np.nan
|
self.missing = missing if missing is not None else np.nan
|
||||||
self.nthread = nthread if nthread is not None else 1
|
self.nthread = nthread if nthread is not None else 1
|
||||||
|
self._silent = silent # unused, kept for compatibility
|
||||||
|
|
||||||
if isinstance(data, ctypes.c_void_p):
|
if isinstance(data, ctypes.c_void_p):
|
||||||
self.handle = data
|
self.handle = data
|
||||||
return
|
return
|
||||||
from .data import init_device_quantile_dmatrix
|
from .data import init_device_quantile_dmatrix
|
||||||
handle, feature_names, feature_types = init_device_quantile_dmatrix(
|
handle, feature_names, feature_types = init_device_quantile_dmatrix(
|
||||||
data, missing=self.missing, threads=self.nthread,
|
data,
|
||||||
max_bin=self.max_bin,
|
|
||||||
label=label, weight=weight,
|
label=label, weight=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
group=None,
|
group=group,
|
||||||
label_lower_bound=None,
|
qid=qid,
|
||||||
label_upper_bound=None,
|
missing=self.missing,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound,
|
||||||
|
feature_weights=feature_weights,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types)
|
feature_types=feature_types,
|
||||||
|
threads=self.nthread,
|
||||||
|
max_bin=self.max_bin,
|
||||||
|
)
|
||||||
|
if enable_categorical:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'categorical support is not enabled on DeviceQuantileDMatrix.'
|
||||||
|
)
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
|
if qid is not None and group is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'Only one of the eval_qid or eval_group for each evaluation '
|
||||||
|
'dataset should be provided.'
|
||||||
|
)
|
||||||
|
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
|
|||||||
@ -38,8 +38,9 @@ from .core import Objective, Metric
|
|||||||
from .core import _deprecate_positional_args
|
from .core import _deprecate_positional_args
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker, get_host_ip
|
from .tracker import RabitTracker, get_host_ip
|
||||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||||
from .sklearn import xgboost_model_doc, _objective_decorator
|
from .sklearn import XGBRankerMixIn
|
||||||
|
from .sklearn import xgboost_model_doc
|
||||||
from .sklearn import _cls_predict_proba
|
from .sklearn import _cls_predict_proba
|
||||||
from .sklearn import XGBRanker
|
from .sklearn import XGBRanker
|
||||||
|
|
||||||
@ -180,10 +181,12 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
|||||||
|
|
||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing
|
'''DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||||
a `DaskDMatrix` forces all lazy computation to be carried out. Wait for
|
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
|
||||||
the input data explicitly if you want to see actual computation of
|
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||||
constructing `DaskDMatrix`.
|
|
||||||
|
See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only
|
||||||
|
dask collection.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -197,29 +200,6 @@ class DaskDMatrix:
|
|||||||
client :
|
client :
|
||||||
Specify the dask client used for training. Use default client returned from dask
|
Specify the dask client used for training. Use default client returned from dask
|
||||||
if it's set to None.
|
if it's set to None.
|
||||||
data :
|
|
||||||
data source of DMatrix.
|
|
||||||
label :
|
|
||||||
label used for trainin.
|
|
||||||
missing :
|
|
||||||
Value in the input data (e.g. `numpy.ndarray`) which needs to be present as a
|
|
||||||
missing value. If None, defaults to np.nan.
|
|
||||||
weight :
|
|
||||||
Weight for each instance.
|
|
||||||
base_margin :
|
|
||||||
Global bias for each instance.
|
|
||||||
qid :
|
|
||||||
Query ID for ranking.
|
|
||||||
label_lower_bound :
|
|
||||||
Upper bound for survival training.
|
|
||||||
label_upper_bound :
|
|
||||||
Lower bound for survival training.
|
|
||||||
feature_weights :
|
|
||||||
Weight for features used in column sampling.
|
|
||||||
feature_names :
|
|
||||||
Set names for features.
|
|
||||||
feature_types :
|
|
||||||
Set types for features
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@ -230,15 +210,18 @@ class DaskDMatrix:
|
|||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
label: Optional[_DaskCollection] = None,
|
label: Optional[_DaskCollection] = None,
|
||||||
*,
|
*,
|
||||||
missing: float = None,
|
|
||||||
weight: Optional[_DaskCollection] = None,
|
weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
|
missing: float = None,
|
||||||
|
silent: bool = False, # pylint: disable=unused-argument
|
||||||
|
feature_names: Optional[Union[str, List[str]]] = None,
|
||||||
|
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||||
|
group: Optional[_DaskCollection] = None,
|
||||||
qid: Optional[_DaskCollection] = None,
|
qid: Optional[_DaskCollection] = None,
|
||||||
label_lower_bound: Optional[_DaskCollection] = None,
|
label_lower_bound: Optional[_DaskCollection] = None,
|
||||||
label_upper_bound: Optional[_DaskCollection] = None,
|
label_upper_bound: Optional[_DaskCollection] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
feature_names: Optional[Union[str, List[str]]] = None,
|
enable_categorical: bool = False
|
||||||
feature_types: Optional[Union[Any, List[Any]]] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
@ -248,30 +231,41 @@ class DaskDMatrix:
|
|||||||
self.missing = missing
|
self.missing = missing
|
||||||
|
|
||||||
if qid is not None and weight is not None:
|
if qid is not None and weight is not None:
|
||||||
raise NotImplementedError('per-group weight is not implemented.')
|
raise NotImplementedError("per-group weight is not implemented.")
|
||||||
|
if group is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"group structure is not implemented, use qid instead."
|
||||||
|
)
|
||||||
|
if enable_categorical:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"categorical support is not enabled on `DaskDMatrix`."
|
||||||
|
)
|
||||||
|
|
||||||
if len(data.shape) != 2:
|
if len(data.shape) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Expecting 2 dimensional input, got: {shape}'.format(
|
"Expecting 2 dimensional input, got: {shape}".format(shape=data.shape)
|
||||||
shape=data.shape))
|
)
|
||||||
|
|
||||||
if not isinstance(data, (dd.DataFrame, da.Array)):
|
if not isinstance(data, (dd.DataFrame, da.Array)):
|
||||||
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
|
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
|
||||||
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series,
|
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
|
||||||
type(None))):
|
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||||
raise TypeError(
|
|
||||||
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
|
||||||
|
|
||||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||||
self.is_quantile: bool = False
|
self.is_quantile: bool = False
|
||||||
|
|
||||||
self._init = client.sync(self.map_local_data,
|
self._init = client.sync(
|
||||||
client, data, label=label, weights=weight,
|
self.map_local_data,
|
||||||
|
client,
|
||||||
|
data,
|
||||||
|
label=label,
|
||||||
|
weights=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
label_upper_bound=label_upper_bound)
|
label_upper_bound=label_upper_bound,
|
||||||
|
)
|
||||||
|
|
||||||
def __await__(self) -> Generator:
|
def __await__(self) -> Generator:
|
||||||
return self._init.__await__()
|
return self._init.__await__()
|
||||||
@ -571,11 +565,11 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
|
|
||||||
|
|
||||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||||
'''Specialized data type for `gpu_hist` tree method. This class is used to
|
'''Specialized data type for `gpu_hist` tree method. This class is used to reduce the
|
||||||
reduce the memory usage by eliminating data copies. Internally the all
|
memory usage by eliminating data copies. Internally the all partitions/chunks of data
|
||||||
partitions/chunks of data are merged by weighted GK sketching. So the
|
are merged by weighted GK sketching. So the number of partitions from dask may affect
|
||||||
number of partitions from dask may affect training accuracy as GK generates
|
training accuracy as GK generates bounded error for each merge. See doc string for
|
||||||
bounded error for each merge.
|
`DeviceQuantileDMatrix` and `DMatrix` for other parameters.
|
||||||
|
|
||||||
.. versionadded:: 1.2.0
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
@ -584,42 +578,50 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
max_bin : Number of bins for histogram construction.
|
max_bin : Number of bins for histogram construction.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@_deprecate_positional_args
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
label: Optional[_DaskCollection] = None,
|
label: Optional[_DaskCollection] = None,
|
||||||
missing: float = None,
|
*,
|
||||||
weight: Optional[_DaskCollection] = None,
|
weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
|
missing: float = None,
|
||||||
|
silent: bool = False,
|
||||||
|
feature_names: Optional[Union[str, List[str]]] = None,
|
||||||
|
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||||
|
max_bin: int = 256,
|
||||||
|
group: Optional[_DaskCollection] = None,
|
||||||
qid: Optional[_DaskCollection] = None,
|
qid: Optional[_DaskCollection] = None,
|
||||||
label_lower_bound: Optional[_DaskCollection] = None,
|
label_lower_bound: Optional[_DaskCollection] = None,
|
||||||
label_upper_bound: Optional[_DaskCollection] = None,
|
label_upper_bound: Optional[_DaskCollection] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
feature_names: Optional[Union[str, List[str]]] = None,
|
enable_categorical: bool = False,
|
||||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
|
||||||
max_bin: int = 256
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
client=client,
|
client=client,
|
||||||
data=data,
|
data=data,
|
||||||
label=label,
|
label=label,
|
||||||
missing=missing,
|
|
||||||
feature_weights=feature_weights,
|
|
||||||
weight=weight,
|
weight=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
|
group=group,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
label_upper_bound=label_upper_bound,
|
label_upper_bound=label_upper_bound,
|
||||||
|
missing=missing,
|
||||||
|
silent=silent,
|
||||||
|
feature_weights=feature_weights,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types
|
feature_types=feature_types,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
self.is_quantile = True
|
self.is_quantile = True
|
||||||
|
|
||||||
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
def create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
|
||||||
args = super().create_fn_args(worker_addr)
|
args = super().create_fn_args(worker_addr)
|
||||||
args['max_bin'] = self.max_bin
|
args["max_bin"] = self.max_bin
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@ -630,35 +632,49 @@ def _create_device_quantile_dmatrix(
|
|||||||
meta_names: List[str],
|
meta_names: List[str],
|
||||||
missing: float,
|
missing: float,
|
||||||
parts: Optional[_DataParts],
|
parts: Optional[_DataParts],
|
||||||
max_bin: int
|
max_bin: int,
|
||||||
) -> DeviceQuantileDMatrix:
|
) -> DeviceQuantileDMatrix:
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
if parts is None:
|
if parts is None:
|
||||||
msg = 'worker {address} has an empty DMatrix. '.format(
|
msg = "worker {address} has an empty DMatrix.".format(address=worker.address)
|
||||||
address=worker.address)
|
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
import cupy
|
import cupy
|
||||||
d = DeviceQuantileDMatrix(cupy.zeros((0, 0)),
|
|
||||||
|
d = DeviceQuantileDMatrix(
|
||||||
|
cupy.zeros((0, 0)),
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
max_bin=max_bin)
|
max_bin=max_bin,
|
||||||
|
)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
(data, labels, weights, base_margin, qid,
|
(
|
||||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
data,
|
||||||
parts, meta_names)
|
labels,
|
||||||
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
weights,
|
||||||
|
base_margin,
|
||||||
|
qid,
|
||||||
|
label_lower_bound,
|
||||||
|
label_upper_bound,
|
||||||
|
) = _get_worker_parts(parts, meta_names)
|
||||||
|
it = DaskPartitionIter(
|
||||||
|
data=data,
|
||||||
|
label=labels,
|
||||||
|
weight=weights,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
label_upper_bound=label_upper_bound)
|
label_upper_bound=label_upper_bound,
|
||||||
|
)
|
||||||
|
|
||||||
dmatrix = DeviceQuantileDMatrix(it,
|
dmatrix = DeviceQuantileDMatrix(
|
||||||
|
it,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads,
|
nthread=worker.nthreads,
|
||||||
max_bin=max_bin)
|
max_bin=max_bin,
|
||||||
|
)
|
||||||
dmatrix.set_info(feature_weights=feature_weights)
|
dmatrix.set_info(feature_weights=feature_weights)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
@ -712,13 +728,15 @@ def _create_dmatrix(
|
|||||||
missing=missing,
|
missing=missing,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads
|
nthread=worker.nthreads,
|
||||||
)
|
)
|
||||||
dmatrix.set_info(
|
dmatrix.set_info(
|
||||||
base_margin=_base_margin, qid=_qid, weight=_weights,
|
base_margin=_base_margin,
|
||||||
|
qid=_qid,
|
||||||
|
weight=_weights,
|
||||||
label_lower_bound=_label_lower_bound,
|
label_lower_bound=_label_lower_bound,
|
||||||
label_upper_bound=_label_upper_bound,
|
label_upper_bound=_label_upper_bound,
|
||||||
feature_weights=feature_weights
|
feature_weights=feature_weights,
|
||||||
)
|
)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
@ -753,6 +771,8 @@ def _get_workers_from_data(
|
|||||||
for e in evals:
|
for e in evals:
|
||||||
assert len(e) == 2
|
assert len(e) == 2
|
||||||
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
|
||||||
|
if e[0] is dtrain:
|
||||||
|
continue
|
||||||
worker_map = set(e[0].worker_map.keys())
|
worker_map = set(e[0].worker_map.keys())
|
||||||
X_worker_map = X_worker_map.union(worker_map)
|
X_worker_map = X_worker_map.union(worker_map)
|
||||||
return X_worker_map
|
return X_worker_map
|
||||||
@ -960,7 +980,7 @@ async def _predict_async(
|
|||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({'nthread': worker.nthreads})
|
||||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=m,
|
data=m,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
@ -1587,7 +1607,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
For dask implementation, group is not supported, use qid instead.
|
For dask implementation, group is not supported, use qid instead.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
class DaskXGBRanker(DaskScikitLearnBase):
|
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
|
def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
|
||||||
if callable(objective):
|
if callable(objective):
|
||||||
@ -1632,11 +1652,10 @@ class DaskXGBRanker(DaskScikitLearnBase):
|
|||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
if callable(eval_metric):
|
if callable(eval_metric):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
"Custom evaluation metric is not yet supported for XGBRanker."
|
||||||
|
)
|
||||||
model, metric, params = self._configure_fit(
|
model, metric, params = self._configure_fit(
|
||||||
booster=xgb_model,
|
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||||
eval_metric=eval_metric,
|
|
||||||
params=params
|
|
||||||
)
|
)
|
||||||
results = await train(
|
results = await train(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
|||||||
@ -737,16 +737,28 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
area for meta info.
|
area for meta info.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, data, label, weight, base_margin, group,
|
def __init__(
|
||||||
label_lower_bound, label_upper_bound,
|
self, data,
|
||||||
feature_names, feature_types):
|
label,
|
||||||
|
weight,
|
||||||
|
base_margin,
|
||||||
|
group,
|
||||||
|
qid,
|
||||||
|
label_lower_bound,
|
||||||
|
label_upper_bound,
|
||||||
|
feature_weights,
|
||||||
|
feature_names,
|
||||||
|
feature_types
|
||||||
|
):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.label = label
|
self.label = label
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.base_margin = base_margin
|
self.base_margin = base_margin
|
||||||
self.group = group
|
self.group = group
|
||||||
|
self.qid = qid
|
||||||
self.label_lower_bound = label_lower_bound
|
self.label_lower_bound = label_lower_bound
|
||||||
self.label_upper_bound = label_upper_bound
|
self.label_upper_bound = label_upper_bound
|
||||||
|
self.feature_weights = feature_weights
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
self.it = 0 # pylint: disable=invalid-name
|
self.it = 0 # pylint: disable=invalid-name
|
||||||
@ -759,8 +771,10 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
input_data(data=self.data, label=self.label,
|
input_data(data=self.data, label=self.label,
|
||||||
weight=self.weight, base_margin=self.base_margin,
|
weight=self.weight, base_margin=self.base_margin,
|
||||||
group=self.group,
|
group=self.group,
|
||||||
|
qid=self.qid,
|
||||||
label_lower_bound=self.label_lower_bound,
|
label_lower_bound=self.label_lower_bound,
|
||||||
label_upper_bound=self.label_upper_bound,
|
label_upper_bound=self.label_upper_bound,
|
||||||
|
feature_weights=self.feature_weights,
|
||||||
feature_names=self.feature_names,
|
feature_names=self.feature_names,
|
||||||
feature_types=self.feature_types)
|
feature_types=self.feature_types)
|
||||||
return 1
|
return 1
|
||||||
@ -770,7 +784,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
|
|
||||||
|
|
||||||
def init_device_quantile_dmatrix(
|
def init_device_quantile_dmatrix(
|
||||||
data, missing, max_bin, threads, feature_names, feature_types, **meta):
|
data, missing, max_bin, threads, feature_names, feature_types, **meta
|
||||||
|
):
|
||||||
'''Constructor for DeviceQuantileDMatrix.'''
|
'''Constructor for DeviceQuantileDMatrix.'''
|
||||||
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
|
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
|
||||||
_is_dlpack(data), _is_iter(data)]):
|
_is_dlpack(data), _is_iter(data)]):
|
||||||
|
|||||||
@ -556,7 +556,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _configure_fit(
|
def _configure_fit(
|
||||||
self,
|
self,
|
||||||
booster: Optional[Booster],
|
booster: Optional[Union[Booster, "XGBModel"]],
|
||||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
|
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
|
||||||
@ -631,7 +631,7 @@ class XGBModel(XGBModelBase):
|
|||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
xgb_model : str
|
xgb_model : Union[str, Booster, XGBModel]
|
||||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||||
loaded before training (allows training continuation).
|
loaded before training (allows training continuation).
|
||||||
sample_weight_eval_set : list, optional
|
sample_weight_eval_set : list, optional
|
||||||
@ -942,10 +942,22 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
super().__init__(objective=objective, **kwargs)
|
super().__init__(objective=objective, **kwargs)
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
def fit(
|
||||||
eval_set=None, eval_metric=None,
|
self,
|
||||||
early_stopping_rounds=None, verbose=True, xgb_model=None,
|
X,
|
||||||
sample_weight_eval_set=None, feature_weights=None, callbacks=None):
|
y,
|
||||||
|
*,
|
||||||
|
sample_weight=None,
|
||||||
|
base_margin=None,
|
||||||
|
eval_set=None,
|
||||||
|
eval_metric=None,
|
||||||
|
early_stopping_rounds=None,
|
||||||
|
verbose=True,
|
||||||
|
xgb_model=None,
|
||||||
|
sample_weight_eval_set=None,
|
||||||
|
feature_weights=None,
|
||||||
|
callbacks=None
|
||||||
|
):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements
|
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements
|
||||||
|
|
||||||
can_use_label_encoder = True
|
can_use_label_encoder = True
|
||||||
@ -1283,7 +1295,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self, X, y, *,
|
self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
*,
|
||||||
group=None,
|
group=None,
|
||||||
qid=None,
|
qid=None,
|
||||||
sample_weight=None,
|
sample_weight=None,
|
||||||
@ -1372,7 +1387,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation
|
||||||
metric measured on the validation set to stderr.
|
metric measured on the validation set to stderr.
|
||||||
xgb_model : str
|
xgb_model : Union[str, Booster, XGBModel]
|
||||||
file name of stored XGBoost model or 'Booster' instance XGBoost
|
file name of stored XGBoost model or 'Booster' instance XGBoost
|
||||||
model to be loaded before training (allows training continuation).
|
model to be loaded before training (allows training continuation).
|
||||||
feature_weights: array_like
|
feature_weights: array_like
|
||||||
@ -1391,9 +1406,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
save_best=True)]
|
save_best=True)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check if group information is provided
|
if group is None and qid is None:
|
||||||
if group is None:
|
raise ValueError("group or qid is required for ranking task")
|
||||||
raise ValueError("group is required for ranking task")
|
|
||||||
|
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
if eval_group is None and eval_qid is None:
|
if eval_group is None and eval_qid is None:
|
||||||
|
|||||||
@ -34,3 +34,25 @@ class TestDeviceQuantileDMatrix:
|
|||||||
import cupy as cp
|
import cupy as cp
|
||||||
data = cp.random.randn(5, 5)
|
data = cp.random.randn(5, 5)
|
||||||
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
|
def test_metainfo(self) -> None:
|
||||||
|
import cupy as cp
|
||||||
|
rng = cp.random.RandomState(1994)
|
||||||
|
|
||||||
|
rows = 10
|
||||||
|
cols = 3
|
||||||
|
data = rng.randn(rows, cols)
|
||||||
|
|
||||||
|
labels = rng.randn(rows)
|
||||||
|
|
||||||
|
fw = rng.randn(rows)
|
||||||
|
fw -= fw.min()
|
||||||
|
|
||||||
|
m = xgb.DeviceQuantileDMatrix(data=data, label=labels, feature_weights=fw)
|
||||||
|
|
||||||
|
got_fw = m.get_float_info("feature_weights")
|
||||||
|
got_labels = m.get_label()
|
||||||
|
|
||||||
|
cp.testing.assert_allclose(fw, got_fw)
|
||||||
|
cp.testing.assert_allclose(labels, got_labels)
|
||||||
|
|||||||
@ -6,7 +6,9 @@ import numpy as np
|
|||||||
import asyncio
|
import asyncio
|
||||||
import xgboost
|
import xgboost
|
||||||
import subprocess
|
import subprocess
|
||||||
from hypothesis import given, strategies, settings, note, HealthCheck
|
from collections import OrderedDict
|
||||||
|
from inspect import signature
|
||||||
|
from hypothesis import given, strategies, settings, note
|
||||||
from hypothesis._settings import duration
|
from hypothesis._settings import duration
|
||||||
from test_gpu_updaters import parameter_strategy
|
from test_gpu_updaters import parameter_strategy
|
||||||
|
|
||||||
@ -18,13 +20,15 @@ from test_with_dask import run_empty_dmatrix_reg # noqa
|
|||||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||||
from test_with_dask import _get_client_workers # noqa
|
from test_with_dask import _get_client_workers # noqa
|
||||||
from test_with_dask import generate_array # noqa
|
from test_with_dask import generate_array # noqa
|
||||||
from test_with_dask import suppress
|
from test_with_dask import kCols as random_cols # noqa
|
||||||
|
from test_with_dask import suppress # noqa
|
||||||
import testing as tm # noqa
|
import testing as tm # noqa
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
from xgboost import dask as dxgb
|
from xgboost import dask as dxgb
|
||||||
|
import xgboost as xgb
|
||||||
from dask.distributed import Client
|
from dask.distributed import Client
|
||||||
from dask import array as da
|
from dask import array as da
|
||||||
from dask_cuda import LocalCUDACluster
|
from dask_cuda import LocalCUDACluster
|
||||||
@ -252,6 +256,64 @@ class TestDistributedGPU:
|
|||||||
run_empty_dmatrix_reg(client, parameters)
|
run_empty_dmatrix_reg(client, parameters)
|
||||||
run_empty_dmatrix_cls(client, parameters)
|
run_empty_dmatrix_cls(client, parameters)
|
||||||
|
|
||||||
|
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||||
|
with Client(local_cuda_cluster) as client:
|
||||||
|
X, y, _ = generate_array()
|
||||||
|
fw = da.random.random((random_cols, ))
|
||||||
|
fw = fw - fw.min()
|
||||||
|
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
||||||
|
|
||||||
|
workers = list(_get_client_workers(client).keys())
|
||||||
|
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
||||||
|
|
||||||
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
|
with dxgb.RabitContext(rabit_args):
|
||||||
|
local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref)
|
||||||
|
fw_rows = local_dtrain.get_float_info("feature_weights").shape[0]
|
||||||
|
assert fw_rows == local_dtrain.num_col()
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for i in range(len(workers)):
|
||||||
|
futures.append(client.submit(worker_fn, workers[i],
|
||||||
|
m.create_fn_args(workers[i]), pure=False,
|
||||||
|
workers=[workers[i]]))
|
||||||
|
client.gather(futures)
|
||||||
|
|
||||||
|
def test_interface_consistency(self) -> None:
|
||||||
|
sig = OrderedDict(signature(dxgb.DaskDMatrix).parameters)
|
||||||
|
del sig["client"]
|
||||||
|
ddm_names = list(sig.keys())
|
||||||
|
sig = OrderedDict(signature(dxgb.DaskDeviceQuantileDMatrix).parameters)
|
||||||
|
del sig["client"]
|
||||||
|
del sig["max_bin"]
|
||||||
|
ddqdm_names = list(sig.keys())
|
||||||
|
assert len(ddm_names) == len(ddqdm_names)
|
||||||
|
|
||||||
|
# between dask
|
||||||
|
for i in range(len(ddm_names)):
|
||||||
|
assert ddm_names[i] == ddqdm_names[i]
|
||||||
|
|
||||||
|
sig = OrderedDict(signature(xgb.DMatrix).parameters)
|
||||||
|
del sig["nthread"] # no nthread in dask
|
||||||
|
dm_names = list(sig.keys())
|
||||||
|
sig = OrderedDict(signature(xgb.DeviceQuantileDMatrix).parameters)
|
||||||
|
del sig["nthread"]
|
||||||
|
del sig["max_bin"]
|
||||||
|
dqdm_names = list(sig.keys())
|
||||||
|
|
||||||
|
# between single node
|
||||||
|
assert len(dm_names) == len(dqdm_names)
|
||||||
|
for i in range(len(dm_names)):
|
||||||
|
assert dm_names[i] == dqdm_names[i]
|
||||||
|
|
||||||
|
# ddm <-> dm
|
||||||
|
for i in range(len(ddm_names)):
|
||||||
|
assert ddm_names[i] == dm_names[i]
|
||||||
|
|
||||||
|
# dqdm <-> ddqdm
|
||||||
|
for i in range(len(ddqdm_names)):
|
||||||
|
assert ddqdm_names[i] == dqdm_names[i]
|
||||||
|
|
||||||
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows")
|
pytest.skip("Skipping dask tests on Windows")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user