[dask] Add DaskXGBRanker (#6576)
* Initial support for distributed LTR using dask. * Support `qid` in libxgboost. * Refactor `predict` and `n_features_in_`, `best_[score/iteration/ntree_limit]` to avoid duplicated code. * Define `DaskXGBRanker`. The dask ranker doesn't support group structure, instead it uses query id and convert to group ptr internally.
This commit is contained in:
parent
96d3d32265
commit
80065d571e
@ -165,7 +165,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Get the number of features of the booster.
|
||||
* \return number of features
|
||||
*/
|
||||
virtual uint32_t GetNumFeature() = 0;
|
||||
virtual uint32_t GetNumFeature() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Set additional attribute to the Booster.
|
||||
|
||||
@ -321,6 +321,7 @@ class DataIter:
|
||||
|
||||
def data_handle(data, label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None, label_upper_bound=None,
|
||||
feature_names=None, feature_types=None,
|
||||
feature_weights=None):
|
||||
@ -333,6 +334,7 @@ class DataIter:
|
||||
self.proxy.set_info(label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
@ -523,12 +525,14 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
def set_info(self, *,
|
||||
label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
feature_weights=None):
|
||||
'''Set meta info for DMatrix.'''
|
||||
from .data import dispatch_meta_backend
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
@ -537,6 +541,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.set_base_margin(base_margin)
|
||||
if group is not None:
|
||||
self.set_group(group)
|
||||
if qid is not None:
|
||||
dispatch_meta_backend(matrix=self, data=qid, name='qid')
|
||||
if label_lower_bound is not None:
|
||||
self.set_float_info('label_lower_bound', label_lower_bound)
|
||||
if label_upper_bound is not None:
|
||||
@ -546,7 +552,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
if feature_weights is not None:
|
||||
from .data import dispatch_meta_backend
|
||||
dispatch_meta_backend(matrix=self, data=feature_weights,
|
||||
name='feature_weights')
|
||||
|
||||
@ -993,7 +998,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
|
||||
|
||||
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||
|
||||
|
||||
class Booster(object):
|
||||
@ -1743,10 +1748,16 @@ class Booster(object):
|
||||
'''
|
||||
rounds = ctypes.c_int()
|
||||
assert self.handle is not None
|
||||
_check_call(_LIB.XGBoosterBoostedRounds(
|
||||
self.handle, ctypes.byref(rounds)))
|
||||
_check_call(_LIB.XGBoosterBoostedRounds(self.handle, ctypes.byref(rounds)))
|
||||
return rounds.value
|
||||
|
||||
def num_features(self) -> int:
|
||||
'''Number of features in booster.'''
|
||||
features = ctypes.c_int()
|
||||
assert self.handle is not None
|
||||
_check_call(_LIB.XGBoosterGetNumFeature(self.handle, ctypes.byref(features)))
|
||||
return features.value
|
||||
|
||||
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
||||
"""Dump model into a text or JSON file. Unlike `save_model`, the
|
||||
output format is primarily used for visualization or interpretation,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
|
||||
# pylint: disable=missing-class-docstring, invalid-name
|
||||
# pylint: disable=too-many-lines
|
||||
# pylint: disable=too-many-lines, fixme
|
||||
# pylint: disable=import-error
|
||||
"""Dask extensions for distributed training. See
|
||||
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
|
||||
@ -41,6 +41,7 @@ from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
||||
from .sklearn import xgboost_model_doc
|
||||
from .sklearn import _cls_predict_proba
|
||||
from .sklearn import XGBRanker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -207,6 +208,8 @@ class DaskDMatrix:
|
||||
Weight for each instance.
|
||||
base_margin :
|
||||
Global bias for each instance.
|
||||
qid :
|
||||
Query ID for ranking.
|
||||
label_lower_bound :
|
||||
Upper bound for survival training.
|
||||
label_upper_bound :
|
||||
@ -220,14 +223,17 @@ class DaskDMatrix:
|
||||
|
||||
'''
|
||||
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
missing: float = None,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -241,6 +247,9 @@ class DaskDMatrix:
|
||||
self.feature_types = feature_types
|
||||
self.missing = missing
|
||||
|
||||
if qid is not None and weight is not None:
|
||||
raise NotImplementedError('per-group weight is not implemented.')
|
||||
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError(
|
||||
'Expecting 2 dimensional input, got: {shape}'.format(
|
||||
@ -259,6 +268,7 @@ class DaskDMatrix:
|
||||
self._init = client.sync(self.map_local_data,
|
||||
client, data, label=label, weights=weight,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
feature_weights=feature_weights,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
@ -273,6 +283,7 @@ class DaskDMatrix:
|
||||
label: Optional[_DaskCollection] = None,
|
||||
weights: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None
|
||||
@ -325,6 +336,7 @@ class DaskDMatrix:
|
||||
y_parts = flatten_meta(label)
|
||||
w_parts = flatten_meta(weights)
|
||||
margin_parts = flatten_meta(base_margin)
|
||||
qid_parts = flatten_meta(qid)
|
||||
ll_parts = flatten_meta(label_lower_bound)
|
||||
lu_parts = flatten_meta(label_upper_bound)
|
||||
|
||||
@ -343,6 +355,7 @@ class DaskDMatrix:
|
||||
append_meta(y_parts, 'labels')
|
||||
append_meta(w_parts, 'weights')
|
||||
append_meta(margin_parts, 'base_margin')
|
||||
append_meta(qid_parts, 'qid')
|
||||
append_meta(ll_parts, 'label_lower_bound')
|
||||
append_meta(lu_parts, 'label_upper_bound')
|
||||
# At this point, `parts` looks like:
|
||||
@ -397,7 +410,7 @@ class DaskDMatrix:
|
||||
|
||||
|
||||
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
|
||||
Optional[Any]]]
|
||||
Optional[Any], Optional[Any]]]
|
||||
|
||||
|
||||
def _get_worker_parts_ordered(
|
||||
@ -413,6 +426,7 @@ def _get_worker_parts_ordered(
|
||||
labels = None
|
||||
weights = None
|
||||
base_margin = None
|
||||
qid = None
|
||||
label_lower_bound = None
|
||||
label_upper_bound = None
|
||||
# Iterate through all possible meta info, brings small overhead as in xgboost
|
||||
@ -424,13 +438,15 @@ def _get_worker_parts_ordered(
|
||||
weights = blob
|
||||
elif meta_names[j] == 'base_margin':
|
||||
base_margin = blob
|
||||
elif meta_names[j] == 'qid':
|
||||
qid = blob
|
||||
elif meta_names[j] == 'label_lower_bound':
|
||||
label_lower_bound = blob
|
||||
elif meta_names[j] == 'label_upper_bound':
|
||||
label_upper_bound = blob
|
||||
else:
|
||||
raise ValueError('Unknown metainfo:', meta_names[j])
|
||||
result.append((data, labels, weights, base_margin, label_lower_bound,
|
||||
result.append((data, labels, weights, base_margin, qid, label_lower_bound,
|
||||
label_upper_bound))
|
||||
return result
|
||||
|
||||
@ -456,6 +472,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
label: Optional[Tuple[Any, ...]] = None,
|
||||
weight: Optional[Tuple[Any, ...]] = None,
|
||||
base_margin: Optional[Tuple[Any, ...]] = None,
|
||||
qid: Optional[Tuple[Any, ...]] = None,
|
||||
label_lower_bound: Optional[Tuple[Any, ...]] = None,
|
||||
label_upper_bound: Optional[Tuple[Any, ...]] = None,
|
||||
feature_names: Optional[Union[str, List[str]]] = None,
|
||||
@ -465,6 +482,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
self._labels = label
|
||||
self._weights = weight
|
||||
self._base_margin = base_margin
|
||||
self._qid = qid
|
||||
self._label_lower_bound = label_lower_bound
|
||||
self._label_upper_bound = label_upper_bound
|
||||
self._feature_names = feature_names
|
||||
@ -498,6 +516,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
return self._weights[self._iter]
|
||||
return None
|
||||
|
||||
def qids(self) -> Any:
|
||||
'''Utility function for obtaining current batch of query id.'''
|
||||
if self._qid is not None:
|
||||
return self._qid[self._iter]
|
||||
return None
|
||||
|
||||
def base_margins(self) -> Any:
|
||||
'''Utility function for obtaining current batch of base_margin.'''
|
||||
if self._base_margin is not None:
|
||||
@ -537,6 +561,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
feature_names = None
|
||||
input_data(data=self.data(), label=self.labels(),
|
||||
weight=self.weights(), group=None,
|
||||
qid=self.qids(),
|
||||
label_lower_bound=self.label_lower_bounds(),
|
||||
label_upper_bound=self.label_upper_bounds(),
|
||||
feature_names=feature_names,
|
||||
@ -567,6 +592,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
missing: float = None,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
@ -574,13 +600,20 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||
max_bin: int = 256
|
||||
) -> None:
|
||||
super().__init__(client=client, data=data, label=label,
|
||||
missing=missing,
|
||||
weight=weight, base_margin=base_margin,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
super().__init__(
|
||||
client=client,
|
||||
data=data,
|
||||
label=label,
|
||||
missing=missing,
|
||||
feature_weights=feature_weights,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types
|
||||
)
|
||||
self.max_bin = max_bin
|
||||
self.is_quantile = True
|
||||
|
||||
@ -611,11 +644,12 @@ def _create_device_quantile_dmatrix(
|
||||
max_bin=max_bin)
|
||||
return d
|
||||
|
||||
(data, labels, weights, base_margin,
|
||||
(data, labels, weights, base_margin, qid,
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||
parts, meta_names)
|
||||
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
||||
base_margin=base_margin,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
|
||||
@ -661,26 +695,31 @@ def _create_dmatrix(
|
||||
return None
|
||||
return concat(data)
|
||||
|
||||
(data, labels, weights, base_margin,
|
||||
(data, labels, weights, base_margin, qid,
|
||||
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
|
||||
|
||||
_labels = concat_or_none(labels)
|
||||
_weights = concat_or_none(weights)
|
||||
_base_margin = concat_or_none(base_margin)
|
||||
_qid = concat_or_none(qid)
|
||||
_label_lower_bound = concat_or_none(label_lower_bound)
|
||||
_label_upper_bound = concat_or_none(label_upper_bound)
|
||||
|
||||
_data = concat(data)
|
||||
dmatrix = DMatrix(_data,
|
||||
_labels,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads)
|
||||
dmatrix.set_info(base_margin=_base_margin, weight=_weights,
|
||||
label_lower_bound=_label_lower_bound,
|
||||
label_upper_bound=_label_upper_bound,
|
||||
feature_weights=feature_weights)
|
||||
dmatrix = DMatrix(
|
||||
_data,
|
||||
_labels,
|
||||
missing=missing,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=worker.nthreads
|
||||
)
|
||||
dmatrix.set_info(
|
||||
base_margin=_base_margin, qid=_qid, weight=_weights,
|
||||
label_lower_bound=_label_lower_bound,
|
||||
label_upper_bound=_label_upper_bound,
|
||||
feature_weights=feature_weights
|
||||
)
|
||||
return dmatrix
|
||||
|
||||
|
||||
@ -746,7 +785,7 @@ async def _train_async(
|
||||
'''Perform training on a single worker. A local function prevents pickling.
|
||||
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
LOGGER.debug('Training on %s', str(worker_addr))
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
|
||||
@ -954,7 +993,7 @@ async def _predict_async(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]:
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
LOGGER.debug('Predicting on %d', worker_id)
|
||||
with config.config_context(**global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
@ -962,7 +1001,7 @@ async def _predict_async(
|
||||
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _) = parts
|
||||
(data, _, _, base_margin, _, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
local_part = DMatrix(
|
||||
data,
|
||||
@ -991,11 +1030,11 @@ async def _predict_async(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[int, int]]:
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Get shape on %d', worker_id)
|
||||
LOGGER.debug('Get shape on %d', worker_id)
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
shapes = []
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, _, _, _) = parts
|
||||
(data, _, _, _, _, _, _) = parts
|
||||
shapes.append((data.shape, list_of_orders[i]))
|
||||
return shapes
|
||||
|
||||
@ -1182,6 +1221,7 @@ async def _evaluation_matrices(
|
||||
client: "distributed.Client",
|
||||
validation_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
sample_weight: Optional[List[_DaskCollection]],
|
||||
sample_qid: Optional[List[_DaskCollection]],
|
||||
missing: float
|
||||
) -> Optional[List[Tuple[DaskDMatrix, str]]]:
|
||||
'''
|
||||
@ -1206,9 +1246,10 @@ async def _evaluation_matrices(
|
||||
if validation_set is not None:
|
||||
assert isinstance(validation_set, list)
|
||||
for i, e in enumerate(validation_set):
|
||||
w = (sample_weight[i] if sample_weight is not None else None)
|
||||
w = sample_weight[i] if sample_weight is not None else None
|
||||
qid = sample_qid[i] if sample_qid is not None else None
|
||||
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
||||
weight=w, missing=missing)
|
||||
weight=w, missing=missing, qid=qid)
|
||||
assert isinstance(evals, list)
|
||||
evals.append((dmat, 'validation_{}'.format(i)))
|
||||
else:
|
||||
@ -1223,61 +1264,21 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: List[Tuple[_DaskCollection, _DaskCollection]] = None,
|
||||
eval_metric: Optional[Callable] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: List[TrainingCallback] = None
|
||||
) -> "DaskScikitLearnBase":
|
||||
'''Fit gradient boosting model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix
|
||||
y : array_like
|
||||
Labels
|
||||
sample_weight : array_like
|
||||
instance weights
|
||||
eval_set : list, optional
|
||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||
metrics will be computed.
|
||||
Validation metrics will help us track the performance of the model.
|
||||
eval_metric : str, list of str, or callable, optional
|
||||
sample_weight_eval_set : list, optional
|
||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
|
||||
of group weights on the i-th validation set.
|
||||
early_stopping_rounds : int
|
||||
Activates early stopping.
|
||||
verbose : bool
|
||||
If `verbose` and an evaluation set is used, writes the evaluation
|
||||
metric measured on the validation set to stderr.
|
||||
feature_weights: array_like
|
||||
Weight for each feature, defines the probability of each feature being
|
||||
selected when colsample is being used. All values must be greater than 0,
|
||||
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
|
||||
`exact` tree methods.
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
It is possible to use predefined callbacks by using :ref:`callback_api`.
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)]
|
||||
|
||||
'''
|
||||
raise NotImplementedError
|
||||
async def _predict_async(
|
||||
self, data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> Any:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
client=self.client, data=data, base_margin=base_margin,
|
||||
missing=self.missing
|
||||
)
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features)
|
||||
return pred_probs
|
||||
|
||||
def predict(
|
||||
self,
|
||||
@ -1287,21 +1288,13 @@ class DaskScikitLearnBase(XGBModel):
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> Any:
|
||||
'''Predict with `data`.
|
||||
Parameters
|
||||
----------
|
||||
data: data that can be used to construct a DaskDMatrix
|
||||
output_margin : Whether to output the raw untransformed margin value.
|
||||
ntree_limit : NOT supported on dask interface.
|
||||
validate_features :
|
||||
When this is True, validate that the Booster's and data's feature_names are
|
||||
identical. Otherwise, it is assumed that the feature_names are the same.
|
||||
Returns
|
||||
-------
|
||||
prediction:
|
||||
|
||||
'''
|
||||
raise NotImplementedError
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(self._predict_async, data,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin)
|
||||
|
||||
def __await__(self) -> Awaitable[Any]:
|
||||
# Generate a coroutine wrapper to make this class awaitable.
|
||||
@ -1320,59 +1313,63 @@ class DaskScikitLearnBase(XGBModel):
|
||||
self._client = clt
|
||||
|
||||
|
||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||
['estimators', 'model'])
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
|
||||
)
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(
|
||||
self, X: _DaskCollection,
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, List[str], Callable]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
early_stopping_rounds: int,
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]]
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
) -> _DaskCollection:
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=self.missing)
|
||||
dtrain = await DaskDMatrix(
|
||||
client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=self.missing,
|
||||
)
|
||||
params = self.get_xgb_params()
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
self.missing)
|
||||
evals = await _evaluation_matrices(
|
||||
self.client, eval_set, sample_weight_eval_set, None, self.missing
|
||||
)
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
metric = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({"eval_metric": eval_metric})
|
||||
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
obj=obj,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks)
|
||||
self._Booster = results['booster']
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
obj=obj,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks,
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results["booster"]
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
self.evals_result_ = results["history"]
|
||||
return self
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
@ -1384,60 +1381,31 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: List[Tuple[_DaskCollection, _DaskCollection]] = None,
|
||||
eval_metric: Optional[Callable] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: List[TrainingCallback] = None
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks)
|
||||
|
||||
async def _predict_async(
|
||||
self, data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> _DaskCollection:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
client=self.client, data=data, base_margin=base_margin,
|
||||
missing=self.missing
|
||||
return self.client.sync(
|
||||
self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose,
|
||||
xgb_model=xgb_model,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features)
|
||||
return pred_probs
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def predict(
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(self._predict_async, data,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
@ -1450,10 +1418,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
eval_metric: Optional[Union[str, List[str], Callable]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
early_stopping_rounds: int,
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]]
|
||||
) -> "DaskXGBClassifier":
|
||||
@ -1481,29 +1450,31 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
None,
|
||||
self.missing)
|
||||
|
||||
if callable(self.objective):
|
||||
obj = _objective_decorator(self.objective)
|
||||
else:
|
||||
obj = None
|
||||
metric = eval_metric if callable(eval_metric) else None
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
eval_metric = None
|
||||
else:
|
||||
params.update({"eval_metric": eval_metric})
|
||||
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
obj=obj,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks)
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model,
|
||||
eval_metric=eval_metric,
|
||||
params=params
|
||||
)
|
||||
results = await train(
|
||||
client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
obj=obj,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks,
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results['booster']
|
||||
|
||||
if not callable(self.objective):
|
||||
@ -1522,26 +1493,30 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Callable]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
early_stopping_rounds: int = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
verbose: bool = True,
|
||||
feature_weights: _DaskCollection = None,
|
||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBClassifier":
|
||||
_assert_dask_support()
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks)
|
||||
return self.client.sync(
|
||||
self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
eval_metric=eval_metric,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose,
|
||||
xgb_model=xgb_model,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
async def _predict_proba_async(
|
||||
self,
|
||||
@ -1561,7 +1536,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
output_margin=output_margin)
|
||||
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
|
||||
|
||||
# pylint: disable=arguments-differ,missing-docstring
|
||||
# pylint: disable=missing-docstring
|
||||
def predict_proba(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
@ -1587,16 +1562,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> _DaskCollection:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
client=self.client, data=data, base_margin=base_margin,
|
||||
missing=self.missing
|
||||
)
|
||||
pred_probs = await predict(
|
||||
client=self.client,
|
||||
model=self.get_booster(),
|
||||
data=test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features
|
||||
pred_probs = await super()._predict_async(
|
||||
data, output_margin, validate_features, base_margin
|
||||
)
|
||||
if output_margin:
|
||||
return pred_probs
|
||||
@ -1608,22 +1575,126 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
return preds
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def predict(
|
||||
|
||||
@xgboost_model_doc(
|
||||
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
|
||||
["estimators", "model"],
|
||||
end_note="""
|
||||
Note
|
||||
----
|
||||
For dask implementation, group is not supported, use qid instead.
|
||||
""",
|
||||
)
|
||||
class DaskXGBRanker(DaskScikitLearnBase):
|
||||
def __init__(self, objective: str = "rank:pairwise", **kwargs: Any):
|
||||
if callable(objective):
|
||||
raise ValueError("Custom objective function not supported by XGBRanker.")
|
||||
super().__init__(objective=objective, kwargs=kwargs)
|
||||
|
||||
async def _fit_async(
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(
|
||||
self._predict_async,
|
||||
data,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
qid: Optional[_DaskCollection],
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||
eval_qid: Optional[List[_DaskCollection]],
|
||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||
early_stopping_rounds: int,
|
||||
verbose: bool,
|
||||
xgb_model: Optional[Union[XGBModel, Booster]],
|
||||
feature_weights: Optional[_DaskCollection],
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
) -> "DaskXGBRanker":
|
||||
dtrain = await DaskDMatrix(
|
||||
client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
qid=qid,
|
||||
weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=self.missing,
|
||||
)
|
||||
params = self.get_xgb_params()
|
||||
evals = await _evaluation_matrices(
|
||||
self.client,
|
||||
eval_set,
|
||||
sample_weight_eval_set,
|
||||
sample_qid=eval_qid,
|
||||
missing=self.missing,
|
||||
)
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
raise ValueError(
|
||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model,
|
||||
eval_metric=eval_metric,
|
||||
params=params
|
||||
)
|
||||
results = await train(
|
||||
client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks,
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results["booster"]
|
||||
self.evals_result_ = results["history"]
|
||||
return self
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit( # pylint: disable=arguments-differ
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
eval_group: Optional[List[_DaskCollection]] = None,
|
||||
eval_qid: Optional[List[_DaskCollection]] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||
early_stopping_rounds: int = None,
|
||||
verbose: bool = False,
|
||||
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
msg = "Use `qid` instead of `group` on dask interface."
|
||||
if not (group is None and eval_group is None):
|
||||
raise ValueError(msg)
|
||||
if qid is None:
|
||||
raise ValueError("`qid` is required for ranking.")
|
||||
return self.client.sync(
|
||||
self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
qid=qid,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_qid=eval_qid,
|
||||
eval_metric=eval_metric,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose,
|
||||
xgb_model=xgb_model,
|
||||
feature_weights=feature_weights,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
fit.__doc__ = XGBRanker.fit.__doc__
|
||||
|
||||
@ -7,6 +7,7 @@ import json
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||
from .core import Metric
|
||||
from .training import train
|
||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
||||
|
||||
@ -132,6 +133,10 @@ __model_doc = '''
|
||||
importance_type: string, default "gain"
|
||||
The feature importance type for the feature_importances\\_ property:
|
||||
either "gain", "weight", "cover", "total_gain" or "total_cover".
|
||||
gpu_id :
|
||||
Device ordinal.
|
||||
validate_parameters :
|
||||
Give warnings for unknown parameter.
|
||||
|
||||
\\*\\*kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of
|
||||
@ -211,20 +216,41 @@ Parameters
|
||||
['estimators', 'model', 'objective'])
|
||||
class XGBModel(XGBModelBase):
|
||||
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
|
||||
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
|
||||
verbosity=None, objective=None, booster=None,
|
||||
tree_method=None, n_jobs=None, gamma=None,
|
||||
min_child_weight=None, max_delta_step=None, subsample=None,
|
||||
colsample_bytree=None, colsample_bylevel=None,
|
||||
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
|
||||
scale_pos_weight=None, base_score=None, random_state=None,
|
||||
missing=np.nan, num_parallel_tree=None,
|
||||
monotone_constraints=None, interaction_constraints=None,
|
||||
importance_type="gain", gpu_id=None,
|
||||
validate_parameters=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
max_depth=None,
|
||||
learning_rate=None,
|
||||
n_estimators=100,
|
||||
verbosity=None,
|
||||
objective=None,
|
||||
booster=None,
|
||||
tree_method=None,
|
||||
n_jobs=None,
|
||||
gamma=None,
|
||||
min_child_weight=None,
|
||||
max_delta_step=None,
|
||||
subsample=None,
|
||||
colsample_bytree=None,
|
||||
colsample_bylevel=None,
|
||||
colsample_bynode=None,
|
||||
reg_alpha=None,
|
||||
reg_lambda=None,
|
||||
scale_pos_weight=None,
|
||||
base_score=None,
|
||||
random_state=None,
|
||||
missing=np.nan,
|
||||
num_parallel_tree=None,
|
||||
monotone_constraints=None,
|
||||
interaction_constraints=None,
|
||||
importance_type="gain",
|
||||
gpu_id=None,
|
||||
validate_parameters=None,
|
||||
**kwargs
|
||||
):
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise XGBoostError(
|
||||
'sklearn needs to be installed in order to use this module')
|
||||
"sklearn needs to be installed in order to use this module"
|
||||
)
|
||||
self.n_estimators = n_estimators
|
||||
self.objective = objective
|
||||
|
||||
@ -255,11 +281,23 @@ class XGBModel(XGBModelBase):
|
||||
self.gpu_id = gpu_id
|
||||
self.validate_parameters = validate_parameters
|
||||
|
||||
def _wrap_evaluation_matrices(self, X, y, group,
|
||||
sample_weight, base_margin, feature_weights,
|
||||
eval_set, sample_weight_eval_set, eval_group,
|
||||
label_transform=lambda x: x):
|
||||
'''Convert array_like evaluation matrices into DMatrix'''
|
||||
def _wrap_evaluation_matrices(
|
||||
self, X, y,
|
||||
group,
|
||||
qid,
|
||||
sample_weight,
|
||||
base_margin,
|
||||
feature_weights,
|
||||
eval_set,
|
||||
sample_weight_eval_set,
|
||||
eval_group,
|
||||
eval_qid,
|
||||
label_transform=lambda x: x
|
||||
) -> Tuple[DMatrix, Optional[List[Tuple[DMatrix, str]]]]:
|
||||
'''Convert array_like evaluation matrices into DMatrix, group and qid are only used for
|
||||
valiation but not set in this function
|
||||
|
||||
'''
|
||||
if sample_weight_eval_set is not None:
|
||||
assert eval_set is not None
|
||||
assert len(sample_weight_eval_set) == len(eval_set)
|
||||
@ -271,26 +309,32 @@ class XGBModel(XGBModelBase):
|
||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
train_dmatrix.set_info(feature_weights=feature_weights, group=group)
|
||||
train_dmatrix.set_info(feature_weights=feature_weights)
|
||||
|
||||
if eval_set is not None:
|
||||
if sample_weight_eval_set is None:
|
||||
sample_weight_eval_set = [None] * len(eval_set)
|
||||
if eval_group is None:
|
||||
eval_group = [None] * len(eval_set)
|
||||
if eval_qid is None:
|
||||
eval_qid = [None] * len(eval_set)
|
||||
|
||||
evals = []
|
||||
for i, (valid_X, valid_y) in enumerate(eval_set):
|
||||
# Skip the duplicated entry.
|
||||
if valid_X is X and valid_y is y and \
|
||||
sample_weight_eval_set[i] is sample_weight and eval_group[i] is group:
|
||||
if (
|
||||
valid_X is X and
|
||||
valid_y is y and
|
||||
sample_weight_eval_set[i] is sample_weight and
|
||||
eval_group[i] is group and
|
||||
eval_qid[i] is qid
|
||||
):
|
||||
evals.append(train_dmatrix)
|
||||
else:
|
||||
m = DMatrix(valid_X,
|
||||
label=label_transform(valid_y),
|
||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
||||
nthread=self.n_jobs)
|
||||
m.set_info(group=eval_group[i])
|
||||
evals.append(m)
|
||||
|
||||
nevals = len(evals)
|
||||
@ -515,7 +559,7 @@ class XGBModel(XGBModelBase):
|
||||
booster: Optional[Booster],
|
||||
eval_metric: Optional[Union[Callable, str, List[str]]],
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
|
||||
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
|
||||
# pylint: disable=protected-access, no-self-use
|
||||
model = booster
|
||||
if hasattr(model, '_Booster'):
|
||||
@ -602,8 +646,6 @@ class XGBModel(XGBModelBase):
|
||||
save_best=True)]
|
||||
|
||||
"""
|
||||
self.n_features_in_ = X.shape[1]
|
||||
|
||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
@ -613,9 +655,12 @@ class XGBModel(XGBModelBase):
|
||||
evals_result = {}
|
||||
|
||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
||||
X, y, group=None, sample_weight=sample_weight, base_margin=base_margin,
|
||||
X, y, group=None, qid=None, sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights, eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set, eval_group=None)
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_group=None, eval_qid=None
|
||||
)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
if callable(self.objective):
|
||||
@ -640,10 +685,6 @@ class XGBModel(XGBModelBase):
|
||||
evals_result_key]
|
||||
self.evals_result_ = evals_result
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
self.best_score = self._Booster.best_score
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=None,
|
||||
@ -663,16 +704,20 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : numpy.array/scipy.sparse
|
||||
data : array_like
|
||||
Data to predict with
|
||||
output_margin : bool
|
||||
Whether to output the raw untransformed margin value.
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
|
||||
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if
|
||||
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
|
||||
trees).
|
||||
validate_features : bool
|
||||
When this is True, validate that the Booster's and data's feature_names are identical.
|
||||
Otherwise, it is assumed that the feature_names are the same.
|
||||
base_margin : array_like
|
||||
Margin added to prediction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
@ -754,6 +799,32 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
return evals_result
|
||||
|
||||
@property
|
||||
def n_features_in_(self) -> int:
|
||||
booster = self.get_booster()
|
||||
return booster.num_features()
|
||||
|
||||
def _early_stopping_attr(self, attr: str) -> Union[float, int]:
|
||||
booster = self.get_booster()
|
||||
try:
|
||||
return getattr(booster, attr)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f'`{attr}` in only defined when early stopping is used.'
|
||||
) from e
|
||||
|
||||
@property
|
||||
def best_score(self) -> float:
|
||||
return float(self._early_stopping_attr('best_score'))
|
||||
|
||||
@property
|
||||
def best_iteration(self) -> int:
|
||||
return int(self._early_stopping_attr('best_iteration'))
|
||||
|
||||
@property
|
||||
def best_ntree_limit(self) -> int:
|
||||
return int(self._early_stopping_attr('best_ntree_limit'))
|
||||
|
||||
@property
|
||||
def feature_importances_(self):
|
||||
"""
|
||||
@ -941,14 +1012,17 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
raise ValueError(
|
||||
'Please reshape the input data X into 2-dimensional matrix.')
|
||||
|
||||
self._features_count = X.shape[1]
|
||||
self.n_features_in_ = self._features_count
|
||||
|
||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
||||
X, y, group=None, sample_weight=sample_weight, base_margin=base_margin,
|
||||
X, y, group=None, qid=None,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_group=None, label_transform=label_transform)
|
||||
eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
label_transform=label_transform
|
||||
)
|
||||
|
||||
self._Booster = train(params, train_dmatrix,
|
||||
self.get_num_boosting_rounds(),
|
||||
@ -968,11 +1042,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
evals_result_key] = val[1][evals_result_key]
|
||||
self.evals_result_ = evals_result
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
self.best_score = self._Booster.best_score
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
|
||||
return self
|
||||
|
||||
fit.__doc__ = XGBModel.fit.__doc__.replace(
|
||||
@ -1058,6 +1127,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
validate_features : bool
|
||||
When this is True, validate that the Booster's and data's feature_names are
|
||||
identical. Otherwise, it is assumed that the feature_names are the same.
|
||||
base_margin : array_like
|
||||
Margin added to prediction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1198,11 +1269,12 @@ class XGBRFRegressor(XGBRegressor):
|
||||
|
||||
Note
|
||||
----
|
||||
Query group information is required for ranking tasks.
|
||||
Query group information is required for ranking tasks by either using the `group`
|
||||
parameter or `qid` parameter in `fit` method.
|
||||
|
||||
Before fitting the model, your data need to be sorted by query
|
||||
group. When fitting the model, you need to provide an additional array
|
||||
that contains the size of each query group.
|
||||
Before fitting the model, your data need to be sorted by query group. When fitting
|
||||
the model, you need to provide an additional array that contains the size of each
|
||||
query group.
|
||||
|
||||
For example, if your original data look like:
|
||||
|
||||
@ -1224,7 +1296,8 @@ class XGBRFRegressor(XGBRegressor):
|
||||
| 2 | 1 | x_7 |
|
||||
+-------+-----------+---------------+
|
||||
|
||||
then your group array should be ``[3, 4]``.
|
||||
then your group array should be ``[3, 4]``. Sometimes using query id (`qid`)
|
||||
instead of group can be more convenient.
|
||||
''')
|
||||
class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||
@ -1238,11 +1311,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
raise ValueError("please use XGBRanker for ranking task")
|
||||
|
||||
@_deprecate_positional_args
|
||||
def fit(self, X, y, *, group, sample_weight=None, base_margin=None,
|
||||
eval_set=None, sample_weight_eval_set=None,
|
||||
eval_group=None, eval_metric=None,
|
||||
early_stopping_rounds=None, verbose=False, xgb_model=None,
|
||||
feature_weights=None, callbacks=None):
|
||||
def fit(
|
||||
self, X, y, *,
|
||||
group=None,
|
||||
qid=None,
|
||||
sample_weight=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
eval_metric=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=False,
|
||||
xgb_model=None,
|
||||
feature_weights=None,
|
||||
callbacks=None
|
||||
) -> "XGBRanker":
|
||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||
"""Fit gradient boosting ranker
|
||||
|
||||
@ -1257,17 +1342,21 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
y : array_like
|
||||
Labels
|
||||
group : array_like
|
||||
Size of each query group of training data. Should have as many
|
||||
elements as the query groups in the training data
|
||||
Size of each query group of training data. Should have as many elements as the
|
||||
query groups in the training data. If this is set to None, then user must
|
||||
provide qid.
|
||||
qid : array_like
|
||||
Query ID for each training sample. Should have the size of n_samples. If
|
||||
this is set to None, then user must provide group.
|
||||
sample_weight : array_like
|
||||
Query group weights
|
||||
|
||||
.. note:: Weights are per-group for ranking tasks
|
||||
|
||||
In ranking task, one weight is assigned to each query group
|
||||
(not each data point). This is because we only care about the
|
||||
relative ordering of data points within each group, so it
|
||||
doesn't make sense to assign weights to individual data points.
|
||||
In ranking task, one weight is assigned to each query group/id (not each
|
||||
data point). This is because we only care about the relative ordering of
|
||||
data points within each group, so it doesn't make sense to assign weights
|
||||
to individual data points.
|
||||
|
||||
base_margin : array_like
|
||||
Global bias for each instance.
|
||||
@ -1289,6 +1378,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
eval_group : list of arrays, optional
|
||||
A list in which ``eval_group[i]`` is the list containing the sizes of all
|
||||
query groups in the ``i``-th pair in **eval_set**.
|
||||
eval_qid : list of array_like, optional
|
||||
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
|
||||
pair in **eval_set**.
|
||||
eval_metric : str, list of str, optional
|
||||
If a str, should be a built-in evaluation metric to use. See
|
||||
doc/parameter.rst.
|
||||
@ -1333,23 +1425,60 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
raise ValueError("group is required for ranking task")
|
||||
|
||||
if eval_set is not None:
|
||||
if eval_group is None:
|
||||
if eval_group is None and eval_qid is None:
|
||||
raise ValueError(
|
||||
"eval_group is required if eval_set is not None")
|
||||
if len(eval_group) != len(eval_set):
|
||||
"eval_group or eval_qid is required if eval_set is not None")
|
||||
if (
|
||||
(eval_group is not None and len(eval_group) != len(eval_set)) and
|
||||
(eval_qid is not None and len(eval_qid) != len(eval_set))
|
||||
):
|
||||
raise ValueError(
|
||||
"length of eval_group should match that of eval_set")
|
||||
if any(group is None for group in eval_group):
|
||||
raise ValueError(
|
||||
"group is required for all eval datasets for ranking task")
|
||||
|
||||
self.n_features_in_ = X.shape[1]
|
||||
"length of eval_group or eval_qid should match that of eval_set"
|
||||
)
|
||||
for i in range(len(eval_set)):
|
||||
if (
|
||||
(eval_group is not None and eval_group[i] is not None) and
|
||||
(eval_qid is not None and eval_qid[i] is not None)
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of the qid or group should be provided."
|
||||
)
|
||||
|
||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
||||
X, y, group=group, sample_weight=sample_weight, base_margin=base_margin,
|
||||
feature_weights=feature_weights, eval_set=eval_set,
|
||||
X, y,
|
||||
group=group,
|
||||
qid=qid,
|
||||
sample_weight=sample_weight,
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
eval_group=eval_group)
|
||||
eval_group=eval_group,
|
||||
eval_qid=eval_qid
|
||||
)
|
||||
|
||||
if qid is not None:
|
||||
train_dmatrix.set_info(qid=qid)
|
||||
elif group is not None:
|
||||
train_dmatrix.set_info(group=group)
|
||||
else:
|
||||
raise ValueError("Either qid or group should be provided for ranking task.")
|
||||
|
||||
if evals is not None:
|
||||
for i, e in enumerate(evals):
|
||||
if eval_qid is not None and eval_qid[i] is not None:
|
||||
assert eval_group is None or eval_group[i] is None, (
|
||||
'Only one of the eval_qid or eval_group for each evaluation '
|
||||
'dataset should be provided.'
|
||||
)
|
||||
e[0].set_info(qid=qid)
|
||||
elif eval_group is not None and eval_group[i] is not None:
|
||||
e[0].set_info(group=eval_group[i])
|
||||
else:
|
||||
raise ValueError(
|
||||
'Either eval_qid or eval_group for each evaluation dataset should'
|
||||
' be provided for ranking task.'
|
||||
)
|
||||
|
||||
evals_result = {}
|
||||
params = self.get_xgb_params()
|
||||
@ -1380,11 +1509,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
|
||||
self.evals_result = evals_result
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
self.best_score = self._Booster.best_score
|
||||
self.best_iteration = self._Booster.best_iteration
|
||||
self.best_ntree_limit = self._Booster.best_ntree_limit
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data, output_margin=False,
|
||||
|
||||
@ -111,6 +111,7 @@ def _train_internal(params, dtrain,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown booster: {booster}')
|
||||
|
||||
num_groups = int(config['learner']['learner_model_param']['num_class'])
|
||||
num_groups = 1 if num_groups == 0 else num_groups
|
||||
bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups)
|
||||
|
||||
@ -498,6 +498,7 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<Learner*>(handle)->Configure();
|
||||
*out = static_cast<Learner*>(handle)->GetNumFeature();
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -374,13 +374,32 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
|
||||
} else if (!std::strcmp(key, "group")) {
|
||||
group_ptr_.resize(num + 1);
|
||||
group_ptr_.clear(); group_ptr_.resize(num + 1, 0);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1));
|
||||
group_ptr_[0] = 0;
|
||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
||||
}
|
||||
} else if (!std::strcmp(key, "qid")) {
|
||||
std::vector<uint32_t> query_ids(num, 0);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, query_ids.begin()));
|
||||
bool non_dec = true;
|
||||
for (size_t i = 1; i < query_ids.size(); ++i) {
|
||||
if (query_ids[i] < query_ids[i-1]) {
|
||||
non_dec = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data.";
|
||||
group_ptr_.clear(); group_ptr_.push_back(0);
|
||||
for (size_t i = 1; i < query_ids.size(); ++i) {
|
||||
if (query_ids[i] != query_ids[i-1]) {
|
||||
group_ptr_.push_back(i);
|
||||
}
|
||||
}
|
||||
group_ptr_.push_back(query_ids.size());
|
||||
} else if (!std::strcmp(key, "label_lower_bound")) {
|
||||
auto& labels = labels_lower_bound_.HostVector();
|
||||
labels.resize(num);
|
||||
|
||||
@ -34,16 +34,20 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
auto SetDeviceToPtr(void *ptr) {
|
||||
cudaPointerAttributes attr;
|
||||
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
|
||||
int32_t ptr_device = attr.device;
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
|
||||
<< "Expected integer metainfo";
|
||||
auto SetDeviceToPtr = [](void* ptr) {
|
||||
cudaPointerAttributes attr;
|
||||
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
|
||||
int32_t ptr_device = attr.device;
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
};
|
||||
|
||||
auto ptr_device = SetDeviceToPtr(column.data);
|
||||
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
|
||||
auto d_tmp = temp.data();
|
||||
@ -95,6 +99,47 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
return;
|
||||
} else if (key == "qid") {
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return static_cast<uint32_t>(array_interface.GetElement(i));
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
|
||||
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
bool non_dec = true;
|
||||
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
|
||||
cudaMemcpyDeviceToHost));
|
||||
CHECK(non_dec)
|
||||
<< "`qid` must be sorted in increasing order along with data.";
|
||||
size_t bytes = 0;
|
||||
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
|
||||
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
|
||||
HostDeviceVector<int> d_num_runs_out(1, 0, d);
|
||||
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(),
|
||||
cnt.begin(), d_num_runs_out.DevicePointer(),
|
||||
array_interface.num_rows);
|
||||
dh::caching_device_vector<char> tmp(bytes);
|
||||
cub::DeviceRunLengthEncode::Encode(tmp.data().get(), bytes, it, out.begin(),
|
||||
cnt.begin(), d_num_runs_out.DevicePointer(),
|
||||
array_interface.num_rows);
|
||||
|
||||
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
|
||||
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
|
||||
cnt.begin() + h_num_runs_out, cnt.begin());
|
||||
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
|
||||
group_ptr_.begin() + 1);
|
||||
return;
|
||||
} else if (key == "label_lower_bound") {
|
||||
CopyInfoImpl(array_interface, &labels_lower_bound_);
|
||||
return;
|
||||
|
||||
@ -436,7 +436,7 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t GetNumFeature() override {
|
||||
uint32_t GetNumFeature() const override {
|
||||
return learner_model_param_.num_feature;
|
||||
}
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
|
||||
|
||||
template <typename T>
|
||||
Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
|
||||
thrust::device_vector<T>* p_data) {
|
||||
thrust::device_vector<T> *p_data) {
|
||||
auto& data = *p_data;
|
||||
thrust::sequence(data.begin(), data.end());
|
||||
|
||||
|
||||
@ -202,6 +202,24 @@ TEST(MetaInfo, LoadQid) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MetaInfo, CPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = 100;
|
||||
std::vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
qid[i] = i;
|
||||
}
|
||||
|
||||
info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
|
||||
for (size_t i = 0; i < info.num_row_ + 1; ++i) {
|
||||
ASSERT_EQ(info.group_ptr_[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MetaInfo, Validate) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = 10;
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include "test_array_interface.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -105,6 +106,28 @@ TEST(MetaInfo, Group) {
|
||||
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
||||
}
|
||||
|
||||
TEST(MetaInfo, GPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
info.num_row_ = 100;
|
||||
thrust::device_vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
qid[i] = i;
|
||||
}
|
||||
auto column = Generate2dArrayInterface(info.num_row_, 1, "<u4", &qid);
|
||||
Json array{std::vector<Json>{column}};
|
||||
std::string array_str;
|
||||
Json::Dump(array, &array_str);
|
||||
info.SetInfo("qid", array_str.c_str());
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
|
||||
for (size_t i = 0; i < info.num_row_ + 1; ++i) {
|
||||
ASSERT_EQ(info.group_ptr_[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(MetaInfo, DeviceExtend) {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
size_t const kRows = 100;
|
||||
|
||||
@ -171,6 +171,22 @@ Arrow specification.'''
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
m.slice(rindex=[0, 1, 2])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_qid(self):
|
||||
import cupy as cp
|
||||
rng = cp.random.RandomState(1994)
|
||||
rows = 100
|
||||
cols = 10
|
||||
X, y = rng.randn(rows, cols), rng.randn(rows)
|
||||
qid = rng.randint(low=0, high=10, size=rows, dtype=np.uint32)
|
||||
qid = cp.sort(qid)
|
||||
|
||||
Xy = xgb.DMatrix(X, y)
|
||||
Xy.set_info(qid=qid)
|
||||
group_ptr = Xy.get_uint_info('group_ptr')
|
||||
assert group_ptr[0] == 0
|
||||
assert group_ptr[-1] == rows
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_specified_device(self):
|
||||
|
||||
@ -239,6 +239,19 @@ class TestDMatrix:
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('group_ptr')
|
||||
|
||||
def test_qid(self):
|
||||
rows = 100
|
||||
cols = 10
|
||||
X, y = rng.randn(rows, cols), rng.randn(rows)
|
||||
qid = rng.randint(low=0, high=10, size=rows, dtype=np.uint32)
|
||||
qid = np.sort(qid)
|
||||
|
||||
Xy = xgb.DMatrix(X, y)
|
||||
Xy.set_info(qid=qid)
|
||||
group_ptr = Xy.get_uint_info('group_ptr')
|
||||
assert group_ptr[0] == 0
|
||||
assert group_ptr[-1] == rows
|
||||
|
||||
def test_feature_weights(self):
|
||||
kRows = 10
|
||||
kCols = 50
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from scipy.sparse import csr_matrix
|
||||
import testing as tm
|
||||
import xgboost
|
||||
import os
|
||||
import itertools
|
||||
@ -79,22 +80,10 @@ class TestRanking:
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
# download the test data
|
||||
cls.dpath = 'demo/rank/'
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = cls.dpath + '/MQ2008.zip'
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath)
|
||||
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import scipy
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||
import asyncio
|
||||
@ -670,12 +671,56 @@ def run_aft_survival(client: "Client", dmatrix_t: Type) -> None:
|
||||
assert nloglik_rec['extreme'][-1] > 4.9
|
||||
|
||||
|
||||
def test_aft_survival() -> None:
|
||||
def test_dask_aft_survival() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_aft_survival(client, DaskDMatrix)
|
||||
|
||||
|
||||
def test_dask_ranking(client: "Client") -> None:
|
||||
dpath = "demo/rank/"
|
||||
mq2008 = tm.get_mq2008(dpath)
|
||||
data = []
|
||||
for d in mq2008:
|
||||
if isinstance(d, scipy.sparse.csr_matrix):
|
||||
d[d == 0] = np.inf
|
||||
d = d.toarray()
|
||||
d[d == 0] = np.nan
|
||||
d[np.isinf(d)] = 0
|
||||
data.append(da.from_array(d))
|
||||
else:
|
||||
data.append(da.from_array(d))
|
||||
|
||||
(
|
||||
x_train,
|
||||
y_train,
|
||||
qid_train,
|
||||
x_test,
|
||||
y_test,
|
||||
qid_test,
|
||||
x_valid,
|
||||
y_valid,
|
||||
qid_valid,
|
||||
) = data
|
||||
qid_train = qid_train.astype(np.uint32)
|
||||
qid_valid = qid_valid.astype(np.uint32)
|
||||
qid_test = qid_test.astype(np.uint32)
|
||||
|
||||
rank = xgb.dask.DaskXGBRanker(n_estimators=2500)
|
||||
rank.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
qid=qid_train,
|
||||
eval_set=[(x_test, y_test), (x_train, y_train)],
|
||||
eval_qid=[qid_test, qid_train],
|
||||
eval_metric=["ndcg"],
|
||||
verbose=True,
|
||||
early_stopping_rounds=10,
|
||||
)
|
||||
assert rank.n_features_in_ == 46
|
||||
assert rank.best_score > 0.98
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_global_config(self, client: "Client") -> None:
|
||||
X, y, _ = generate_array()
|
||||
@ -981,7 +1026,7 @@ class TestWithDask:
|
||||
def test_shap(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_boston, load_digits
|
||||
X, y = load_boston(return_X_y=True)
|
||||
params = {'objective': 'reg:squarederror'}
|
||||
params: Dict[str, Any] = {'objective': 'reg:squarederror'}
|
||||
self.run_shap(X, y, params, client)
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
|
||||
@ -125,9 +125,11 @@ def test_ranking():
|
||||
x_train = np.random.rand(1000, 10)
|
||||
y_train = np.random.randint(5, size=1000)
|
||||
train_group = np.repeat(50, 20)
|
||||
|
||||
x_valid = np.random.rand(200, 10)
|
||||
y_valid = np.random.randint(5, size=200)
|
||||
valid_group = np.repeat(50, 4)
|
||||
|
||||
x_test = np.random.rand(100, 10)
|
||||
|
||||
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
@ -136,6 +138,7 @@ def test_ranking():
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, group=train_group,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||
|
||||
pred = model.predict(x_test)
|
||||
|
||||
train_data = xgb.DMatrix(x_train, y_train)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# coding: utf-8
|
||||
import os
|
||||
import urllib
|
||||
import zipfile
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
@ -209,6 +211,29 @@ def get_sparse():
|
||||
return X, y
|
||||
|
||||
|
||||
@memory.cache
|
||||
def get_mq2008(dpath):
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = dpath + '/MQ2008.zip'
|
||||
if not os.path.exists(target):
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=dpath)
|
||||
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(dpath + "MQ2008/Fold1/train.txt",
|
||||
dpath + "MQ2008/Fold1/test.txt",
|
||||
dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
|
||||
return (x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid)
|
||||
|
||||
|
||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
|
||||
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user