Verify shared object version at load. (#7928) (#8168)

This commit is contained in:
Jiaming Yuan 2022-08-15 15:16:22 +08:00 committed by GitHub
parent 97d89c3ca1
commit 0e2b5c467e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 12 deletions

View File

@ -1,12 +1,16 @@
# coding: utf-8
"""XGBoost: eXtreme Gradient Boosting library. """XGBoost: eXtreme Gradient Boosting library.
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
""" """
import os from .core import (
DMatrix,
from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter, build_info DeviceQuantileDMatrix,
Booster,
DataIter,
build_info,
_py_version,
)
from .training import train, cv from .training import train, cv
from . import rabit # noqa from . import rabit # noqa
from . import tracker # noqa from . import tracker # noqa
@ -21,9 +25,9 @@ try:
except ImportError: except ImportError:
pass pass
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
with open(VERSION_FILE, encoding="ascii") as f: __version__ = _py_version()
__version__ = f.read().strip()
__all__ = [ __all__ = [
# core # core

View File

@ -137,6 +137,22 @@ def _get_log_callback_func() -> Callable:
return c_callback(_log_callback) return c_callback(_log_callback)
def _lib_version(lib: ctypes.CDLL) -> Tuple[int, int, int]:
"""Get the XGBoost version from native shared object."""
major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
return major.value, minor.value, patch.value
def _py_version() -> str:
"""Get the XGBoost version from Python version file."""
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
with open(VERSION_FILE, encoding="ascii") as f:
return f.read().strip()
def _load_lib() -> ctypes.CDLL: def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library.""" """Load xgboost Library."""
lib_paths = find_lib_path() lib_paths = find_lib_path()
@ -144,7 +160,7 @@ def _load_lib() -> ctypes.CDLL:
# This happens only when building document. # This happens only when building document.
return None # type: ignore return None # type: ignore
try: try:
pathBackup = os.environ['PATH'].split(os.pathsep) pathBackup = os.environ["PATH"].split(os.pathsep)
except KeyError: except KeyError:
pathBackup = [] pathBackup = []
lib_success = False lib_success = False
@ -153,15 +169,16 @@ def _load_lib() -> ctypes.CDLL:
try: try:
# needed when the lib is linked with non-system-available # needed when the lib is linked with non-system-available
# dependencies # dependencies
os.environ['PATH'] = os.pathsep.join( os.environ["PATH"] = os.pathsep.join(
pathBackup + [os.path.dirname(lib_path)]) pathBackup + [os.path.dirname(lib_path)]
)
lib = ctypes.cdll.LoadLibrary(lib_path) lib = ctypes.cdll.LoadLibrary(lib_path)
lib_success = True lib_success = True
except OSError as e: except OSError as e:
os_error_list.append(str(e)) os_error_list.append(str(e))
continue continue
finally: finally:
os.environ['PATH'] = os.pathsep.join(pathBackup) os.environ["PATH"] = os.pathsep.join(pathBackup)
if not lib_success: if not lib_success:
libname = os.path.basename(lib_paths[0]) libname = os.path.basename(lib_paths[0])
raise XGBoostError( raise XGBoostError(
@ -177,11 +194,36 @@ Likely causes:
* You are running 32-bit Python on a 64-bit OS * You are running 32-bit Python on a 64-bit OS
Error message(s): {os_error_list} Error message(s): {os_error_list}
""") """
)
lib.XGBGetLastError.restype = ctypes.c_char_p lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func() # type: ignore lib.callback = _get_log_callback_func() # type: ignore
if lib.XGBRegisterLogCallback(lib.callback) != 0: if lib.XGBRegisterLogCallback(lib.callback) != 0:
raise XGBoostError(lib.XGBGetLastError()) raise XGBoostError(lib.XGBGetLastError())
def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev or 2.0.0
major, minor, patch = ver.split("-")[0].split(".")
return int(major), int(minor), int(patch)
libver = _lib_version(lib)
pyver = parse(_py_version())
# verify that we are loading the correct binary.
if pyver != libver:
pyver_str = ".".join((str(v) for v in pyver))
libver_str = ".".join((str(v) for v in libver))
msg = (
"Mismatched version between the Python package and the native shared "
f"""object. Python package version: {pyver_str}. Shared object """
f"""version: {libver_str}. Shared object is loaded from: {lib.path}.
Likely cause:
* XGBoost is first installed with anaconda then upgraded with pip. To fix it """
"please remove one of the installations."
)
raise ValueError(msg)
return lib return lib