Support doc link for the sklearn module. (#10287)

This commit is contained in:
Jiaming Yuan
2024-08-06 02:35:32 +08:00
committed by GitHub
parent a269055b2b
commit 3d8107adb8
8 changed files with 92 additions and 13 deletions

View File

@@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
raise XGBoostError(lib.XGBGetLastError())
def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
if ver.find("post") != -1:
major, minor, patch = ver.split(".")[:-1]
postfix = ver.split(".")[-1]
elif "-dev" in ver:
major, minor, patch = ver.split("-")[0].split(".")
postfix = "dev"
else:
major, minor, patch = ver.split(".")
rc = patch.find("rc")
if rc != -1:
postfix = patch[rc:]
patch = patch[:rc]
else:
postfix = ""
return (int(major), int(minor), int(patch)), postfix
def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library."""
lib_paths = find_lib_path()
@@ -237,17 +258,8 @@ Error message(s): {os_error_list}
)
_register_log_callback(lib)
def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
major, minor, patch = ver.split("-")[0].split(".")
rc = patch.find("rc")
if rc != -1:
patch = patch[:rc]
return int(major), int(minor), int(patch)
libver = _lib_version(lib)
pyver = parse(_py_version())
pyver, _ = _parse_version(_py_version())
# verify that we are loading the correct binary.
if pyver != libver:

View File

@@ -40,6 +40,8 @@ from .core import (
XGBoostError,
_deprecate_positional_args,
_parse_eval_str,
_parse_version,
_py_version,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
from .training import train
@@ -795,6 +797,32 @@ class XGBModel(XGBModelBase):
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
@property
def _doc_link_module(self) -> str:
return "xgboost"
@property
def _doc_link_template(self) -> str:
ver = _py_version()
(major, minor, _), post = _parse_version(ver)
if post == "dev":
rel = "latest"
else:
# RTD tracks the release branch. We don't have independent branches for
# patch releases.
rel = f"release_{major}.{minor}.0"
module = self.__class__.__module__
# All sklearn estimators are forwarded to the top level module in both source
# code and sphinx api doc.
if module == "xgboost.sklearn":
module = module.split(".")[0]
name = self.__class__.__name__
base = "https://xgboost.readthedocs.io/en"
return f"{base}/{rel}/python/python_api.html#{module}.{name}"
def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.