Support doc link for the sklearn module. (#10287)
This commit is contained in:
parent
a269055b2b
commit
3d8107adb8
@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
|
||||
|
||||
def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
|
||||
"""Avoid dependency on packaging (PEP 440)."""
|
||||
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
|
||||
if ver.find("post") != -1:
|
||||
major, minor, patch = ver.split(".")[:-1]
|
||||
postfix = ver.split(".")[-1]
|
||||
elif "-dev" in ver:
|
||||
major, minor, patch = ver.split("-")[0].split(".")
|
||||
postfix = "dev"
|
||||
else:
|
||||
major, minor, patch = ver.split(".")
|
||||
rc = patch.find("rc")
|
||||
if rc != -1:
|
||||
postfix = patch[rc:]
|
||||
patch = patch[:rc]
|
||||
else:
|
||||
postfix = ""
|
||||
|
||||
return (int(major), int(minor), int(patch)), postfix
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Load xgboost Library."""
|
||||
lib_paths = find_lib_path()
|
||||
@ -237,17 +258,8 @@ Error message(s): {os_error_list}
|
||||
)
|
||||
_register_log_callback(lib)
|
||||
|
||||
def parse(ver: str) -> Tuple[int, int, int]:
|
||||
"""Avoid dependency on packaging (PEP 440)."""
|
||||
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
|
||||
major, minor, patch = ver.split("-")[0].split(".")
|
||||
rc = patch.find("rc")
|
||||
if rc != -1:
|
||||
patch = patch[:rc]
|
||||
return int(major), int(minor), int(patch)
|
||||
|
||||
libver = _lib_version(lib)
|
||||
pyver = parse(_py_version())
|
||||
pyver, _ = _parse_version(_py_version())
|
||||
|
||||
# verify that we are loading the correct binary.
|
||||
if pyver != libver:
|
||||
|
||||
@ -40,6 +40,8 @@ from .core import (
|
||||
XGBoostError,
|
||||
_deprecate_positional_args,
|
||||
_parse_eval_str,
|
||||
_parse_version,
|
||||
_py_version,
|
||||
)
|
||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
|
||||
from .training import train
|
||||
@ -795,6 +797,32 @@ class XGBModel(XGBModelBase):
|
||||
def __sklearn_is_fitted__(self) -> bool:
|
||||
return hasattr(self, "_Booster")
|
||||
|
||||
@property
|
||||
def _doc_link_module(self) -> str:
|
||||
return "xgboost"
|
||||
|
||||
@property
|
||||
def _doc_link_template(self) -> str:
|
||||
ver = _py_version()
|
||||
(major, minor, _), post = _parse_version(ver)
|
||||
|
||||
if post == "dev":
|
||||
rel = "latest"
|
||||
else:
|
||||
# RTD tracks the release branch. We don't have independent branches for
|
||||
# patch releases.
|
||||
rel = f"release_{major}.{minor}.0"
|
||||
|
||||
module = self.__class__.__module__
|
||||
# All sklearn estimators are forwarded to the top level module in both source
|
||||
# code and sphinx api doc.
|
||||
if module == "xgboost.sklearn":
|
||||
module = module.split(".")[0]
|
||||
name = self.__class__.__name__
|
||||
|
||||
base = "https://xgboost.readthedocs.io/en"
|
||||
return f"{base}/{rel}/python/python_api.html#{module}.{name}"
|
||||
|
||||
def get_booster(self) -> Booster:
|
||||
"""Get the underlying xgboost Booster of this model.
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ dependencies:
|
||||
- pylint
|
||||
- numpy
|
||||
- scipy
|
||||
- scikit-learn
|
||||
- scikit-learn>=1.4.1
|
||||
- pandas
|
||||
- matplotlib
|
||||
- dask
|
||||
|
||||
@ -11,7 +11,7 @@ dependencies:
|
||||
- numpy
|
||||
- scipy
|
||||
- llvm-openmp
|
||||
- scikit-learn
|
||||
- scikit-learn>=1.4.1
|
||||
- pandas
|
||||
- matplotlib
|
||||
- dask
|
||||
|
||||
@ -6,7 +6,7 @@ dependencies:
|
||||
- wheel
|
||||
- numpy
|
||||
- scipy
|
||||
- scikit-learn
|
||||
- scikit-learn>=1.4.1
|
||||
- pandas
|
||||
- matplotlib
|
||||
- dask
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import _parse_version
|
||||
|
||||
dpath = "demo/data/"
|
||||
rng = np.random.RandomState(1994)
|
||||
@ -315,3 +316,14 @@ class TestBasicPathLike:
|
||||
"""An invalid model_file path should raise XGBoostError."""
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.Booster(model_file=Path("invalidpath"))
|
||||
|
||||
|
||||
def test_parse_ver() -> None:
|
||||
(major, minor, patch), post = _parse_version("2.1.0")
|
||||
assert post == ""
|
||||
(major, minor, patch), post = _parse_version("2.1.0-dev")
|
||||
assert post == "dev"
|
||||
(major, minor, patch), post = _parse_version("2.1.0rc1")
|
||||
assert post == "rc1"
|
||||
(major, minor, patch), post = _parse_version("2.1.0.post1")
|
||||
assert post == "post1"
|
||||
|
||||
@ -1484,3 +1484,16 @@ def test_tags() -> None:
|
||||
|
||||
tags = xgb.XGBRanker()._more_tags()
|
||||
assert "multioutput" not in tags
|
||||
|
||||
|
||||
def test_doc_link() -> None:
|
||||
for est in [
|
||||
xgb.XGBRegressor(),
|
||||
xgb.XGBClassifier(),
|
||||
xgb.XGBRanker(),
|
||||
xgb.XGBRFRegressor(),
|
||||
xgb.XGBRFClassifier(),
|
||||
]:
|
||||
name = est.__class__.__name__
|
||||
link = est._get_doc_link()
|
||||
assert f"xgboost.{name}" in link
|
||||
|
||||
@ -35,6 +35,7 @@ from hypothesis import HealthCheck, assume, given, note, settings
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import dask as dxgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.data import _is_cudf_df
|
||||
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
||||
@ -2324,3 +2325,16 @@ async def test_worker_restarted(c, s, a, b):
|
||||
d_train,
|
||||
evals=[(d_train, "train")],
|
||||
)
|
||||
|
||||
|
||||
def test_doc_link() -> None:
|
||||
for est in [
|
||||
dxgb.DaskXGBRegressor(),
|
||||
dxgb.DaskXGBClassifier(),
|
||||
dxgb.DaskXGBRanker(),
|
||||
dxgb.DaskXGBRFRegressor(),
|
||||
dxgb.DaskXGBRFClassifier(),
|
||||
]:
|
||||
name = est.__class__.__name__
|
||||
link = est._get_doc_link()
|
||||
assert f"xgboost.dask.{name}" in link
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user