Support doc link for the sklearn module. (#10287)

This commit is contained in:
Jiaming Yuan 2024-08-06 02:35:32 +08:00 committed by GitHub
parent a269055b2b
commit 3d8107adb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 92 additions and 13 deletions

View File

@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
raise XGBoostError(lib.XGBGetLastError()) raise XGBoostError(lib.XGBGetLastError())
def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
if ver.find("post") != -1:
major, minor, patch = ver.split(".")[:-1]
postfix = ver.split(".")[-1]
elif "-dev" in ver:
major, minor, patch = ver.split("-")[0].split(".")
postfix = "dev"
else:
major, minor, patch = ver.split(".")
rc = patch.find("rc")
if rc != -1:
postfix = patch[rc:]
patch = patch[:rc]
else:
postfix = ""
return (int(major), int(minor), int(patch)), postfix
def _load_lib() -> ctypes.CDLL: def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library.""" """Load xgboost Library."""
lib_paths = find_lib_path() lib_paths = find_lib_path()
@ -237,17 +258,8 @@ Error message(s): {os_error_list}
) )
_register_log_callback(lib) _register_log_callback(lib)
def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
major, minor, patch = ver.split("-")[0].split(".")
rc = patch.find("rc")
if rc != -1:
patch = patch[:rc]
return int(major), int(minor), int(patch)
libver = _lib_version(lib) libver = _lib_version(lib)
pyver = parse(_py_version()) pyver, _ = _parse_version(_py_version())
# verify that we are loading the correct binary. # verify that we are loading the correct binary.
if pyver != libver: if pyver != libver:

View File

@ -40,6 +40,8 @@ from .core import (
XGBoostError, XGBoostError,
_deprecate_positional_args, _deprecate_positional_args,
_parse_eval_str, _parse_eval_str,
_parse_version,
_py_version,
) )
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
from .training import train from .training import train
@ -795,6 +797,32 @@ class XGBModel(XGBModelBase):
def __sklearn_is_fitted__(self) -> bool: def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster") return hasattr(self, "_Booster")
@property
def _doc_link_module(self) -> str:
return "xgboost"
@property
def _doc_link_template(self) -> str:
ver = _py_version()
(major, minor, _), post = _parse_version(ver)
if post == "dev":
rel = "latest"
else:
# RTD tracks the release branch. We don't have independent branches for
# patch releases.
rel = f"release_{major}.{minor}.0"
module = self.__class__.__module__
# All sklearn estimators are forwarded to the top level module in both source
# code and sphinx api doc.
if module == "xgboost.sklearn":
module = module.split(".")[0]
name = self.__class__.__name__
base = "https://xgboost.readthedocs.io/en"
return f"{base}/{rel}/python/python_api.html#{module}.{name}"
def get_booster(self) -> Booster: def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model. """Get the underlying xgboost Booster of this model.

View File

@ -14,7 +14,7 @@ dependencies:
- pylint - pylint
- numpy - numpy
- scipy - scipy
- scikit-learn - scikit-learn>=1.4.1
- pandas - pandas
- matplotlib - matplotlib
- dask - dask

View File

@ -11,7 +11,7 @@ dependencies:
- numpy - numpy
- scipy - scipy
- llvm-openmp - llvm-openmp
- scikit-learn - scikit-learn>=1.4.1
- pandas - pandas
- matplotlib - matplotlib
- dask - dask

View File

@ -6,7 +6,7 @@ dependencies:
- wheel - wheel
- numpy - numpy
- scipy - scipy
- scikit-learn - scikit-learn>=1.4.1
- pandas - pandas
- matplotlib - matplotlib
- dask - dask

View File

@ -9,6 +9,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.core import _parse_version
dpath = "demo/data/" dpath = "demo/data/"
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -315,3 +316,14 @@ class TestBasicPathLike:
"""An invalid model_file path should raise XGBoostError.""" """An invalid model_file path should raise XGBoostError."""
with pytest.raises(xgb.core.XGBoostError): with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file=Path("invalidpath")) xgb.Booster(model_file=Path("invalidpath"))
def test_parse_ver() -> None:
(major, minor, patch), post = _parse_version("2.1.0")
assert post == ""
(major, minor, patch), post = _parse_version("2.1.0-dev")
assert post == "dev"
(major, minor, patch), post = _parse_version("2.1.0rc1")
assert post == "rc1"
(major, minor, patch), post = _parse_version("2.1.0.post1")
assert post == "post1"

View File

@ -1484,3 +1484,16 @@ def test_tags() -> None:
tags = xgb.XGBRanker()._more_tags() tags = xgb.XGBRanker()._more_tags()
assert "multioutput" not in tags assert "multioutput" not in tags
def test_doc_link() -> None:
for est in [
xgb.XGBRegressor(),
xgb.XGBClassifier(),
xgb.XGBRanker(),
xgb.XGBRFRegressor(),
xgb.XGBRFClassifier(),
]:
name = est.__class__.__name__
link = est._get_doc_link()
assert f"xgboost.{name}" in link

View File

@ -35,6 +35,7 @@ from hypothesis import HealthCheck, assume, given, note, settings
from sklearn.datasets import make_classification, make_regression from sklearn.datasets import make_classification, make_regression
import xgboost as xgb import xgboost as xgb
from xgboost import dask as dxgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.data import _is_cudf_df from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
@ -2324,3 +2325,16 @@ async def test_worker_restarted(c, s, a, b):
d_train, d_train,
evals=[(d_train, "train")], evals=[(d_train, "train")],
) )
def test_doc_link() -> None:
for est in [
dxgb.DaskXGBRegressor(),
dxgb.DaskXGBClassifier(),
dxgb.DaskXGBRanker(),
dxgb.DaskXGBRFRegressor(),
dxgb.DaskXGBRFClassifier(),
]:
name = est.__class__.__name__
link = est._get_doc_link()
assert f"xgboost.dask.{name}" in link