Support sklearn cross validation for ranker. (#8859)
* Support sklearn cross validation for ranker. - Add a convention for X to include a special `qid` column. sklearn utilities consider only `X`, `y` and `sample_weight` for supervised learning algorithms, but we need an additional qid array for ranking. It's important to be able to support the cross validation function in sklearn since all other tuning functions like grid search are based on cross validation.
This commit is contained in:
parent
cad7401783
commit
7eba285a1e
@ -23,7 +23,13 @@ from typing import (
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import collective
|
from . import collective
|
||||||
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
|
from .core import (
|
||||||
|
Booster,
|
||||||
|
DMatrix,
|
||||||
|
XGBoostError,
|
||||||
|
_get_booster_layer_trees,
|
||||||
|
_parse_eval_str,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingCallback",
|
"TrainingCallback",
|
||||||
@ -250,11 +256,7 @@ class CallbackContainer:
|
|||||||
for _, name in evals:
|
for _, name in evals:
|
||||||
assert name.find("-") == -1, "Dataset name should not contain `-`"
|
assert name.find("-") == -1, "Dataset name should not contain `-`"
|
||||||
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||||
splited = score.split()[1:] # into datasets
|
metric_score = _parse_eval_str(score)
|
||||||
# split up `test-error:0.1234`
|
|
||||||
metric_score_str = [tuple(s.split(":")) for s in splited]
|
|
||||||
# convert to float
|
|
||||||
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
|
||||||
self._update_history(metric_score, epoch)
|
self._update_history(metric_score, epoch)
|
||||||
ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
|
ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@ -231,7 +231,7 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
|
|||||||
if buf.base is data.base:
|
if buf.base is data.base:
|
||||||
buf = buf.copy()
|
buf = buf.copy()
|
||||||
if buf.dtype not in DTYPE_ENUM__:
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
raise Exception(f"data type {buf.dtype} not supported")
|
raise TypeError(f"data type {buf.dtype} not supported")
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGCommunicatorAllreduce(
|
_LIB.XGCommunicatorAllreduce(
|
||||||
buf.ctypes.data_as(ctypes.c_void_p),
|
buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
|||||||
@ -111,6 +111,16 @@ def make_jcargs(**kwargs: Any) -> bytes:
|
|||||||
return from_pystr_to_cstr(json.dumps(kwargs))
|
return from_pystr_to_cstr(json.dumps(kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_eval_str(result: str) -> List[Tuple[str, float]]:
|
||||||
|
"""Parse an eval result string from the booster."""
|
||||||
|
splited = result.split()[1:]
|
||||||
|
# split up `test-error:0.1234`
|
||||||
|
metric_score_str = [tuple(s.split(":")) for s in splited]
|
||||||
|
# convert to float
|
||||||
|
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
||||||
|
return metric_score
|
||||||
|
|
||||||
|
|
||||||
IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int])
|
IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -136,7 +136,7 @@ def allreduce( # pylint:disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
if prepare_fun is None:
|
if prepare_fun is None:
|
||||||
return collective.allreduce(data, collective.Op(op))
|
return collective.allreduce(data, collective.Op(op))
|
||||||
raise Exception("preprocessing function is no longer supported")
|
raise ValueError("preprocessing function is no longer supported")
|
||||||
|
|
||||||
|
|
||||||
def version_number() -> int:
|
def version_number() -> int:
|
||||||
|
|||||||
@ -43,8 +43,9 @@ from .core import (
|
|||||||
XGBoostError,
|
XGBoostError,
|
||||||
_convert_ntree_limit,
|
_convert_ntree_limit,
|
||||||
_deprecate_positional_args,
|
_deprecate_positional_args,
|
||||||
|
_parse_eval_str,
|
||||||
)
|
)
|
||||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
|
||||||
from .training import train
|
from .training import train
|
||||||
|
|
||||||
|
|
||||||
@ -1812,32 +1813,43 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def _get_qid(
|
||||||
|
X: ArrayLike, qid: Optional[ArrayLike]
|
||||||
|
) -> Tuple[ArrayLike, Optional[ArrayLike]]:
|
||||||
|
"""Get the special qid column from X if exists."""
|
||||||
|
if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"):
|
||||||
|
if qid is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Found both the special column `qid` in `X` and the `qid` from the"
|
||||||
|
"`fit` method. Please remove one of them."
|
||||||
|
)
|
||||||
|
q_x = X.qid
|
||||||
|
X = X.drop("qid", axis=1)
|
||||||
|
return X, q_x
|
||||||
|
return X, qid
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
|
"""Implementation of the Scikit-Learn API for XGBoost Ranking.""",
|
||||||
["estimators", "model"],
|
["estimators", "model"],
|
||||||
end_note="""
|
end_note="""
|
||||||
.. note::
|
|
||||||
|
|
||||||
The default objective for XGBRanker is "rank:pairwise"
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
A custom objective function is currently not supported by XGBRanker.
|
A custom objective function is currently not supported by XGBRanker.
|
||||||
Likewise, a custom metric function is not supported either.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Query group information is required for ranking tasks by either using the
|
Query group information is only required for ranking training but not
|
||||||
`group` parameter or `qid` parameter in `fit` method. This information is
|
prediction. Multiple groups can be predicted on a single call to
|
||||||
not required in 'predict' method and multiple groups can be predicted on
|
:py:meth:`predict`.
|
||||||
a single call to `predict`.
|
|
||||||
|
|
||||||
When fitting the model with the `group` parameter, your data need to be sorted
|
When fitting the model with the `group` parameter, your data need to be sorted
|
||||||
by query group first. `group` must be an array that contains the size of each
|
by the query group first. `group` is an array that contains the size of each
|
||||||
query group.
|
query group.
|
||||||
When fitting the model with the `qid` parameter, your data does not need
|
|
||||||
sorting. `qid` must be an array that contains the group of each training
|
Similarly, when fitting the model with the `qid` parameter, the data should be
|
||||||
sample.
|
sorted according to query index and `qid` is an array that contains the query
|
||||||
|
index for each training sample.
|
||||||
|
|
||||||
For example, if your original data look like:
|
For example, if your original data look like:
|
||||||
|
|
||||||
@ -1859,9 +1871,10 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
| 2 | 1 | x_7 |
|
| 2 | 1 | x_7 |
|
||||||
+-------+-----------+---------------+
|
+-------+-----------+---------------+
|
||||||
|
|
||||||
then `fit` method can be called with either `group` array as ``[3, 4]``
|
then :py:meth:`fit` method can be called with either `group` array as ``[3, 4]``
|
||||||
or with `qid` as ``[`1, 1, 1, 2, 2, 2, 2]``, that is the qid column.
|
or with `qid` as ``[1, 1, 1, 2, 2, 2, 2]``, that is the qid column. Also, the
|
||||||
""",
|
`qid` can be a special column of input `X` instead of a separated parameter, see
|
||||||
|
:py:meth:`fit` for more info.""",
|
||||||
)
|
)
|
||||||
class XGBRanker(XGBModel, XGBRankerMixIn):
|
class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
@ -1873,6 +1886,16 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
if "rank:" not in objective:
|
if "rank:" not in objective:
|
||||||
raise ValueError("please use XGBRanker for ranking task")
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
|
def _create_ltr_dmatrix(
|
||||||
|
self, ref: Optional[DMatrix], data: ArrayLike, qid: ArrayLike, **kwargs: Any
|
||||||
|
) -> DMatrix:
|
||||||
|
data, qid = _get_qid(data, qid)
|
||||||
|
|
||||||
|
if kwargs.get("group", None) is None and qid is None:
|
||||||
|
raise ValueError("Either `group` or `qid` is required for ranking task")
|
||||||
|
|
||||||
|
return super()._create_dmatrix(ref=ref, data=data, qid=qid, **kwargs)
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
@ -1907,6 +1930,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
X :
|
X :
|
||||||
Feature matrix. See :ref:`py-data` for a list of supported types.
|
Feature matrix. See :ref:`py-data` for a list of supported types.
|
||||||
|
|
||||||
|
When this is a :py:class:`pandas.DataFrame` or a :py:class:`cudf.DataFrame`,
|
||||||
|
it may contain a special column called ``qid`` for specifying the query
|
||||||
|
index. Using a special column is the same as using the `qid` parameter,
|
||||||
|
except for being compatible with sklearn utility functions like
|
||||||
|
:py:func:`sklearn.model_selection.cross_validation`. The same convention
|
||||||
|
applies to the :py:meth:`XGBRanker.score` and :py:meth:`XGBRanker.predict`.
|
||||||
|
|
||||||
|
+-----+----------------+----------------+
|
||||||
|
| qid | feat_0 | feat_1 |
|
||||||
|
+-----+----------------+----------------+
|
||||||
|
| 0 | :math:`x_{00}` | :math:`x_{01}` |
|
||||||
|
+-----+----------------+----------------+
|
||||||
|
| 1 | :math:`x_{10}` | :math:`x_{11}` |
|
||||||
|
+-----+----------------+----------------+
|
||||||
|
| 1 | :math:`x_{20}` | :math:`x_{21}` |
|
||||||
|
+-----+----------------+----------------+
|
||||||
|
|
||||||
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
||||||
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
||||||
for conserving memory. However, this has performance implications when the
|
for conserving memory. However, this has performance implications when the
|
||||||
@ -1916,12 +1956,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
y :
|
y :
|
||||||
Labels
|
Labels
|
||||||
group :
|
group :
|
||||||
Size of each query group of training data. Should have as many elements as the
|
Size of each query group of training data. Should have as many elements as
|
||||||
query groups in the training data. If this is set to None, then user must
|
the query groups in the training data. If this is set to None, then user
|
||||||
provide qid.
|
must provide qid.
|
||||||
qid :
|
qid :
|
||||||
Query ID for each training sample. Should have the size of n_samples. If
|
Query ID for each training sample. Should have the size of n_samples. If
|
||||||
this is set to None, then user must provide group.
|
this is set to None, then user must provide group or a special column in X.
|
||||||
sample_weight :
|
sample_weight :
|
||||||
Query group weights
|
Query group weights
|
||||||
|
|
||||||
@ -1929,8 +1969,9 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
|
|
||||||
In ranking task, one weight is assigned to each query group/id (not each
|
In ranking task, one weight is assigned to each query group/id (not each
|
||||||
data point). This is because we only care about the relative ordering of
|
data point). This is because we only care about the relative ordering of
|
||||||
data points within each group, so it doesn't make sense to assign weights
|
data points within each group, so it doesn't make sense to assign
|
||||||
to individual data points.
|
weights to individual data points.
|
||||||
|
|
||||||
base_margin :
|
base_margin :
|
||||||
Global bias for each instance.
|
Global bias for each instance.
|
||||||
eval_set :
|
eval_set :
|
||||||
@ -1942,7 +1983,8 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
query groups in the ``i``-th pair in **eval_set**.
|
query groups in the ``i``-th pair in **eval_set**.
|
||||||
eval_qid :
|
eval_qid :
|
||||||
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
|
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
|
||||||
pair in **eval_set**.
|
pair in **eval_set**. The special column convention in `X` applies to
|
||||||
|
validation datasets as well.
|
||||||
|
|
||||||
eval_metric : str, list of str, optional
|
eval_metric : str, list of str, optional
|
||||||
.. deprecated:: 1.6.0
|
.. deprecated:: 1.6.0
|
||||||
@ -1985,16 +2027,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check if group information is provided
|
|
||||||
with config_context(verbosity=self.verbosity):
|
with config_context(verbosity=self.verbosity):
|
||||||
if group is None and qid is None:
|
|
||||||
raise ValueError("group or qid is required for ranking task")
|
|
||||||
|
|
||||||
if eval_set is not None:
|
|
||||||
if eval_group is None and eval_qid is None:
|
|
||||||
raise ValueError(
|
|
||||||
"eval_group or eval_qid is required if eval_set is not None"
|
|
||||||
)
|
|
||||||
train_dmatrix, evals = _wrap_evaluation_matrices(
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
X=X,
|
X=X,
|
||||||
@ -2009,7 +2042,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
base_margin_eval_set=base_margin_eval_set,
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
eval_group=eval_group,
|
eval_group=eval_group,
|
||||||
eval_qid=eval_qid,
|
eval_qid=eval_qid,
|
||||||
create_dmatrix=self._create_dmatrix,
|
create_dmatrix=self._create_ltr_dmatrix,
|
||||||
enable_categorical=self.enable_categorical,
|
enable_categorical=self.enable_categorical,
|
||||||
feature_types=self.feature_types,
|
feature_types=self.feature_types,
|
||||||
)
|
)
|
||||||
@ -2044,3 +2077,59 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
|
|
||||||
self._set_evaluation_result(evals_result)
|
self._set_evaluation_result(evals_result)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
X: ArrayLike,
|
||||||
|
output_margin: bool = False,
|
||||||
|
ntree_limit: Optional[int] = None,
|
||||||
|
validate_features: bool = True,
|
||||||
|
base_margin: Optional[ArrayLike] = None,
|
||||||
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> ArrayLike:
|
||||||
|
X, _ = _get_qid(X, None)
|
||||||
|
return super().predict(
|
||||||
|
X,
|
||||||
|
output_margin,
|
||||||
|
ntree_limit,
|
||||||
|
validate_features,
|
||||||
|
base_margin,
|
||||||
|
iteration_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
X: ArrayLike,
|
||||||
|
ntree_limit: int = 0,
|
||||||
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
|
) -> ArrayLike:
|
||||||
|
X, _ = _get_qid(X, None)
|
||||||
|
return super().apply(X, ntree_limit, iteration_range)
|
||||||
|
|
||||||
|
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
||||||
|
"""Evaluate score for data using the last evaluation metric.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : pd.DataFrame|cudf.DataFrame
|
||||||
|
Feature matrix. A DataFrame with a special `qid` column.
|
||||||
|
|
||||||
|
y :
|
||||||
|
Labels
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
score :
|
||||||
|
The result of the first evaluation metric for the ranker.
|
||||||
|
|
||||||
|
"""
|
||||||
|
X, qid = _get_qid(X, None)
|
||||||
|
Xyq = DMatrix(X, y, qid=qid)
|
||||||
|
if callable(self.eval_metric):
|
||||||
|
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
|
||||||
|
result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric)
|
||||||
|
else:
|
||||||
|
result_str = self.get_booster().eval(Xyq)
|
||||||
|
|
||||||
|
metric_score = _parse_eval_str(result_str)
|
||||||
|
return metric_score[-1][1]
|
||||||
|
|||||||
72
python-package/xgboost/testing/ranking.py
Normal file
72
python-package/xgboost/testing/ranking.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# pylint: disable=too-many-locals
|
||||||
|
"""Tests for learning to rank."""
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import testing as tm
|
||||||
|
|
||||||
|
|
||||||
|
def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
|
||||||
|
"""Test ranking with qid packed into X."""
|
||||||
|
import scipy.sparse
|
||||||
|
from sklearn.metrics import mean_squared_error
|
||||||
|
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
|
||||||
|
|
||||||
|
X, y, q, _ = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3)
|
||||||
|
|
||||||
|
# pack qid into x using dataframe
|
||||||
|
df = impl.DataFrame(X)
|
||||||
|
df["qid"] = q
|
||||||
|
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method)
|
||||||
|
ranker.fit(df, y)
|
||||||
|
s = ranker.score(df, y)
|
||||||
|
assert s > 0.7
|
||||||
|
|
||||||
|
# works with validation datasets as well
|
||||||
|
valid_df = df.copy()
|
||||||
|
valid_df.iloc[0, 0] = 3.0
|
||||||
|
ranker.fit(df, y, eval_set=[(valid_df, y)])
|
||||||
|
|
||||||
|
# same as passing qid directly
|
||||||
|
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method)
|
||||||
|
ranker.fit(X, y, qid=q)
|
||||||
|
s1 = ranker.score(df, y)
|
||||||
|
assert np.isclose(s, s1)
|
||||||
|
|
||||||
|
# Works with standard sklearn cv
|
||||||
|
if tree_method != "gpu_hist":
|
||||||
|
# we need cuML for this.
|
||||||
|
kfold = StratifiedGroupKFold(shuffle=False)
|
||||||
|
results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
|
||||||
|
assert len(results) == 5
|
||||||
|
|
||||||
|
# Works with custom metric
|
||||||
|
def neg_mse(*args: Any, **kwargs: Any) -> float:
|
||||||
|
return -float(mean_squared_error(*args, **kwargs))
|
||||||
|
|
||||||
|
ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
|
||||||
|
ranker.fit(df, y, eval_set=[(valid_df, y)])
|
||||||
|
score = ranker.score(valid_df, y)
|
||||||
|
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])
|
||||||
|
|
||||||
|
# Works with sparse data
|
||||||
|
if tree_method != "gpu_hist":
|
||||||
|
# no sparse with cuDF
|
||||||
|
X_csr = scipy.sparse.csr_matrix(X)
|
||||||
|
df = impl.DataFrame.sparse.from_spmatrix(
|
||||||
|
X_csr, columns=[str(i) for i in range(X.shape[1])]
|
||||||
|
)
|
||||||
|
df["qid"] = q
|
||||||
|
ranker = xgb.XGBRanker(
|
||||||
|
n_estimators=3, eval_metric="ndcg", tree_method=tree_method
|
||||||
|
)
|
||||||
|
ranker.fit(df, y)
|
||||||
|
s2 = ranker.score(df, y)
|
||||||
|
assert np.isclose(s2, s)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Either `group` or `qid`."):
|
||||||
|
ranker.fit(df, y, eval_set=[(X, y)])
|
||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.ranking import run_ranking_qid_df
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
import test_with_sklearn as twskl # noqa
|
import test_with_sklearn as twskl # noqa
|
||||||
@ -153,3 +154,10 @@ def test_classififer():
|
|||||||
y *= 10
|
y *= 10
|
||||||
with pytest.raises(ValueError, match=r"Invalid classes.*"):
|
with pytest.raises(ValueError, match=r"Invalid classes.*"):
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_ranking_qid_df():
|
||||||
|
import cudf
|
||||||
|
|
||||||
|
run_ranking_qid_df(cudf, "gpu_hist")
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from sklearn.utils.estimator_checks import parametrize_with_checks
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.ranking import run_ranking_qid_df
|
||||||
from xgboost.testing.shared import get_feature_weights, validate_data_initialization
|
from xgboost.testing.shared import get_feature_weights, validate_data_initialization
|
||||||
from xgboost.testing.updater import get_basescore
|
from xgboost.testing.updater import get_basescore
|
||||||
|
|
||||||
@ -180,6 +181,13 @@ def test_ranking_metric() -> None:
|
|||||||
assert results["validation_0"]["roc_auc_score"][-1] > 0.6
|
assert results["validation_0"]["roc_auc_score"][-1] > 0.6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_ranking_qid_df():
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
run_ranking_qid_df(pd, "hist")
|
||||||
|
|
||||||
|
|
||||||
def test_stacking_regression():
|
def test_stacking_regression():
|
||||||
from sklearn.datasets import load_diabetes
|
from sklearn.datasets import load_diabetes
|
||||||
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
|
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user