parent
97d89c3ca1
commit
0e2b5c467e
@ -1,12 +1,16 @@
|
||||
# coding: utf-8
|
||||
"""XGBoost: eXtreme Gradient Boosting library.
|
||||
|
||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter, build_info
|
||||
from .core import (
|
||||
DMatrix,
|
||||
DeviceQuantileDMatrix,
|
||||
Booster,
|
||||
DataIter,
|
||||
build_info,
|
||||
_py_version,
|
||||
)
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
@ -21,9 +25,9 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
||||
with open(VERSION_FILE, encoding="ascii") as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
__version__ = _py_version()
|
||||
|
||||
|
||||
__all__ = [
|
||||
# core
|
||||
|
||||
@ -137,6 +137,22 @@ def _get_log_callback_func() -> Callable:
|
||||
return c_callback(_log_callback)
|
||||
|
||||
|
||||
def _lib_version(lib: ctypes.CDLL) -> Tuple[int, int, int]:
|
||||
"""Get the XGBoost version from native shared object."""
|
||||
major = ctypes.c_int()
|
||||
minor = ctypes.c_int()
|
||||
patch = ctypes.c_int()
|
||||
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
|
||||
return major.value, minor.value, patch.value
|
||||
|
||||
|
||||
def _py_version() -> str:
|
||||
"""Get the XGBoost version from Python version file."""
|
||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
||||
with open(VERSION_FILE, encoding="ascii") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Load xgboost Library."""
|
||||
lib_paths = find_lib_path()
|
||||
@ -144,7 +160,7 @@ def _load_lib() -> ctypes.CDLL:
|
||||
# This happens only when building document.
|
||||
return None # type: ignore
|
||||
try:
|
||||
pathBackup = os.environ['PATH'].split(os.pathsep)
|
||||
pathBackup = os.environ["PATH"].split(os.pathsep)
|
||||
except KeyError:
|
||||
pathBackup = []
|
||||
lib_success = False
|
||||
@ -153,15 +169,16 @@ def _load_lib() -> ctypes.CDLL:
|
||||
try:
|
||||
# needed when the lib is linked with non-system-available
|
||||
# dependencies
|
||||
os.environ['PATH'] = os.pathsep.join(
|
||||
pathBackup + [os.path.dirname(lib_path)])
|
||||
os.environ["PATH"] = os.pathsep.join(
|
||||
pathBackup + [os.path.dirname(lib_path)]
|
||||
)
|
||||
lib = ctypes.cdll.LoadLibrary(lib_path)
|
||||
lib_success = True
|
||||
except OSError as e:
|
||||
os_error_list.append(str(e))
|
||||
continue
|
||||
finally:
|
||||
os.environ['PATH'] = os.pathsep.join(pathBackup)
|
||||
os.environ["PATH"] = os.pathsep.join(pathBackup)
|
||||
if not lib_success:
|
||||
libname = os.path.basename(lib_paths[0])
|
||||
raise XGBoostError(
|
||||
@ -177,11 +194,36 @@ Likely causes:
|
||||
* You are running 32-bit Python on a 64-bit OS
|
||||
|
||||
Error message(s): {os_error_list}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||
lib.callback = _get_log_callback_func() # type: ignore
|
||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||
raise XGBoostError(lib.XGBGetLastError())
|
||||
|
||||
def parse(ver: str) -> Tuple[int, int, int]:
|
||||
"""Avoid dependency on packaging (PEP 440)."""
|
||||
# 2.0.0-dev or 2.0.0
|
||||
major, minor, patch = ver.split("-")[0].split(".")
|
||||
return int(major), int(minor), int(patch)
|
||||
|
||||
libver = _lib_version(lib)
|
||||
pyver = parse(_py_version())
|
||||
|
||||
# verify that we are loading the correct binary.
|
||||
if pyver != libver:
|
||||
pyver_str = ".".join((str(v) for v in pyver))
|
||||
libver_str = ".".join((str(v) for v in libver))
|
||||
msg = (
|
||||
"Mismatched version between the Python package and the native shared "
|
||||
f"""object. Python package version: {pyver_str}. Shared object """
|
||||
f"""version: {libver_str}. Shared object is loaded from: {lib.path}.
|
||||
Likely cause:
|
||||
* XGBoost is first installed with anaconda then upgraded with pip. To fix it """
|
||||
"please remove one of the installations."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return lib
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user