[dask] Add DaskXGBRanker (#6576)

* Initial support for distributed LTR using dask.

* Support `qid` in libxgboost.
* Refactor `predict` and `n_features_in_`, `best_[score/iteration/ntree_limit]`
  to avoid duplicated code.
* Define `DaskXGBRanker`.

The dask ranker doesn't support group structure, instead it uses query id and
convert to group ptr internally.
This commit is contained in:
Jiaming Yuan 2021-01-08 18:35:09 +08:00 committed by GitHub
parent 96d3d32265
commit 80065d571e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 755 additions and 351 deletions

View File

@ -165,7 +165,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Get the number of features of the booster.
* \return number of features
*/
virtual uint32_t GetNumFeature() = 0;
virtual uint32_t GetNumFeature() const = 0;
/*!
* \brief Set additional attribute to the Booster.

View File

@ -321,6 +321,7 @@ class DataIter:
def data_handle(data, label=None, weight=None, base_margin=None,
group=None,
qid=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None,
feature_weights=None):
@ -333,6 +334,7 @@ class DataIter:
self.proxy.set_info(label=label, weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names,
@ -523,12 +525,14 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
def set_info(self, *,
label=None, weight=None, base_margin=None,
group=None,
qid=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
from .data import dispatch_meta_backend
if label is not None:
self.set_label(label)
if weight is not None:
@ -537,6 +541,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.set_base_margin(base_margin)
if group is not None:
self.set_group(group)
if qid is not None:
dispatch_meta_backend(matrix=self, data=qid, name='qid')
if label_lower_bound is not None:
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:
@ -546,7 +552,6 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if feature_types is not None:
self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')
@ -993,7 +998,7 @@ class DeviceQuantileDMatrix(DMatrix):
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
class Booster(object):
@ -1743,10 +1748,16 @@ class Booster(object):
'''
rounds = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterBoostedRounds(
self.handle, ctypes.byref(rounds)))
_check_call(_LIB.XGBoosterBoostedRounds(self.handle, ctypes.byref(rounds)))
return rounds.value
def num_features(self) -> int:
'''Number of features in booster.'''
features = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterGetNumFeature(self.handle, ctypes.byref(features)))
return features.value
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. Unlike `save_model`, the
output format is primarily used for visualization or interpretation,

View File

@ -1,6 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines, fixme
# pylint: disable=import-error
"""Dask extensions for distributed training. See
https://xgboost.readthedocs.io/en/latest/tutorials/dask.html for simple
@ -41,6 +41,7 @@ from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker
if TYPE_CHECKING:
@ -207,6 +208,8 @@ class DaskDMatrix:
Weight for each instance.
base_margin :
Global bias for each instance.
qid :
Query ID for ranking.
label_lower_bound :
Upper bound for survival training.
label_upper_bound :
@ -220,14 +223,17 @@ class DaskDMatrix:
'''
@_deprecate_positional_args
def __init__(
self,
client: "distributed.Client",
data: _DaskCollection,
label: Optional[_DaskCollection] = None,
*,
missing: float = None,
weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
@ -241,6 +247,9 @@ class DaskDMatrix:
self.feature_types = feature_types
self.missing = missing
if qid is not None and weight is not None:
raise NotImplementedError('per-group weight is not implemented.')
if len(data.shape) != 2:
raise ValueError(
'Expecting 2 dimensional input, got: {shape}'.format(
@ -259,6 +268,7 @@ class DaskDMatrix:
self._init = client.sync(self.map_local_data,
client, data, label=label, weights=weight,
base_margin=base_margin,
qid=qid,
feature_weights=feature_weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
@ -273,6 +283,7 @@ class DaskDMatrix:
label: Optional[_DaskCollection] = None,
weights: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None
@ -325,6 +336,7 @@ class DaskDMatrix:
y_parts = flatten_meta(label)
w_parts = flatten_meta(weights)
margin_parts = flatten_meta(base_margin)
qid_parts = flatten_meta(qid)
ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound)
@ -343,6 +355,7 @@ class DaskDMatrix:
append_meta(y_parts, 'labels')
append_meta(w_parts, 'weights')
append_meta(margin_parts, 'base_margin')
append_meta(qid_parts, 'qid')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
# At this point, `parts` looks like:
@ -397,7 +410,7 @@ class DaskDMatrix:
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
Optional[Any]]]
Optional[Any], Optional[Any]]]
def _get_worker_parts_ordered(
@ -413,6 +426,7 @@ def _get_worker_parts_ordered(
labels = None
weights = None
base_margin = None
qid = None
label_lower_bound = None
label_upper_bound = None
# Iterate through all possible meta info, brings small overhead as in xgboost
@ -424,13 +438,15 @@ def _get_worker_parts_ordered(
weights = blob
elif meta_names[j] == 'base_margin':
base_margin = blob
elif meta_names[j] == 'qid':
qid = blob
elif meta_names[j] == 'label_lower_bound':
label_lower_bound = blob
elif meta_names[j] == 'label_upper_bound':
label_upper_bound = blob
else:
raise ValueError('Unknown metainfo:', meta_names[j])
result.append((data, labels, weights, base_margin, label_lower_bound,
result.append((data, labels, weights, base_margin, qid, label_lower_bound,
label_upper_bound))
return result
@ -456,6 +472,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
label: Optional[Tuple[Any, ...]] = None,
weight: Optional[Tuple[Any, ...]] = None,
base_margin: Optional[Tuple[Any, ...]] = None,
qid: Optional[Tuple[Any, ...]] = None,
label_lower_bound: Optional[Tuple[Any, ...]] = None,
label_upper_bound: Optional[Tuple[Any, ...]] = None,
feature_names: Optional[Union[str, List[str]]] = None,
@ -465,6 +482,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
self._labels = label
self._weights = weight
self._base_margin = base_margin
self._qid = qid
self._label_lower_bound = label_lower_bound
self._label_upper_bound = label_upper_bound
self._feature_names = feature_names
@ -498,6 +516,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
return self._weights[self._iter]
return None
def qids(self) -> Any:
'''Utility function for obtaining current batch of query id.'''
if self._qid is not None:
return self._qid[self._iter]
return None
def base_margins(self) -> Any:
'''Utility function for obtaining current batch of base_margin.'''
if self._base_margin is not None:
@ -537,6 +561,7 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
feature_names = None
input_data(data=self.data(), label=self.labels(),
weight=self.weights(), group=None,
qid=self.qids(),
label_lower_bound=self.label_lower_bounds(),
label_upper_bound=self.label_upper_bounds(),
feature_names=feature_names,
@ -567,6 +592,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
missing: float = None,
weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
label_upper_bound: Optional[_DaskCollection] = None,
feature_weights: Optional[_DaskCollection] = None,
@ -574,13 +600,20 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
feature_types: Optional[Union[Any, List[Any]]] = None,
max_bin: int = 256
) -> None:
super().__init__(client=client, data=data, label=label,
missing=missing,
weight=weight, base_margin=base_margin,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types)
super().__init__(
client=client,
data=data,
label=label,
missing=missing,
feature_weights=feature_weights,
weight=weight,
base_margin=base_margin,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types
)
self.max_bin = max_bin
self.is_quantile = True
@ -611,11 +644,12 @@ def _create_device_quantile_dmatrix(
max_bin=max_bin)
return d
(data, labels, weights, base_margin,
(data, labels, weights, base_margin, qid,
label_lower_bound, label_upper_bound) = _get_worker_parts(
parts, meta_names)
it = DaskPartitionIter(data=data, label=labels, weight=weights,
base_margin=base_margin,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
@ -661,26 +695,31 @@ def _create_dmatrix(
return None
return concat(data)
(data, labels, weights, base_margin,
(data, labels, weights, base_margin, qid,
label_lower_bound, label_upper_bound) = _get_worker_parts(list_of_parts, meta_names)
_labels = concat_or_none(labels)
_weights = concat_or_none(weights)
_base_margin = concat_or_none(base_margin)
_qid = concat_or_none(qid)
_label_lower_bound = concat_or_none(label_lower_bound)
_label_upper_bound = concat_or_none(label_upper_bound)
_data = concat(data)
dmatrix = DMatrix(_data,
_labels,
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads)
dmatrix.set_info(base_margin=_base_margin, weight=_weights,
label_lower_bound=_label_lower_bound,
label_upper_bound=_label_upper_bound,
feature_weights=feature_weights)
dmatrix = DMatrix(
_data,
_labels,
missing=missing,
feature_names=feature_names,
feature_types=feature_types,
nthread=worker.nthreads
)
dmatrix.set_info(
base_margin=_base_margin, qid=_qid, weight=_weights,
label_lower_bound=_label_lower_bound,
label_upper_bound=_label_upper_bound,
feature_weights=feature_weights
)
return dmatrix
@ -746,7 +785,7 @@ async def _train_async(
'''Perform training on a single worker. A local function prevents pickling.
'''
LOGGER.info('Training on %s', str(worker_addr))
LOGGER.debug('Training on %s', str(worker_addr))
worker = distributed.get_worker()
with RabitContext(rabit_args), config.config_context(**global_config):
local_dtrain = _dmatrix_from_list_of_parts(**dtrain_ref)
@ -954,7 +993,7 @@ async def _predict_async(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]:
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
LOGGER.debug('Predicting on %d', worker_id)
with config.config_context(**global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
@ -962,7 +1001,7 @@ async def _predict_async(
booster.set_param({'nthread': worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _) = parts
(data, _, _, base_margin, _, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
@ -991,11 +1030,11 @@ async def _predict_async(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[int, int]]:
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
LOGGER.debug('Get shape on %d', worker_id)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _) = parts
(data, _, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes
@ -1182,6 +1221,7 @@ async def _evaluation_matrices(
client: "distributed.Client",
validation_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight: Optional[List[_DaskCollection]],
sample_qid: Optional[List[_DaskCollection]],
missing: float
) -> Optional[List[Tuple[DaskDMatrix, str]]]:
'''
@ -1206,9 +1246,10 @@ async def _evaluation_matrices(
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weight[i] if sample_weight is not None else None)
w = sample_weight[i] if sample_weight is not None else None
qid = sample_qid[i] if sample_qid is not None else None
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
weight=w, missing=missing)
weight=w, missing=missing, qid=qid)
assert isinstance(evals, list)
evals.append((dmat, 'validation_{}'.format(i)))
else:
@ -1223,61 +1264,21 @@ class DaskScikitLearnBase(XGBModel):
# pylint: disable=arguments-differ
@_deprecate_positional_args
def fit(
self,
X: _DaskCollection,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: List[Tuple[_DaskCollection, _DaskCollection]] = None,
eval_metric: Optional[Callable] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: bool = True,
feature_weights: Optional[_DaskCollection] = None,
callbacks: List[TrainingCallback] = None
) -> "DaskScikitLearnBase":
'''Fit gradient boosting model
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
instance weights
eval_set : list, optional
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
sample_weight_eval_set : list, optional
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list
of group weights on the i-th validation set.
early_stopping_rounds : int
Activates early stopping.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
feature_weights: array_like
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown. Only available for `hist`, `gpu_hist` and
`exact` tree methods.
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using :ref:`callback_api`.
Example:
.. code-block:: python
callbacks = [xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)]
'''
raise NotImplementedError
async def _predict_async(
self, data: _DaskCollection,
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> Any:
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix,
output_margin=output_margin,
validate_features=validate_features)
return pred_probs
def predict(
self,
@ -1287,21 +1288,13 @@ class DaskScikitLearnBase(XGBModel):
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> Any:
'''Predict with `data`.
Parameters
----------
data: data that can be used to construct a DaskDMatrix
output_margin : Whether to output the raw untransformed margin value.
ntree_limit : NOT supported on dask interface.
validate_features :
When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same.
Returns
-------
prediction:
'''
raise NotImplementedError
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
assert ntree_limit is None, msg
return self.client.sync(self._predict_async, data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin)
def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
@ -1320,59 +1313,63 @@ class DaskScikitLearnBase(XGBModel):
self._client = clt
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
['estimators', 'model'])
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
)
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self, X: _DaskCollection,
self,
X: _DaskCollection,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, List[str], Callable]],
eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int,
verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]]
callbacks: Optional[List[TrainingCallback]],
) -> _DaskCollection:
dtrain = await DaskDMatrix(client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing)
dtrain = await DaskDMatrix(
client=self.client,
data=X,
label=y,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing,
)
params = self.get_xgb_params()
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
self.missing)
evals = await _evaluation_matrices(
self.client, eval_set, sample_weight_eval_set, None, self.missing
)
if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks)
self._Booster = results['booster']
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
self.evals_result_ = results["history"]
return self
# pylint: disable=missing-docstring
@ -1384,60 +1381,31 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
*,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: List[Tuple[_DaskCollection, _DaskCollection]] = None,
eval_metric: Optional[Callable] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: bool = True,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: List[TrainingCallback] = None
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRegressor":
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)
async def _predict_async(
self, data: _DaskCollection,
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> _DaskCollection:
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
return self.client.sync(
self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose,
xgb_model=xgb_model,
feature_weights=feature_weights,
callbacks=callbacks,
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix,
output_margin=output_margin,
validate_features=validate_features)
return pred_probs
# pylint: disable=arguments-differ
def predict(
self,
data: _DaskCollection,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> Any:
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
assert ntree_limit is None, msg
return self.client.sync(self._predict_async, data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin)
@xgboost_model_doc(
@ -1450,10 +1418,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
eval_metric: Optional[Union[str, List[str], Callable]],
eval_metric: Optional[Union[str, List[str], Metric]],
sample_weight_eval_set: Optional[List[_DaskCollection]],
early_stopping_rounds: int,
verbose: bool,
xgb_model: Optional[Union[Booster, XGBModel]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]]
) -> "DaskXGBClassifier":
@ -1481,29 +1450,31 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
evals = await _evaluation_matrices(self.client, eval_set,
sample_weight_eval_set,
None,
self.missing)
if callable(self.objective):
obj = _objective_decorator(self.objective)
else:
obj = None
metric = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({"eval_metric": eval_metric})
results = await train(client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks)
model, metric, params = self._configure_fit(
booster=xgb_model,
eval_metric=eval_metric,
params=params
)
results = await train(
client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
xgb_model=model,
)
self._Booster = results['booster']
if not callable(self.objective):
@ -1522,26 +1493,30 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
early_stopping_rounds: int = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: bool = True,
feature_weights: _DaskCollection = None,
xgb_model: Optional[Union[Booster, XGBModel]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBClassifier":
_assert_dask_support()
return self.client.sync(self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose,
feature_weights=feature_weights,
callbacks=callbacks)
return self.client.sync(
self._fit_async,
X=X,
y=y,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
eval_metric=eval_metric,
sample_weight_eval_set=sample_weight_eval_set,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose,
xgb_model=xgb_model,
feature_weights=feature_weights,
callbacks=callbacks,
)
async def _predict_proba_async(
self,
@ -1561,7 +1536,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
output_margin=output_margin)
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
# pylint: disable=arguments-differ,missing-docstring
# pylint: disable=missing-docstring
def predict_proba(
self,
X: _DaskCollection,
@ -1587,16 +1562,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> _DaskCollection:
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(
client=self.client,
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin,
validate_features=validate_features
pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin
)
if output_margin:
return pred_probs
@ -1608,22 +1575,126 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
return preds
# pylint: disable=arguments-differ
def predict(
@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
["estimators", "model"],
end_note="""
Note
----
For dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(DaskScikitLearnBase):
def __init__(self, objective: str = "rank:pairwise", **kwargs: Any):
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
super().__init__(objective=objective, kwargs=kwargs)
async def _fit_async(
self,
data: _DaskCollection,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
) -> Any:
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
assert ntree_limit is None, msg
return self.client.sync(
self._predict_async,
data,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin
X: _DaskCollection,
y: _DaskCollection,
qid: Optional[_DaskCollection],
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[List[_DaskCollection]],
eval_qid: Optional[List[_DaskCollection]],
eval_metric: Optional[Union[str, List[str], Metric]],
early_stopping_rounds: int,
verbose: bool,
xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection],
callbacks: Optional[List[TrainingCallback]],
) -> "DaskXGBRanker":
dtrain = await DaskDMatrix(
client=self.client,
data=X,
label=y,
qid=qid,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=self.missing,
)
params = self.get_xgb_params()
evals = await _evaluation_matrices(
self.client,
eval_set,
sample_weight_eval_set,
sample_qid=eval_qid,
missing=self.missing,
)
if eval_metric is not None:
if callable(eval_metric):
raise ValueError(
'Custom evaluation metric is not yet supported for XGBRanker.')
model, metric, params = self._configure_fit(
booster=xgb_model,
eval_metric=eval_metric,
params=params
)
results = await train(
client=self.client,
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
self.evals_result_ = results["history"]
return self
@_deprecate_positional_args
def fit( # pylint: disable=arguments-differ
self,
X: _DaskCollection,
y: _DaskCollection,
*,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
sample_weight: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskCollection]] = None,
eval_qid: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[str, List[str], Metric]] = None,
early_stopping_rounds: int = None,
verbose: bool = False,
xgb_model: Optional[Union[XGBModel, Booster]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRanker":
_assert_dask_support()
msg = "Use `qid` instead of `group` on dask interface."
if not (group is None and eval_group is None):
raise ValueError(msg)
if qid is None:
raise ValueError("`qid` is required for ranking.")
return self.client.sync(
self._fit_async,
X=X,
y=y,
qid=qid,
sample_weight=sample_weight,
base_margin=base_margin,
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
eval_qid=eval_qid,
eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds,
verbose=verbose,
xgb_model=xgb_model,
feature_weights=feature_weights,
callbacks=callbacks,
)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__

View File

@ -7,6 +7,7 @@ import json
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .core import Metric
from .training import train
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
@ -132,6 +133,10 @@ __model_doc = '''
importance_type: string, default "gain"
The feature importance type for the feature_importances\\_ property:
either "gain", "weight", "cover", "total_gain" or "total_cover".
gpu_id :
Device ordinal.
validate_parameters :
Give warnings for unknown parameter.
\\*\\*kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of
@ -211,20 +216,41 @@ Parameters
['estimators', 'model', 'objective'])
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
def __init__(self, max_depth=None, learning_rate=None, n_estimators=100,
verbosity=None, objective=None, booster=None,
tree_method=None, n_jobs=None, gamma=None,
min_child_weight=None, max_delta_step=None, subsample=None,
colsample_bytree=None, colsample_bylevel=None,
colsample_bynode=None, reg_alpha=None, reg_lambda=None,
scale_pos_weight=None, base_score=None, random_state=None,
missing=np.nan, num_parallel_tree=None,
monotone_constraints=None, interaction_constraints=None,
importance_type="gain", gpu_id=None,
validate_parameters=None, **kwargs):
def __init__(
self,
max_depth=None,
learning_rate=None,
n_estimators=100,
verbosity=None,
objective=None,
booster=None,
tree_method=None,
n_jobs=None,
gamma=None,
min_child_weight=None,
max_delta_step=None,
subsample=None,
colsample_bytree=None,
colsample_bylevel=None,
colsample_bynode=None,
reg_alpha=None,
reg_lambda=None,
scale_pos_weight=None,
base_score=None,
random_state=None,
missing=np.nan,
num_parallel_tree=None,
monotone_constraints=None,
interaction_constraints=None,
importance_type="gain",
gpu_id=None,
validate_parameters=None,
**kwargs
):
if not SKLEARN_INSTALLED:
raise XGBoostError(
'sklearn needs to be installed in order to use this module')
"sklearn needs to be installed in order to use this module"
)
self.n_estimators = n_estimators
self.objective = objective
@ -255,11 +281,23 @@ class XGBModel(XGBModelBase):
self.gpu_id = gpu_id
self.validate_parameters = validate_parameters
def _wrap_evaluation_matrices(self, X, y, group,
sample_weight, base_margin, feature_weights,
eval_set, sample_weight_eval_set, eval_group,
label_transform=lambda x: x):
'''Convert array_like evaluation matrices into DMatrix'''
def _wrap_evaluation_matrices(
self, X, y,
group,
qid,
sample_weight,
base_margin,
feature_weights,
eval_set,
sample_weight_eval_set,
eval_group,
eval_qid,
label_transform=lambda x: x
) -> Tuple[DMatrix, Optional[List[Tuple[DMatrix, str]]]]:
'''Convert array_like evaluation matrices into DMatrix, group and qid are only used for
valiation but not set in this function
'''
if sample_weight_eval_set is not None:
assert eval_set is not None
assert len(sample_weight_eval_set) == len(eval_set)
@ -271,26 +309,32 @@ class XGBModel(XGBModelBase):
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
train_dmatrix.set_info(feature_weights=feature_weights, group=group)
train_dmatrix.set_info(feature_weights=feature_weights)
if eval_set is not None:
if sample_weight_eval_set is None:
sample_weight_eval_set = [None] * len(eval_set)
if eval_group is None:
eval_group = [None] * len(eval_set)
if eval_qid is None:
eval_qid = [None] * len(eval_set)
evals = []
for i, (valid_X, valid_y) in enumerate(eval_set):
# Skip the duplicated entry.
if valid_X is X and valid_y is y and \
sample_weight_eval_set[i] is sample_weight and eval_group[i] is group:
if (
valid_X is X and
valid_y is y and
sample_weight_eval_set[i] is sample_weight and
eval_group[i] is group and
eval_qid[i] is qid
):
evals.append(train_dmatrix)
else:
m = DMatrix(valid_X,
label=label_transform(valid_y),
missing=self.missing, weight=sample_weight_eval_set[i],
nthread=self.n_jobs)
m.set_info(group=eval_group[i])
evals.append(m)
nevals = len(evals)
@ -515,7 +559,7 @@ class XGBModel(XGBModelBase):
booster: Optional[Booster],
eval_metric: Optional[Union[Callable, str, List[str]]],
params: Dict[str, Any],
) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]:
) -> Tuple[Booster, Optional[Metric], Dict[str, Any]]:
# pylint: disable=protected-access, no-self-use
model = booster
if hasattr(model, '_Booster'):
@ -602,8 +646,6 @@ class XGBModel(XGBModelBase):
save_best=True)]
"""
self.n_features_in_ = X.shape[1]
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
base_margin=base_margin,
missing=self.missing,
@ -613,9 +655,12 @@ class XGBModel(XGBModelBase):
evals_result = {}
train_dmatrix, evals = self._wrap_evaluation_matrices(
X, y, group=None, sample_weight=sample_weight, base_margin=base_margin,
X, y, group=None, qid=None, sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights, eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set, eval_group=None)
sample_weight_eval_set=sample_weight_eval_set,
eval_group=None, eval_qid=None
)
params = self.get_xgb_params()
if callable(self.objective):
@ -640,10 +685,6 @@ class XGBModel(XGBModelBase):
evals_result_key]
self.evals_result_ = evals_result
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
self.best_ntree_limit = self._Booster.best_ntree_limit
return self
def predict(self, data, output_margin=False, ntree_limit=None,
@ -663,16 +704,20 @@ class XGBModel(XGBModelBase):
Parameters
----------
data : numpy.array/scipy.sparse
data : array_like
Data to predict with
output_margin : bool
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if defined
(i.e. it has been trained with early stopping), otherwise 0 (use all trees).
Limit number of trees in the prediction; defaults to best_ntree_limit if
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical.
Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like
Margin added to prediction.
Returns
-------
prediction : numpy array
@ -754,6 +799,32 @@ class XGBModel(XGBModelBase):
return evals_result
@property
def n_features_in_(self) -> int:
booster = self.get_booster()
return booster.num_features()
def _early_stopping_attr(self, attr: str) -> Union[float, int]:
booster = self.get_booster()
try:
return getattr(booster, attr)
except AttributeError as e:
raise AttributeError(
f'`{attr}` in only defined when early stopping is used.'
) from e
@property
def best_score(self) -> float:
return float(self._early_stopping_attr('best_score'))
@property
def best_iteration(self) -> int:
return int(self._early_stopping_attr('best_iteration'))
@property
def best_ntree_limit(self) -> int:
return int(self._early_stopping_attr('best_ntree_limit'))
@property
def feature_importances_(self):
"""
@ -941,14 +1012,17 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
raise ValueError(
'Please reshape the input data X into 2-dimensional matrix.')
self._features_count = X.shape[1]
self.n_features_in_ = self._features_count
train_dmatrix, evals = self._wrap_evaluation_matrices(
X, y, group=None, sample_weight=sample_weight, base_margin=base_margin,
X, y, group=None, qid=None,
sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set,
eval_group=None, label_transform=label_transform)
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
eval_group=None,
eval_qid=None,
label_transform=label_transform
)
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(),
@ -968,11 +1042,6 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
evals_result_key] = val[1][evals_result_key]
self.evals_result_ = evals_result
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
self.best_ntree_limit = self._Booster.best_ntree_limit
return self
fit.__doc__ = XGBModel.fit.__doc__.replace(
@ -1058,6 +1127,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
validate_features : bool
When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like
Margin added to prediction.
Returns
-------
@ -1198,11 +1269,12 @@ class XGBRFRegressor(XGBRegressor):
Note
----
Query group information is required for ranking tasks.
Query group information is required for ranking tasks by either using the `group`
parameter or `qid` parameter in `fit` method.
Before fitting the model, your data need to be sorted by query
group. When fitting the model, you need to provide an additional array
that contains the size of each query group.
Before fitting the model, your data need to be sorted by query group. When fitting
the model, you need to provide an additional array that contains the size of each
query group.
For example, if your original data look like:
@ -1224,7 +1296,8 @@ class XGBRFRegressor(XGBRegressor):
| 2 | 1 | x_7 |
+-------+-----------+---------------+
then your group array should be ``[3, 4]``.
then your group array should be ``[3, 4]``. Sometimes using query id (`qid`)
instead of group can be more convenient.
''')
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@ -1238,11 +1311,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
raise ValueError("please use XGBRanker for ranking task")
@_deprecate_positional_args
def fit(self, X, y, *, group, sample_weight=None, base_margin=None,
eval_set=None, sample_weight_eval_set=None,
eval_group=None, eval_metric=None,
early_stopping_rounds=None, verbose=False, xgb_model=None,
feature_weights=None, callbacks=None):
def fit(
self, X, y, *,
group=None,
qid=None,
sample_weight=None,
base_margin=None,
eval_set=None,
sample_weight_eval_set=None,
eval_group=None,
eval_qid=None,
eval_metric=None,
early_stopping_rounds=None,
verbose=False,
xgb_model=None,
feature_weights=None,
callbacks=None
) -> "XGBRanker":
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""Fit gradient boosting ranker
@ -1257,17 +1342,21 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
y : array_like
Labels
group : array_like
Size of each query group of training data. Should have as many
elements as the query groups in the training data
Size of each query group of training data. Should have as many elements as the
query groups in the training data. If this is set to None, then user must
provide qid.
qid : array_like
Query ID for each training sample. Should have the size of n_samples. If
this is set to None, then user must provide group.
sample_weight : array_like
Query group weights
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each query group
(not each data point). This is because we only care about the
relative ordering of data points within each group, so it
doesn't make sense to assign weights to individual data points.
In ranking task, one weight is assigned to each query group/id (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign weights
to individual data points.
base_margin : array_like
Global bias for each instance.
@ -1289,6 +1378,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_group : list of arrays, optional
A list in which ``eval_group[i]`` is the list containing the sizes of all
query groups in the ``i``-th pair in **eval_set**.
eval_qid : list of array_like, optional
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
pair in **eval_set**.
eval_metric : str, list of str, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.rst.
@ -1333,23 +1425,60 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
raise ValueError("group is required for ranking task")
if eval_set is not None:
if eval_group is None:
if eval_group is None and eval_qid is None:
raise ValueError(
"eval_group is required if eval_set is not None")
if len(eval_group) != len(eval_set):
"eval_group or eval_qid is required if eval_set is not None")
if (
(eval_group is not None and len(eval_group) != len(eval_set)) and
(eval_qid is not None and len(eval_qid) != len(eval_set))
):
raise ValueError(
"length of eval_group should match that of eval_set")
if any(group is None for group in eval_group):
raise ValueError(
"group is required for all eval datasets for ranking task")
self.n_features_in_ = X.shape[1]
"length of eval_group or eval_qid should match that of eval_set"
)
for i in range(len(eval_set)):
if (
(eval_group is not None and eval_group[i] is not None) and
(eval_qid is not None and eval_qid[i] is not None)
):
raise ValueError(
"Only one of the qid or group should be provided."
)
train_dmatrix, evals = self._wrap_evaluation_matrices(
X, y, group=group, sample_weight=sample_weight, base_margin=base_margin,
feature_weights=feature_weights, eval_set=eval_set,
X, y,
group=group,
qid=qid,
sample_weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
eval_set=eval_set,
sample_weight_eval_set=sample_weight_eval_set,
eval_group=eval_group)
eval_group=eval_group,
eval_qid=eval_qid
)
if qid is not None:
train_dmatrix.set_info(qid=qid)
elif group is not None:
train_dmatrix.set_info(group=group)
else:
raise ValueError("Either qid or group should be provided for ranking task.")
if evals is not None:
for i, e in enumerate(evals):
if eval_qid is not None and eval_qid[i] is not None:
assert eval_group is None or eval_group[i] is None, (
'Only one of the eval_qid or eval_group for each evaluation '
'dataset should be provided.'
)
e[0].set_info(qid=qid)
elif eval_group is not None and eval_group[i] is not None:
e[0].set_info(group=eval_group[i])
else:
raise ValueError(
'Either eval_qid or eval_group for each evaluation dataset should'
' be provided for ranking task.'
)
evals_result = {}
params = self.get_xgb_params()
@ -1380,11 +1509,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
evals_result[val[0]][evals_result_key] = val[1][evals_result_key]
self.evals_result = evals_result
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
self.best_ntree_limit = self._Booster.best_ntree_limit
return self
def predict(self, data, output_margin=False,

View File

@ -111,6 +111,7 @@ def _train_internal(params, dtrain,
)
else:
raise ValueError(f'Unknown booster: {booster}')
num_groups = int(config['learner']['learner_model_param']['num_class'])
num_groups = 1 if num_groups == 0 else num_groups
bst.best_ntree_limit = ((bst.best_iteration + 1) * num_parallel_tree * num_groups)

View File

@ -498,6 +498,7 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Learner*>(handle)->Configure();
*out = static_cast<Learner*>(handle)->GetNumFeature();
API_END();
}

View File

@ -374,13 +374,32 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
} else if (!std::strcmp(key, "group")) {
group_ptr_.resize(num + 1);
group_ptr_.clear(); group_ptr_.resize(num + 1, 0);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1));
group_ptr_[0] = 0;
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
} else if (!std::strcmp(key, "qid")) {
std::vector<uint32_t> query_ids(num, 0);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, query_ids.begin()));
bool non_dec = true;
for (size_t i = 1; i < query_ids.size(); ++i) {
if (query_ids[i] < query_ids[i-1]) {
non_dec = false;
break;
}
}
CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data.";
group_ptr_.clear(); group_ptr_.push_back(0);
for (size_t i = 1; i < query_ids.size(); ++i) {
if (query_ids[i] != query_ids[i-1]) {
group_ptr_.push_back(i);
}
}
group_ptr_.push_back(query_ids.size());
} else if (!std::strcmp(key, "label_lower_bound")) {
auto& labels = labels_lower_bound_.HostVector();
labels.resize(num);

View File

@ -34,16 +34,20 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
});
}
namespace {
auto SetDeviceToPtr(void *ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
}
} // anonymous namespace
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
<< "Expected integer metainfo";
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();
@ -95,6 +99,47 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
return;
} else if (key == "qid") {
auto it = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) {
return static_cast<uint32_t>(array_interface.GetElement(i));
});
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
d_flag[0] = false;
}
});
bool non_dec = true;
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost));
CHECK(non_dec)
<< "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
HostDeviceVector<int> d_num_runs_out(1, 0, d);
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
dh::caching_device_vector<char> tmp(bytes);
cub::DeviceRunLengthEncode::Encode(tmp.data().get(), bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
return;

View File

@ -436,7 +436,7 @@ class LearnerConfiguration : public Learner {
}
}
uint32_t GetNumFeature() override {
uint32_t GetNumFeature() const override {
return learner_model_param_.num_feature;
}

View File

@ -63,7 +63,7 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
template <typename T>
Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
thrust::device_vector<T>* p_data) {
thrust::device_vector<T> *p_data) {
auto& data = *p_data;
thrust::sequence(data.begin(), data.end());

View File

@ -202,6 +202,24 @@ TEST(MetaInfo, LoadQid) {
}
}
TEST(MetaInfo, CPUQid) {
xgboost::MetaInfo info;
info.num_row_ = 100;
std::vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) {
qid[i] = i;
}
info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
for (size_t i = 0; i < info.num_row_ + 1; ++i) {
ASSERT_EQ(info.group_ptr_[i], i);
}
}
TEST(MetaInfo, Validate) {
xgboost::MetaInfo info;
info.num_row_ = 10;

View File

@ -4,6 +4,7 @@
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include "test_array_interface.h"
#include "../../../src/common/device_helpers.cuh"
namespace xgboost {
@ -105,6 +106,28 @@ TEST(MetaInfo, Group) {
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
}
TEST(MetaInfo, GPUQid) {
xgboost::MetaInfo info;
info.num_row_ = 100;
thrust::device_vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) {
qid[i] = i;
}
auto column = Generate2dArrayInterface(info.num_row_, 1, "<u4", &qid);
Json array{std::vector<Json>{column}};
std::string array_str;
Json::Dump(array, &array_str);
info.SetInfo("qid", array_str.c_str());
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
for (size_t i = 0; i < info.num_row_ + 1; ++i) {
ASSERT_EQ(info.group_ptr_[i], i);
}
}
TEST(MetaInfo, DeviceExtend) {
dh::safe_cuda(cudaSetDevice(0));
size_t const kRows = 100;

View File

@ -171,6 +171,22 @@ Arrow specification.'''
with pytest.raises(xgb.core.XGBoostError):
m.slice(rindex=[0, 1, 2])
@pytest.mark.skipif(**tm.no_cupy())
def test_qid(self):
import cupy as cp
rng = cp.random.RandomState(1994)
rows = 100
cols = 10
X, y = rng.randn(rows, cols), rng.randn(rows)
qid = rng.randint(low=0, high=10, size=rows, dtype=np.uint32)
qid = cp.sort(qid)
Xy = xgb.DMatrix(X, y)
Xy.set_info(qid=qid)
group_ptr = Xy.get_uint_info('group_ptr')
assert group_ptr[0] == 0
assert group_ptr[-1] == rows
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_specified_device(self):

View File

@ -239,6 +239,19 @@ class TestDMatrix:
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('group_ptr')
def test_qid(self):
rows = 100
cols = 10
X, y = rng.randn(rows, cols), rng.randn(rows)
qid = rng.randint(low=0, high=10, size=rows, dtype=np.uint32)
qid = np.sort(qid)
Xy = xgb.DMatrix(X, y)
Xy.set_info(qid=qid)
group_ptr = Xy.get_uint_info('group_ptr')
assert group_ptr[0] == 0
assert group_ptr[-1] == rows
def test_feature_weights(self):
kRows = 10
kCols = 50

View File

@ -1,5 +1,6 @@
import numpy as np
from scipy.sparse import csr_matrix
import testing as tm
import xgboost
import os
import itertools
@ -79,22 +80,10 @@ class TestRanking:
"""
Download and setup the test fixtures
"""
from sklearn.datasets import load_svmlight_files
# download the test data
cls.dpath = 'demo/rank/'
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = cls.dpath + '/MQ2008.zip'
urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f:
f.extractall(path=cls.dpath)
(x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid) = load_svmlight_files(
(cls.dpath + "MQ2008/Fold1/train.txt",
cls.dpath + "MQ2008/Fold1/test.txt",
cls.dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)

View File

@ -5,6 +5,7 @@ import pytest
import xgboost as xgb
import sys
import numpy as np
import scipy
import json
from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio
@ -670,12 +671,56 @@ def run_aft_survival(client: "Client", dmatrix_t: Type) -> None:
assert nloglik_rec['extreme'][-1] > 4.9
def test_aft_survival() -> None:
def test_dask_aft_survival() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
run_aft_survival(client, DaskDMatrix)
def test_dask_ranking(client: "Client") -> None:
dpath = "demo/rank/"
mq2008 = tm.get_mq2008(dpath)
data = []
for d in mq2008:
if isinstance(d, scipy.sparse.csr_matrix):
d[d == 0] = np.inf
d = d.toarray()
d[d == 0] = np.nan
d[np.isinf(d)] = 0
data.append(da.from_array(d))
else:
data.append(da.from_array(d))
(
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = data
qid_train = qid_train.astype(np.uint32)
qid_valid = qid_valid.astype(np.uint32)
qid_test = qid_test.astype(np.uint32)
rank = xgb.dask.DaskXGBRanker(n_estimators=2500)
rank.fit(
x_train,
y_train,
qid=qid_train,
eval_set=[(x_test, y_test), (x_train, y_train)],
eval_qid=[qid_test, qid_train],
eval_metric=["ndcg"],
verbose=True,
early_stopping_rounds=10,
)
assert rank.n_features_in_ == 46
assert rank.best_score > 0.98
class TestWithDask:
def test_global_config(self, client: "Client") -> None:
X, y, _ = generate_array()
@ -981,7 +1026,7 @@ class TestWithDask:
def test_shap(self, client: "Client") -> None:
from sklearn.datasets import load_boston, load_digits
X, y = load_boston(return_X_y=True)
params = {'objective': 'reg:squarederror'}
params: Dict[str, Any] = {'objective': 'reg:squarederror'}
self.run_shap(X, y, params, client)
X, y = load_digits(return_X_y=True)

View File

@ -125,9 +125,11 @@ def test_ranking():
x_train = np.random.rand(1000, 10)
y_train = np.random.randint(5, size=1000)
train_group = np.repeat(50, 20)
x_valid = np.random.rand(200, 10)
y_valid = np.random.randint(5, size=200)
valid_group = np.repeat(50, 4)
x_test = np.random.rand(100, 10)
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
@ -136,6 +138,7 @@ def test_ranking():
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, group=train_group,
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
pred = model.predict(x_test)
train_data = xgb.DMatrix(x_train, y_train)

View File

@ -1,5 +1,7 @@
# coding: utf-8
import os
import urllib
import zipfile
import sys
from contextlib import contextmanager
from io import StringIO
@ -209,6 +211,29 @@ def get_sparse():
return X, y
@memory.cache
def get_mq2008(dpath):
from sklearn.datasets import load_svmlight_files
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = dpath + '/MQ2008.zip'
if not os.path.exists(target):
urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f:
f.extractall(path=dpath)
(x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid) = load_svmlight_files(
(dpath + "MQ2008/Fold1/train.txt",
dpath + "MQ2008/Fold1/test.txt",
dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
return (x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid)
_unweighted_datasets_strategy = strategies.sampled_from(
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),