merge latest changes
This commit is contained in:
@@ -15,18 +15,16 @@ class BuildConfiguration: # pylint: disable=R0902
|
||||
use_cuda: bool = False
|
||||
# Whether to enable NCCL
|
||||
use_nccl: bool = False
|
||||
# Whether to load nccl dynamically
|
||||
use_dlopen_nccl: bool = False
|
||||
# Whether to enable federated learning
|
||||
plugin_federated: bool = False
|
||||
# Whether to enable rmm support
|
||||
plugin_rmm: bool = False
|
||||
# Whether to enable HIP
|
||||
use_hip: bool = False
|
||||
# Whether to enable RCCL
|
||||
use_rccl: bool = False
|
||||
# Whether to enable HDFS
|
||||
use_hdfs: bool = False
|
||||
# Whether to enable Azure Storage
|
||||
use_azure: bool = False
|
||||
# Whether to enable AWS S3
|
||||
use_s3: bool = False
|
||||
# Whether to enable the dense parser plugin
|
||||
plugin_dense_parser: bool = False
|
||||
# Special option: See explanation below
|
||||
use_system_libxgboost: bool = False
|
||||
|
||||
|
||||
@@ -29,7 +29,8 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"numpy",
|
||||
"scipy"
|
||||
"scipy",
|
||||
"nvidia-nccl-cu12 ; platform_system == 'Linux' and platform_machine != 'aarch64'"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from enum import IntEnum, unique
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._typing import _T
|
||||
from .core import _LIB, _check_call, c_str, from_pystr_to_cstr, py_str
|
||||
from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.collective]")
|
||||
|
||||
@@ -250,6 +251,31 @@ class CommunicatorContext:
|
||||
|
||||
def __init__(self, **args: Any) -> None:
|
||||
self.args = args
|
||||
key = "dmlc_nccl_path"
|
||||
if args.get(key, None) is not None:
|
||||
return
|
||||
|
||||
binfo = build_info()
|
||||
if not binfo["USE_DLOPEN_NCCL"] and not binfo["USE_DLOPEN_RCCL"]:
|
||||
return
|
||||
|
||||
try:
|
||||
# PyPI package of NCCL.
|
||||
from nvidia.nccl import lib
|
||||
|
||||
# There are two versions of nvidia-nccl, one is from PyPI, another one from
|
||||
# nvidia-pyindex. We support only the first one as the second one is too old
|
||||
# (2.9.8 as of writing).
|
||||
if lib.__file__ is not None:
|
||||
dirname: Optional[str] = os.path.dirname(lib.__file__)
|
||||
else:
|
||||
dirname = None
|
||||
|
||||
if dirname:
|
||||
path = os.path.join(dirname, "libnccl.so.2")
|
||||
self.args[key] = path
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> Dict[str, Any]:
|
||||
init(**self.args)
|
||||
|
||||
@@ -184,6 +184,13 @@ def _py_version() -> str:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def _register_log_callback(lib: ctypes.CDLL) -> None:
|
||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||
lib.callback = _get_log_callback_func() # type: ignore
|
||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Load xgboost Library."""
|
||||
lib_paths = find_lib_path()
|
||||
@@ -228,10 +235,7 @@ Likely causes:
|
||||
Error message(s): {os_error_list}
|
||||
"""
|
||||
)
|
||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||
lib.callback = _get_log_callback_func() # type: ignore
|
||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
_register_log_callback(lib)
|
||||
|
||||
def parse(ver: str) -> Tuple[int, int, int]:
|
||||
"""Avoid dependency on packaging (PEP 440)."""
|
||||
|
||||
@@ -79,7 +79,6 @@ from xgboost.data import _is_cudf_ser, _is_cupy_array
|
||||
from xgboost.sklearn import (
|
||||
XGBClassifier,
|
||||
XGBClassifierBase,
|
||||
XGBClassifierMixIn,
|
||||
XGBModel,
|
||||
XGBRanker,
|
||||
XGBRankerMixIn,
|
||||
@@ -94,6 +93,8 @@ from xgboost.sklearn import (
|
||||
from xgboost.tracker import RabitTracker, get_host_ip
|
||||
from xgboost.training import train as worker_train
|
||||
|
||||
from .utils import get_n_threads
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask
|
||||
import distributed
|
||||
@@ -908,6 +909,34 @@ async def _check_workers_are_alive(
|
||||
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||
|
||||
|
||||
def _get_dmatrices(
|
||||
train_ref: dict,
|
||||
train_id: int,
|
||||
*refs: dict,
|
||||
evals_id: Sequence[int],
|
||||
evals_name: Sequence[str],
|
||||
n_threads: int,
|
||||
) -> Tuple[DMatrix, List[Tuple[DMatrix, str]]]:
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference to evaluation"
|
||||
" `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads, ref=Xy)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
return Xy, evals
|
||||
|
||||
|
||||
async def _train_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
@@ -940,41 +969,20 @@ async def _train_async(
|
||||
) -> Optional[TrainReturnT]:
|
||||
worker = distributed.get_worker()
|
||||
local_param = parameters.copy()
|
||||
n_threads = 0
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if (
|
||||
local_param.get(p, None) is not None
|
||||
and local_param.get(p, dwnt) != dwnt
|
||||
):
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
n_threads = get_n_threads(local_param, worker)
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
|
||||
with CommunicatorContext(**rabit_args), config.config_context(**global_config):
|
||||
Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads)
|
||||
evals: List[Tuple[DMatrix, str]] = []
|
||||
for i, ref in enumerate(refs):
|
||||
if evals_id[i] == train_id:
|
||||
evals.append((Xy, evals_name[i]))
|
||||
continue
|
||||
if ref.get("ref", None) is not None:
|
||||
if ref["ref"] != train_id:
|
||||
raise ValueError(
|
||||
"The training DMatrix should be used as a reference"
|
||||
" to evaluation `QuantileDMatrix`."
|
||||
)
|
||||
del ref["ref"]
|
||||
eval_Xy = _dmatrix_from_list_of_parts(
|
||||
**ref, nthread=n_threads, ref=Xy
|
||||
)
|
||||
else:
|
||||
eval_Xy = _dmatrix_from_list_of_parts(**ref, nthread=n_threads)
|
||||
evals.append((eval_Xy, evals_name[i]))
|
||||
Xy, evals = _get_dmatrices(
|
||||
train_ref,
|
||||
train_id,
|
||||
*refs,
|
||||
evals_id=evals_id,
|
||||
evals_name=evals_name,
|
||||
n_threads=n_threads,
|
||||
)
|
||||
|
||||
booster = worker_train(
|
||||
params=local_param,
|
||||
@@ -1854,7 +1862,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
"Implementation of the scikit-learn API for XGBoost classification.",
|
||||
["estimators", "model"],
|
||||
)
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBase):
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(
|
||||
self,
|
||||
@@ -2036,10 +2044,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierMixIn, XGBClassifierBa
|
||||
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
||||
return preds
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost Ranking.
|
||||
|
||||
24
python-package/xgboost/dask/utils.py
Normal file
24
python-package/xgboost/dask/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Utilities for the XGBoost Dask interface."""
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import distributed
|
||||
|
||||
|
||||
def get_n_threads(local_param: Dict[str, Any], worker: "distributed.Worker") -> int:
|
||||
"""Get the number of threads from a worker and the user-supplied parameters."""
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
n_threads = None
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if local_param.get(p, None) is not None and local_param.get(p, dwnt) != dwnt:
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = dwnt
|
||||
return n_threads
|
||||
@@ -43,19 +43,6 @@ from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
|
||||
from .training import train
|
||||
|
||||
|
||||
class XGBClassifierMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for classification."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _load_model_attributes(self, booster: Booster) -> None:
|
||||
config = json.loads(booster.save_config())
|
||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
# binary classification is treated as regression in XGBoost.
|
||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
|
||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
|
||||
base classes.
|
||||
@@ -808,7 +795,6 @@ class XGBModel(XGBModelBase):
|
||||
"kwargs",
|
||||
"missing",
|
||||
"n_estimators",
|
||||
"use_label_encoder",
|
||||
"enable_categorical",
|
||||
"early_stopping_rounds",
|
||||
"callbacks",
|
||||
@@ -851,21 +837,38 @@ class XGBModel(XGBModelBase):
|
||||
self.get_booster().load_model(fname)
|
||||
|
||||
meta_str = self.get_booster().attr("scikit_learn")
|
||||
if meta_str is None:
|
||||
return
|
||||
if meta_str is not None:
|
||||
meta = json.loads(meta_str)
|
||||
t = meta.get("_estimator_type", None)
|
||||
if t is not None and t != self._get_type():
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. Expecting: "
|
||||
f"{self._get_type()}, got: {t}"
|
||||
)
|
||||
|
||||
meta = json.loads(meta_str)
|
||||
t = meta.get("_estimator_type", None)
|
||||
if t is not None and t != self._get_type():
|
||||
raise TypeError(
|
||||
"Loading an estimator with different type. Expecting: "
|
||||
f"{self._get_type()}, got: {t}"
|
||||
)
|
||||
self.feature_types = self.get_booster().feature_types
|
||||
self.get_booster().set_attr(scikit_learn=None)
|
||||
config = json.loads(self.get_booster().save_config())
|
||||
self._load_model_attributes(config)
|
||||
|
||||
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
|
||||
|
||||
def _load_model_attributes(self, config: dict) -> None:
|
||||
"""Load model attributes without hyper-parameters."""
|
||||
from sklearn.base import is_classifier
|
||||
|
||||
booster = self.get_booster()
|
||||
|
||||
self.objective = config["learner"]["objective"]["name"]
|
||||
self.booster = config["learner"]["gradient_booster"]["name"]
|
||||
self.base_score = config["learner"]["learner_model_param"]["base_score"]
|
||||
self.feature_types = booster.feature_types
|
||||
|
||||
if is_classifier(self):
|
||||
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
# binary classification is treated as regression in XGBoost.
|
||||
self.n_classes_ = 2 if self.n_classes_ < 2 else self.n_classes_
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
def _configure_fit(
|
||||
self,
|
||||
@@ -1415,7 +1418,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
|
||||
Number of boosting rounds.
|
||||
""",
|
||||
)
|
||||
class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||
@_deprecate_positional_args
|
||||
def __init__(
|
||||
@@ -1643,10 +1646,6 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
def classes_(self) -> np.ndarray:
|
||||
return np.arange(self.n_classes_)
|
||||
|
||||
def load_model(self, fname: ModelIn) -> None:
|
||||
super().load_model(fname)
|
||||
self._load_model_attributes(self.get_booster())
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"scikit-learn API for XGBoost random forest classification.",
|
||||
@@ -2099,7 +2098,17 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
|
||||
"""
|
||||
X, qid = _get_qid(X, None)
|
||||
Xyq = DMatrix(X, y, qid=qid)
|
||||
# fixme(jiamingy): base margin and group weight is not yet supported. We might
|
||||
# need to make extra special fields in the dataframe.
|
||||
Xyq = DMatrix(
|
||||
X,
|
||||
y,
|
||||
qid=qid,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
nthread=self.n_jobs,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
if callable(self.eval_metric):
|
||||
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
|
||||
result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric)
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pyspark import RDD, SparkContext, cloudpickle
|
||||
from pyspark import RDD, SparkConf, SparkContext, cloudpickle
|
||||
from pyspark.ml import Estimator, Model
|
||||
from pyspark.ml.functions import array_to_vector, vector_to_array
|
||||
from pyspark.ml.linalg import VectorUDT
|
||||
@@ -138,7 +138,6 @@ _inverse_pyspark_param_alias_map = {v: k for k, v in _pyspark_param_alias_map.it
|
||||
_unsupported_xgb_params = [
|
||||
"gpu_id", # we have "device" pyspark param instead.
|
||||
"enable_categorical", # Use feature_types param to specify categorical feature instead
|
||||
"use_label_encoder",
|
||||
"n_jobs", # Do not allow user to set it, will use `spark.task.cpus` value instead.
|
||||
"nthread", # Ditto
|
||||
]
|
||||
@@ -368,7 +367,10 @@ class _SparkXGBParams(
|
||||
" on GPU."
|
||||
)
|
||||
|
||||
if not (ss.version >= "3.4.0" and _is_standalone_or_localcluster(sc)):
|
||||
if not (
|
||||
ss.version >= "3.4.0"
|
||||
and _is_standalone_or_localcluster(sc.getConf())
|
||||
):
|
||||
# We will enable stage-level scheduling in spark 3.4.0+ which doesn't
|
||||
# require spark.task.resource.gpu.amount to be set explicitly
|
||||
gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
@@ -907,30 +909,27 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
return booster_params, train_call_kwargs_params, dmatrix_kwargs
|
||||
|
||||
def _skip_stage_level_scheduling(self) -> bool:
|
||||
def _skip_stage_level_scheduling(self, spark_version: str, conf: SparkConf) -> bool:
|
||||
# pylint: disable=too-many-return-statements
|
||||
"""Check if stage-level scheduling is not needed,
|
||||
return true to skip stage-level scheduling"""
|
||||
|
||||
if self._run_on_gpu():
|
||||
ss = _get_spark_session()
|
||||
sc = ss.sparkContext
|
||||
|
||||
if ss.version < "3.4.0":
|
||||
if spark_version < "3.4.0":
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark version 3.4.0+"
|
||||
)
|
||||
return True
|
||||
|
||||
if not _is_standalone_or_localcluster(sc):
|
||||
if not _is_standalone_or_localcluster(conf):
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark standalone or "
|
||||
"local-cluster mode"
|
||||
)
|
||||
return True
|
||||
|
||||
executor_cores = sc.getConf().get("spark.executor.cores")
|
||||
executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount")
|
||||
executor_cores = conf.get("spark.executor.cores")
|
||||
executor_gpus = conf.get("spark.executor.resource.gpu.amount")
|
||||
if executor_cores is None or executor_gpus is None:
|
||||
self.logger.info(
|
||||
"Stage-level scheduling in xgboost requires spark.executor.cores, "
|
||||
@@ -955,7 +954,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
)
|
||||
return True
|
||||
|
||||
task_gpu_amount = sc.getConf().get("spark.task.resource.gpu.amount")
|
||||
task_gpu_amount = conf.get("spark.task.resource.gpu.amount")
|
||||
|
||||
if task_gpu_amount is None:
|
||||
# The ETL tasks will not grab a gpu when spark.task.resource.gpu.amount is not set,
|
||||
@@ -975,14 +974,13 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
|
||||
|
||||
def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
|
||||
"""Try to enable stage-level scheduling"""
|
||||
|
||||
if self._skip_stage_level_scheduling():
|
||||
ss = _get_spark_session()
|
||||
conf = ss.sparkContext.getConf()
|
||||
if self._skip_stage_level_scheduling(ss.version, conf):
|
||||
return rdd
|
||||
|
||||
ss = _get_spark_session()
|
||||
|
||||
# executor_cores will not be None
|
||||
executor_cores = ss.sparkContext.getConf().get("spark.executor.cores")
|
||||
executor_cores = conf.get("spark.executor.cores")
|
||||
assert executor_cores is not None
|
||||
|
||||
# Spark-rapids is a project to leverage GPUs to accelerate spark SQL.
|
||||
|
||||
@@ -10,7 +10,7 @@ from threading import Thread
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type
|
||||
|
||||
import pyspark
|
||||
from pyspark import BarrierTaskContext, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext
|
||||
from pyspark.sql.session import SparkSession
|
||||
|
||||
from xgboost import Booster, XGBModel, collective
|
||||
@@ -129,8 +129,8 @@ def _is_local(spark_context: SparkContext) -> bool:
|
||||
return spark_context._jsc.sc().isLocal()
|
||||
|
||||
|
||||
def _is_standalone_or_localcluster(spark_context: SparkContext) -> bool:
|
||||
master = spark_context.getConf().get("spark.master")
|
||||
def _is_standalone_or_localcluster(conf: SparkConf) -> bool:
|
||||
master = conf.get("spark.master")
|
||||
return master is not None and (
|
||||
master.startswith("spark://") or master.startswith("local-cluster")
|
||||
)
|
||||
|
||||
@@ -75,3 +75,28 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="Either `group` or `qid`."):
|
||||
ranker.fit(df, y, eval_set=[(X, y)])
|
||||
|
||||
|
||||
def run_ranking_categorical(device: str) -> None:
|
||||
"""Test LTR with categorical features."""
|
||||
from sklearn.model_selection import cross_val_score
|
||||
|
||||
X, y = tm.make_categorical(
|
||||
n_samples=512, n_features=10, n_categories=3, onehot=False
|
||||
)
|
||||
rng = np.random.default_rng(1994)
|
||||
qid = rng.choice(3, size=y.shape[0])
|
||||
qid = np.sort(qid)
|
||||
X["qid"] = qid
|
||||
|
||||
ltr = xgb.XGBRanker(enable_categorical=True, device=device)
|
||||
ltr.fit(X, y)
|
||||
score = ltr.score(X, y)
|
||||
assert score > 0.9
|
||||
|
||||
ltr = xgb.XGBRanker(enable_categorical=True, device=device)
|
||||
|
||||
# test using the score function inside sklearn.
|
||||
scores = cross_val_score(ltr, X, y)
|
||||
for s in scores:
|
||||
assert s > 0.7
|
||||
|
||||
Reference in New Issue
Block a user