Support doc link for the sklearn module. (#10287)
This commit is contained in:
parent
a269055b2b
commit
3d8107adb8
@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
|
|||||||
raise XGBoostError(lib.XGBGetLastError())
|
raise XGBoostError(lib.XGBGetLastError())
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
|
||||||
|
"""Avoid dependency on packaging (PEP 440)."""
|
||||||
|
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
|
||||||
|
if ver.find("post") != -1:
|
||||||
|
major, minor, patch = ver.split(".")[:-1]
|
||||||
|
postfix = ver.split(".")[-1]
|
||||||
|
elif "-dev" in ver:
|
||||||
|
major, minor, patch = ver.split("-")[0].split(".")
|
||||||
|
postfix = "dev"
|
||||||
|
else:
|
||||||
|
major, minor, patch = ver.split(".")
|
||||||
|
rc = patch.find("rc")
|
||||||
|
if rc != -1:
|
||||||
|
postfix = patch[rc:]
|
||||||
|
patch = patch[:rc]
|
||||||
|
else:
|
||||||
|
postfix = ""
|
||||||
|
|
||||||
|
return (int(major), int(minor), int(patch)), postfix
|
||||||
|
|
||||||
|
|
||||||
def _load_lib() -> ctypes.CDLL:
|
def _load_lib() -> ctypes.CDLL:
|
||||||
"""Load xgboost Library."""
|
"""Load xgboost Library."""
|
||||||
lib_paths = find_lib_path()
|
lib_paths = find_lib_path()
|
||||||
@ -237,17 +258,8 @@ Error message(s): {os_error_list}
|
|||||||
)
|
)
|
||||||
_register_log_callback(lib)
|
_register_log_callback(lib)
|
||||||
|
|
||||||
def parse(ver: str) -> Tuple[int, int, int]:
|
|
||||||
"""Avoid dependency on packaging (PEP 440)."""
|
|
||||||
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
|
|
||||||
major, minor, patch = ver.split("-")[0].split(".")
|
|
||||||
rc = patch.find("rc")
|
|
||||||
if rc != -1:
|
|
||||||
patch = patch[:rc]
|
|
||||||
return int(major), int(minor), int(patch)
|
|
||||||
|
|
||||||
libver = _lib_version(lib)
|
libver = _lib_version(lib)
|
||||||
pyver = parse(_py_version())
|
pyver, _ = _parse_version(_py_version())
|
||||||
|
|
||||||
# verify that we are loading the correct binary.
|
# verify that we are loading the correct binary.
|
||||||
if pyver != libver:
|
if pyver != libver:
|
||||||
|
|||||||
@ -40,6 +40,8 @@ from .core import (
|
|||||||
XGBoostError,
|
XGBoostError,
|
||||||
_deprecate_positional_args,
|
_deprecate_positional_args,
|
||||||
_parse_eval_str,
|
_parse_eval_str,
|
||||||
|
_parse_version,
|
||||||
|
_py_version,
|
||||||
)
|
)
|
||||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
|
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
|
||||||
from .training import train
|
from .training import train
|
||||||
@ -795,6 +797,32 @@ class XGBModel(XGBModelBase):
|
|||||||
def __sklearn_is_fitted__(self) -> bool:
|
def __sklearn_is_fitted__(self) -> bool:
|
||||||
return hasattr(self, "_Booster")
|
return hasattr(self, "_Booster")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _doc_link_module(self) -> str:
|
||||||
|
return "xgboost"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _doc_link_template(self) -> str:
|
||||||
|
ver = _py_version()
|
||||||
|
(major, minor, _), post = _parse_version(ver)
|
||||||
|
|
||||||
|
if post == "dev":
|
||||||
|
rel = "latest"
|
||||||
|
else:
|
||||||
|
# RTD tracks the release branch. We don't have independent branches for
|
||||||
|
# patch releases.
|
||||||
|
rel = f"release_{major}.{minor}.0"
|
||||||
|
|
||||||
|
module = self.__class__.__module__
|
||||||
|
# All sklearn estimators are forwarded to the top level module in both source
|
||||||
|
# code and sphinx api doc.
|
||||||
|
if module == "xgboost.sklearn":
|
||||||
|
module = module.split(".")[0]
|
||||||
|
name = self.__class__.__name__
|
||||||
|
|
||||||
|
base = "https://xgboost.readthedocs.io/en"
|
||||||
|
return f"{base}/{rel}/python/python_api.html#{module}.{name}"
|
||||||
|
|
||||||
def get_booster(self) -> Booster:
|
def get_booster(self) -> Booster:
|
||||||
"""Get the underlying xgboost Booster of this model.
|
"""Get the underlying xgboost Booster of this model.
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,7 @@ dependencies:
|
|||||||
- pylint
|
- pylint
|
||||||
- numpy
|
- numpy
|
||||||
- scipy
|
- scipy
|
||||||
- scikit-learn
|
- scikit-learn>=1.4.1
|
||||||
- pandas
|
- pandas
|
||||||
- matplotlib
|
- matplotlib
|
||||||
- dask
|
- dask
|
||||||
|
|||||||
@ -11,7 +11,7 @@ dependencies:
|
|||||||
- numpy
|
- numpy
|
||||||
- scipy
|
- scipy
|
||||||
- llvm-openmp
|
- llvm-openmp
|
||||||
- scikit-learn
|
- scikit-learn>=1.4.1
|
||||||
- pandas
|
- pandas
|
||||||
- matplotlib
|
- matplotlib
|
||||||
- dask
|
- dask
|
||||||
|
|||||||
@ -6,7 +6,7 @@ dependencies:
|
|||||||
- wheel
|
- wheel
|
||||||
- numpy
|
- numpy
|
||||||
- scipy
|
- scipy
|
||||||
- scikit-learn
|
- scikit-learn>=1.4.1
|
||||||
- pandas
|
- pandas
|
||||||
- matplotlib
|
- matplotlib
|
||||||
- dask
|
- dask
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import pytest
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.core import _parse_version
|
||||||
|
|
||||||
dpath = "demo/data/"
|
dpath = "demo/data/"
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
@ -315,3 +316,14 @@ class TestBasicPathLike:
|
|||||||
"""An invalid model_file path should raise XGBoostError."""
|
"""An invalid model_file path should raise XGBoostError."""
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
xgb.Booster(model_file=Path("invalidpath"))
|
xgb.Booster(model_file=Path("invalidpath"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_ver() -> None:
|
||||||
|
(major, minor, patch), post = _parse_version("2.1.0")
|
||||||
|
assert post == ""
|
||||||
|
(major, minor, patch), post = _parse_version("2.1.0-dev")
|
||||||
|
assert post == "dev"
|
||||||
|
(major, minor, patch), post = _parse_version("2.1.0rc1")
|
||||||
|
assert post == "rc1"
|
||||||
|
(major, minor, patch), post = _parse_version("2.1.0.post1")
|
||||||
|
assert post == "post1"
|
||||||
|
|||||||
@ -1484,3 +1484,16 @@ def test_tags() -> None:
|
|||||||
|
|
||||||
tags = xgb.XGBRanker()._more_tags()
|
tags = xgb.XGBRanker()._more_tags()
|
||||||
assert "multioutput" not in tags
|
assert "multioutput" not in tags
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_link() -> None:
|
||||||
|
for est in [
|
||||||
|
xgb.XGBRegressor(),
|
||||||
|
xgb.XGBClassifier(),
|
||||||
|
xgb.XGBRanker(),
|
||||||
|
xgb.XGBRFRegressor(),
|
||||||
|
xgb.XGBRFClassifier(),
|
||||||
|
]:
|
||||||
|
name = est.__class__.__name__
|
||||||
|
link = est._get_doc_link()
|
||||||
|
assert f"xgboost.{name}" in link
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from hypothesis import HealthCheck, assume, given, note, settings
|
|||||||
from sklearn.datasets import make_classification, make_regression
|
from sklearn.datasets import make_classification, make_regression
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
from xgboost import dask as dxgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
from xgboost.data import _is_cudf_df
|
from xgboost.data import _is_cudf_df
|
||||||
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
|
||||||
@ -2324,3 +2325,16 @@ async def test_worker_restarted(c, s, a, b):
|
|||||||
d_train,
|
d_train,
|
||||||
evals=[(d_train, "train")],
|
evals=[(d_train, "train")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_link() -> None:
|
||||||
|
for est in [
|
||||||
|
dxgb.DaskXGBRegressor(),
|
||||||
|
dxgb.DaskXGBClassifier(),
|
||||||
|
dxgb.DaskXGBRanker(),
|
||||||
|
dxgb.DaskXGBRFRegressor(),
|
||||||
|
dxgb.DaskXGBRFClassifier(),
|
||||||
|
]:
|
||||||
|
name = est.__class__.__name__
|
||||||
|
link = est._get_doc_link()
|
||||||
|
assert f"xgboost.dask.{name}" in link
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user