parent
97d89c3ca1
commit
0e2b5c467e
@ -1,12 +1,16 @@
|
|||||||
# coding: utf-8
|
|
||||||
"""XGBoost: eXtreme Gradient Boosting library.
|
"""XGBoost: eXtreme Gradient Boosting library.
|
||||||
|
|
||||||
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
from .core import (
|
||||||
|
DMatrix,
|
||||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter, build_info
|
DeviceQuantileDMatrix,
|
||||||
|
Booster,
|
||||||
|
DataIter,
|
||||||
|
build_info,
|
||||||
|
_py_version,
|
||||||
|
)
|
||||||
from .training import train, cv
|
from .training import train, cv
|
||||||
from . import rabit # noqa
|
from . import rabit # noqa
|
||||||
from . import tracker # noqa
|
from . import tracker # noqa
|
||||||
@ -21,9 +25,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
|
||||||
with open(VERSION_FILE, encoding="ascii") as f:
|
__version__ = _py_version()
|
||||||
__version__ = f.read().strip()
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# core
|
# core
|
||||||
|
|||||||
@ -137,6 +137,22 @@ def _get_log_callback_func() -> Callable:
|
|||||||
return c_callback(_log_callback)
|
return c_callback(_log_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def _lib_version(lib: ctypes.CDLL) -> Tuple[int, int, int]:
|
||||||
|
"""Get the XGBoost version from native shared object."""
|
||||||
|
major = ctypes.c_int()
|
||||||
|
minor = ctypes.c_int()
|
||||||
|
patch = ctypes.c_int()
|
||||||
|
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
|
||||||
|
return major.value, minor.value, patch.value
|
||||||
|
|
||||||
|
|
||||||
|
def _py_version() -> str:
|
||||||
|
"""Get the XGBoost version from Python version file."""
|
||||||
|
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
|
||||||
|
with open(VERSION_FILE, encoding="ascii") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
def _load_lib() -> ctypes.CDLL:
|
def _load_lib() -> ctypes.CDLL:
|
||||||
"""Load xgboost Library."""
|
"""Load xgboost Library."""
|
||||||
lib_paths = find_lib_path()
|
lib_paths = find_lib_path()
|
||||||
@ -144,7 +160,7 @@ def _load_lib() -> ctypes.CDLL:
|
|||||||
# This happens only when building document.
|
# This happens only when building document.
|
||||||
return None # type: ignore
|
return None # type: ignore
|
||||||
try:
|
try:
|
||||||
pathBackup = os.environ['PATH'].split(os.pathsep)
|
pathBackup = os.environ["PATH"].split(os.pathsep)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pathBackup = []
|
pathBackup = []
|
||||||
lib_success = False
|
lib_success = False
|
||||||
@ -153,15 +169,16 @@ def _load_lib() -> ctypes.CDLL:
|
|||||||
try:
|
try:
|
||||||
# needed when the lib is linked with non-system-available
|
# needed when the lib is linked with non-system-available
|
||||||
# dependencies
|
# dependencies
|
||||||
os.environ['PATH'] = os.pathsep.join(
|
os.environ["PATH"] = os.pathsep.join(
|
||||||
pathBackup + [os.path.dirname(lib_path)])
|
pathBackup + [os.path.dirname(lib_path)]
|
||||||
|
)
|
||||||
lib = ctypes.cdll.LoadLibrary(lib_path)
|
lib = ctypes.cdll.LoadLibrary(lib_path)
|
||||||
lib_success = True
|
lib_success = True
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
os_error_list.append(str(e))
|
os_error_list.append(str(e))
|
||||||
continue
|
continue
|
||||||
finally:
|
finally:
|
||||||
os.environ['PATH'] = os.pathsep.join(pathBackup)
|
os.environ["PATH"] = os.pathsep.join(pathBackup)
|
||||||
if not lib_success:
|
if not lib_success:
|
||||||
libname = os.path.basename(lib_paths[0])
|
libname = os.path.basename(lib_paths[0])
|
||||||
raise XGBoostError(
|
raise XGBoostError(
|
||||||
@ -177,11 +194,36 @@ Likely causes:
|
|||||||
* You are running 32-bit Python on a 64-bit OS
|
* You are running 32-bit Python on a 64-bit OS
|
||||||
|
|
||||||
Error message(s): {os_error_list}
|
Error message(s): {os_error_list}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
lib.XGBGetLastError.restype = ctypes.c_char_p
|
lib.XGBGetLastError.restype = ctypes.c_char_p
|
||||||
lib.callback = _get_log_callback_func() # type: ignore
|
lib.callback = _get_log_callback_func() # type: ignore
|
||||||
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
if lib.XGBRegisterLogCallback(lib.callback) != 0:
|
||||||
raise XGBoostError(lib.XGBGetLastError())
|
raise XGBoostError(lib.XGBGetLastError())
|
||||||
|
|
||||||
|
def parse(ver: str) -> Tuple[int, int, int]:
|
||||||
|
"""Avoid dependency on packaging (PEP 440)."""
|
||||||
|
# 2.0.0-dev or 2.0.0
|
||||||
|
major, minor, patch = ver.split("-")[0].split(".")
|
||||||
|
return int(major), int(minor), int(patch)
|
||||||
|
|
||||||
|
libver = _lib_version(lib)
|
||||||
|
pyver = parse(_py_version())
|
||||||
|
|
||||||
|
# verify that we are loading the correct binary.
|
||||||
|
if pyver != libver:
|
||||||
|
pyver_str = ".".join((str(v) for v in pyver))
|
||||||
|
libver_str = ".".join((str(v) for v in libver))
|
||||||
|
msg = (
|
||||||
|
"Mismatched version between the Python package and the native shared "
|
||||||
|
f"""object. Python package version: {pyver_str}. Shared object """
|
||||||
|
f"""version: {libver_str}. Shared object is loaded from: {lib.path}.
|
||||||
|
Likely cause:
|
||||||
|
* XGBoost is first installed with anaconda then upgraded with pip. To fix it """
|
||||||
|
"please remove one of the installations."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
return lib
|
return lib
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user