Require black formatter for the python package. (#8748)

This commit is contained in:
Jiaming Yuan 2023-02-07 01:53:33 +08:00 committed by GitHub
parent a2e433a089
commit 0f37a01dd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 707 additions and 574 deletions

View File

@ -19,17 +19,17 @@ sys.path.insert(0, CURRENT_DIR)
# requires using CMake directly.
USER_OPTIONS = {
# libxgboost options.
'use-openmp': (None, 'Build with OpenMP support.', 1),
'use-cuda': (None, 'Build with GPU acceleration.', 0),
'use-nccl': (None, 'Build with NCCL to enable distributed GPU support.', 0),
'build-with-shared-nccl': (None, 'Build with shared NCCL library.', 0),
'hide-cxx-symbols': (None, 'Hide all C++ symbols during build.', 1),
'use-hdfs': (None, 'Build with HDFS support', 0),
'use-azure': (None, 'Build with AZURE support.', 0),
'use-s3': (None, 'Build with S3 support', 0),
'plugin-dense-parser': (None, 'Build dense parser plugin.', 0),
"use-openmp": (None, "Build with OpenMP support.", 1),
"use-cuda": (None, "Build with GPU acceleration.", 0),
"use-nccl": (None, "Build with NCCL to enable distributed GPU support.", 0),
"build-with-shared-nccl": (None, "Build with shared NCCL library.", 0),
"hide-cxx-symbols": (None, "Hide all C++ symbols during build.", 1),
"use-hdfs": (None, "Build with HDFS support", 0),
"use-azure": (None, "Build with AZURE support.", 0),
"use-s3": (None, "Build with S3 support", 0),
"plugin-dense-parser": (None, "Build dense parser plugin.", 0),
# Python specific
'use-system-libxgboost': (None, 'Use libxgboost.so in system path.', 0)
"use-system-libxgboost": (None, "Use libxgboost.so in system path.", 0),
}
NEED_CLEAN_TREE = set()
@ -38,20 +38,21 @@ BUILD_TEMP_DIR = None
def lib_name() -> str:
'''Return platform dependent shared object name.'''
if system() == 'Linux' or system().upper().endswith('BSD'):
name = 'libxgboost.so'
elif system() == 'Darwin':
name = 'libxgboost.dylib'
elif system() == 'Windows':
name = 'xgboost.dll'
elif system() == 'OS400':
name = 'libxgboost.so'
"""Return platform dependent shared object name."""
if system() == "Linux" or system().upper().endswith("BSD"):
name = "libxgboost.so"
elif system() == "Darwin":
name = "libxgboost.dylib"
elif system() == "Windows":
name = "xgboost.dll"
elif system() == "OS400":
name = "libxgboost.so"
return name
def copy_tree(src_dir: str, target_dir: str) -> None:
'''Copy source tree into build directory.'''
"""Copy source tree into build directory."""
def clean_copy_tree(src: str, dst: str) -> None:
shutil.copytree(src, dst)
NEED_CLEAN_TREE.add(os.path.abspath(dst))
@ -60,30 +61,30 @@ def copy_tree(src_dir: str, target_dir: str) -> None:
shutil.copy(src, dst)
NEED_CLEAN_FILE.add(os.path.abspath(dst))
src = os.path.join(src_dir, 'src')
inc = os.path.join(src_dir, 'include')
dmlc_core = os.path.join(src_dir, 'dmlc-core')
src = os.path.join(src_dir, "src")
inc = os.path.join(src_dir, "include")
dmlc_core = os.path.join(src_dir, "dmlc-core")
gputreeshap = os.path.join(src_dir, "gputreeshap")
rabit = os.path.join(src_dir, 'rabit')
cmake = os.path.join(src_dir, 'cmake')
plugin = os.path.join(src_dir, 'plugin')
rabit = os.path.join(src_dir, "rabit")
cmake = os.path.join(src_dir, "cmake")
plugin = os.path.join(src_dir, "plugin")
clean_copy_tree(src, os.path.join(target_dir, 'src'))
clean_copy_tree(inc, os.path.join(target_dir, 'include'))
clean_copy_tree(dmlc_core, os.path.join(target_dir, 'dmlc-core'))
clean_copy_tree(src, os.path.join(target_dir, "src"))
clean_copy_tree(inc, os.path.join(target_dir, "include"))
clean_copy_tree(dmlc_core, os.path.join(target_dir, "dmlc-core"))
clean_copy_tree(gputreeshap, os.path.join(target_dir, "gputreeshap"))
clean_copy_tree(rabit, os.path.join(target_dir, 'rabit'))
clean_copy_tree(cmake, os.path.join(target_dir, 'cmake'))
clean_copy_tree(plugin, os.path.join(target_dir, 'plugin'))
clean_copy_tree(rabit, os.path.join(target_dir, "rabit"))
clean_copy_tree(cmake, os.path.join(target_dir, "cmake"))
clean_copy_tree(plugin, os.path.join(target_dir, "plugin"))
cmake_list = os.path.join(src_dir, 'CMakeLists.txt')
clean_copy_file(cmake_list, os.path.join(target_dir, 'CMakeLists.txt'))
lic = os.path.join(src_dir, 'LICENSE')
clean_copy_file(lic, os.path.join(target_dir, 'LICENSE'))
cmake_list = os.path.join(src_dir, "CMakeLists.txt")
clean_copy_file(cmake_list, os.path.join(target_dir, "CMakeLists.txt"))
lic = os.path.join(src_dir, "LICENSE")
clean_copy_file(lic, os.path.join(target_dir, "LICENSE"))
def clean_up() -> None:
'''Removed copied files.'''
"""Removed copied files."""
for path in NEED_CLEAN_TREE:
shutil.rmtree(path)
for path in NEED_CLEAN_FILE:
@ -91,15 +92,16 @@ def clean_up() -> None:
class CMakeExtension(Extension): # pylint: disable=too-few-public-methods
'''Wrapper for extension'''
"""Wrapper for extension"""
def __init__(self, name: str) -> None:
super().__init__(name=name, sources=[])
class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
'''Custom build_ext command using CMake.'''
"""Custom build_ext command using CMake."""
logger = logging.getLogger('XGBoost build_ext')
logger = logging.getLogger("XGBoost build_ext")
# pylint: disable=too-many-arguments
def build(
@ -110,157 +112,171 @@ class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
build_tool: Optional[str] = None,
use_omp: int = 1,
) -> None:
'''Build the core library with CMake.'''
cmake_cmd = ['cmake', src_dir, generator]
"""Build the core library with CMake."""
cmake_cmd = ["cmake", src_dir, generator]
for k, v in USER_OPTIONS.items():
arg = k.replace('-', '_').upper()
arg = k.replace("-", "_").upper()
value = str(v[2])
if arg == 'USE_SYSTEM_LIBXGBOOST':
if arg == "USE_SYSTEM_LIBXGBOOST":
continue
if arg == 'USE_OPENMP' and use_omp == 0:
if arg == "USE_OPENMP" and use_omp == 0:
cmake_cmd.append("-D" + arg + "=0")
continue
cmake_cmd.append('-D' + arg + '=' + value)
cmake_cmd.append("-D" + arg + "=" + value)
# Flag for cross-compiling for Apple Silicon
# We use environment variable because it's the only way to pass down custom flags
# through the cibuildwheel package, which otherwise calls `python setup.py bdist_wheel`
# command.
if 'CIBW_TARGET_OSX_ARM64' in os.environ:
if "CIBW_TARGET_OSX_ARM64" in os.environ:
cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64")
self.logger.info('Run CMake command: %s', str(cmake_cmd))
self.logger.info("Run CMake command: %s", str(cmake_cmd))
subprocess.check_call(cmake_cmd, cwd=build_dir)
if system() != 'Windows':
if system() != "Windows":
nproc = os.cpu_count()
assert build_tool is not None
subprocess.check_call([build_tool, '-j' + str(nproc)],
cwd=build_dir)
subprocess.check_call([build_tool, "-j" + str(nproc)], cwd=build_dir)
else:
subprocess.check_call(['cmake', '--build', '.',
'--config', 'Release'], cwd=build_dir)
subprocess.check_call(
["cmake", "--build", ".", "--config", "Release"], cwd=build_dir
)
def build_cmake_extension(self) -> None:
'''Configure and build using CMake'''
if USER_OPTIONS['use-system-libxgboost'][2]:
self.logger.info('Using system libxgboost.')
"""Configure and build using CMake"""
if USER_OPTIONS["use-system-libxgboost"][2]:
self.logger.info("Using system libxgboost.")
return
build_dir = self.build_temp
global BUILD_TEMP_DIR # pylint: disable=global-statement
BUILD_TEMP_DIR = build_dir
libxgboost = os.path.abspath(
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name()))
os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
)
if os.path.exists(libxgboost):
self.logger.info('Found shared library, skipping build.')
self.logger.info("Found shared library, skipping build.")
return
src_dir = 'xgboost'
src_dir = "xgboost"
try:
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir),
os.path.join(self.build_temp, src_dir))
copy_tree(
os.path.join(CURRENT_DIR, os.path.pardir),
os.path.join(self.build_temp, src_dir),
)
except Exception: # pylint: disable=broad-except
copy_tree(src_dir, os.path.join(self.build_temp, src_dir))
self.logger.info('Building from source. %s', libxgboost)
self.logger.info("Building from source. %s", libxgboost)
if not os.path.exists(build_dir):
os.mkdir(build_dir)
if shutil.which('ninja'):
build_tool = 'ninja'
if shutil.which("ninja"):
build_tool = "ninja"
else:
build_tool = 'make'
if sys.platform.startswith('os400'):
build_tool = 'make'
build_tool = "make"
if sys.platform.startswith("os400"):
build_tool = "make"
if system() == 'Windows':
if system() == "Windows":
# Pick up from LGB, just test every possible tool chain.
for vs in (
"-GVisual Studio 17 2022",
'-GVisual Studio 16 2019',
'-GVisual Studio 15 2017',
'-GVisual Studio 14 2015',
'-GMinGW Makefiles',
"-GVisual Studio 16 2019",
"-GVisual Studio 15 2017",
"-GVisual Studio 14 2015",
"-GMinGW Makefiles",
):
try:
self.build(src_dir, build_dir, vs)
self.logger.info(
'%s is used for building Windows distribution.', vs)
"%s is used for building Windows distribution.", vs
)
break
except subprocess.CalledProcessError:
shutil.rmtree(build_dir)
os.mkdir(build_dir)
continue
else:
gen = '-GNinja' if build_tool == 'ninja' else '-GUnix Makefiles'
gen = "-GNinja" if build_tool == "ninja" else "-GUnix Makefiles"
try:
self.build(src_dir, build_dir, gen, build_tool, use_omp=1)
except subprocess.CalledProcessError:
self.logger.warning('Disabling OpenMP support.')
self.logger.warning("Disabling OpenMP support.")
self.build(src_dir, build_dir, gen, build_tool, use_omp=0)
def build_extension(self, ext: Extension) -> None:
'''Override the method for dispatching.'''
"""Override the method for dispatching."""
if isinstance(ext, CMakeExtension):
self.build_cmake_extension()
else:
super().build_extension(ext)
def copy_extensions_to_source(self) -> None:
'''Dummy override. Invoked during editable installation. Our binary
"""Dummy override. Invoked during editable installation. Our binary
should available in `lib`.
'''
"""
if not os.path.exists(
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name())):
raise ValueError('For using editable installation, please ' +
'build the shared object first with CMake.')
os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
):
raise ValueError(
"For using editable installation, please "
+ "build the shared object first with CMake."
)
class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors
'''Copy c++ source into Python directory.'''
logger = logging.getLogger('xgboost sdist')
class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors
"""Copy c++ source into Python directory."""
logger = logging.getLogger("xgboost sdist")
def run(self) -> None:
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir),
os.path.join(CURRENT_DIR, 'xgboost'))
libxgboost = os.path.join(
CURRENT_DIR, os.path.pardir, 'lib', lib_name())
copy_tree(
os.path.join(CURRENT_DIR, os.path.pardir),
os.path.join(CURRENT_DIR, "xgboost"),
)
libxgboost = os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
if os.path.exists(libxgboost):
self.logger.warning(
'Found shared library, removing to avoid being included in source distribution.'
"Found shared library, removing to avoid being included in source distribution."
)
os.remove(libxgboost)
super().run()
class InstallLib(install_lib.install_lib):
'''Copy shared object into installation directory.'''
logger = logging.getLogger('xgboost install_lib')
"""Copy shared object into installation directory."""
logger = logging.getLogger("xgboost install_lib")
def install(self) -> List[str]:
outfiles = super().install()
if USER_OPTIONS['use-system-libxgboost'][2] != 0:
self.logger.info('Using system libxgboost.')
lib_path = os.path.join(sys.prefix, 'lib')
msg = 'use-system-libxgboost is specified, but ' + lib_name() + \
' is not found in: ' + lib_path
if USER_OPTIONS["use-system-libxgboost"][2] != 0:
self.logger.info("Using system libxgboost.")
lib_path = os.path.join(sys.prefix, "lib")
msg = (
"use-system-libxgboost is specified, but "
+ lib_name()
+ " is not found in: "
+ lib_path
)
assert os.path.exists(os.path.join(lib_path, lib_name())), msg
return []
lib_dir = os.path.join(self.install_dir, 'xgboost', 'lib')
lib_dir = os.path.join(self.install_dir, "xgboost", "lib")
if not os.path.exists(lib_dir):
os.mkdir(lib_dir)
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name())
dst = os.path.join(self.install_dir, "xgboost", "lib", lib_name())
libxgboost_path = lib_name()
assert BUILD_TEMP_DIR is not None
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib')
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib')
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, "lib")
build_dir = os.path.join(BUILD_TEMP_DIR, "xgboost", "lib")
if os.path.exists(os.path.join(dft_lib_dir, libxgboost_path)):
# The library is built by CMake directly
@ -268,18 +284,21 @@ class InstallLib(install_lib.install_lib):
else:
# The library is built by setup.py
src = os.path.join(build_dir, libxgboost_path)
self.logger.info('Installing shared library: %s', src)
self.logger.info("Installing shared library: %s", src)
dst, _ = self.copy_file(src, dst)
outfiles.append(dst)
return outfiles
class Install(install.install): # pylint: disable=too-many-instance-attributes
'''An interface to install command, accepting XGBoost specific
"""An interface to install command, accepting XGBoost specific
arguments.
'''
user_options = install.install.user_options + [(k, v[0], v[1]) for k, v in USER_OPTIONS.items()]
"""
user_options = install.install.user_options + [
(k, v[0], v[1]) for k, v in USER_OPTIONS.items()
]
def initialize_options(self) -> None:
super().initialize_options()
@ -302,13 +321,13 @@ class Install(install.install): # pylint: disable=too-many-instance-attributes
# arguments, then here we propagate them into `USER_OPTIONS` for visibility to
# other sub-commands like `build_ext`.
for k, v in USER_OPTIONS.items():
arg = k.replace('-', '_')
arg = k.replace("-", "_")
if hasattr(self, arg):
USER_OPTIONS[k] = (v[0], v[1], getattr(self, arg))
super().run()
if __name__ == '__main__':
if __name__ == "__main__":
# Supported commands:
# From internet:
# - pip install xgboost
@ -326,51 +345,55 @@ if __name__ == '__main__':
# - python setup.py develop # same as above
logging.basicConfig(level=logging.INFO)
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd:
with open(os.path.join(CURRENT_DIR, "README.rst"), encoding="utf-8") as fd:
description = fd.read()
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd:
with open(os.path.join(CURRENT_DIR, "xgboost/VERSION"), encoding="ascii") as fd:
version = fd.read().strip()
setup(name='xgboost',
version=version,
description="XGBoost Python Package",
long_description=description,
long_description_content_type="text/x-rst",
install_requires=[
'numpy',
'scipy',
],
ext_modules=[CMakeExtension('libxgboost')],
# error: expected "str": "Type[Command]"
cmdclass={
'build_ext': BuildExt, # type: ignore
'sdist': Sdist, # type: ignore
'install_lib': InstallLib, # type: ignore
'install': Install # type: ignore
},
extras_require={
'pandas': ['pandas'],
'scikit-learn': ['scikit-learn'],
'dask': ['dask', 'pandas', 'distributed'],
'datatable': ['datatable'],
'plotting': ['graphviz', 'matplotlib'],
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
},
maintainer='Hyunsu Cho',
maintainer_email='chohyu01@cs.washington.edu',
zip_safe=False,
packages=find_packages(),
include_package_data=True,
license='Apache-2.0',
classifiers=['License :: OSI Approved :: Apache Software License',
'Development Status :: 5 - Production/Stable',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10'],
python_requires=">=3.8",
url='https://github.com/dmlc/xgboost')
setup(
name="xgboost",
version=version,
description="XGBoost Python Package",
long_description=description,
long_description_content_type="text/x-rst",
install_requires=[
"numpy",
"scipy",
],
ext_modules=[CMakeExtension("libxgboost")],
# error: expected "str": "Type[Command]"
cmdclass={
"build_ext": BuildExt, # type: ignore
"sdist": Sdist, # type: ignore
"install_lib": InstallLib, # type: ignore
"install": Install, # type: ignore
},
extras_require={
"pandas": ["pandas"],
"scikit-learn": ["scikit-learn"],
"dask": ["dask", "pandas", "distributed"],
"datatable": ["datatable"],
"plotting": ["graphviz", "matplotlib"],
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
},
maintainer="Hyunsu Cho",
maintainer_email="chohyu01@cs.washington.edu",
zip_safe=False,
packages=find_packages(),
include_package_data=True,
license="Apache-2.0",
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Development Status :: 5 - Production/Stable",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
python_requires=">=3.8",
url="https://github.com/dmlc/xgboost",
)
clean_up()

View File

@ -152,42 +152,52 @@ def broadcast(data: _T, root: int) -> _T:
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
assert data is not None, "need to pass in data when broadcasting"
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root))
_check_call(
_LIB.XGCommunicatorBroadcast(
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
)
)
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root))
_check_call(
_LIB.XGCommunicatorBroadcast(
ctypes.cast(dptr, ctypes.c_void_p), length.value, root
)
)
data = pickle.loads(dptr.raw)
del dptr
else:
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root))
_check_call(
_LIB.XGCommunicatorBroadcast(
ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), length.value, root
)
)
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8'): 0,
np.dtype('uint8'): 1,
np.dtype('int32'): 2,
np.dtype('uint32'): 3,
np.dtype('int64'): 4,
np.dtype('uint64'): 5,
np.dtype('float32'): 6,
np.dtype('float64'): 7
np.dtype("int8"): 0,
np.dtype("uint8"): 1,
np.dtype("int32"): 2,
np.dtype("uint32"): 3,
np.dtype("int64"): 4,
np.dtype("uint64"): 5,
np.dtype("float32"): 6,
np.dtype("float64"): 7,
}
@unique
class Op(IntEnum):
"""Supported operations for allreduce."""
MAX = 0
MIN = 1
SUM = 2
@ -196,9 +206,7 @@ class Op(IntEnum):
BITWISE_XOR = 5
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op
) -> np.ndarray:
def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid-name
"""Perform allreduce, return the result.
Parameters
@ -218,15 +226,22 @@ def allreduce( # pylint:disable=invalid-name
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise TypeError('allreduce only takes in numpy.ndarray')
raise TypeError("allreduce only takes in numpy.ndarray")
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
int(op), None, None))
_check_call(
_LIB.XGCommunicatorAllreduce(
buf.ctypes.data_as(ctypes.c_void_p),
buf.size,
DTYPE_ENUM__[buf.dtype],
int(op),
None,
None,
)
)
return buf

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# pylint: disable=too-many-arguments, too-many-branches, too-many-lines
# pylint: disable=too-many-return-statements, import-error
'''Data dispatching for DMatrix.'''
"""Data dispatching for DMatrix."""
import ctypes
import json
import os
@ -108,6 +108,7 @@ def _from_scipy_csr(
feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSR matrix."""
handle = ctypes.c_void_p()
data = transform_scipy_sparse(data, True)
_check_call(
@ -178,8 +179,7 @@ def _ensure_np_dtype(
def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
"""Handle numpy slice. This can be removed if we use __array_interface__."""
try:
if not data.flags.c_contiguous:
data = np.array(data, copy=True, dtype=dtype)
@ -653,6 +653,7 @@ def _is_arrow(data: DataType) -> bool:
try:
import pyarrow as pa
from pyarrow import dataset as arrow_dataset
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
except ImportError:
return False
@ -878,8 +879,8 @@ def _is_cupy_array(data: DataType) -> bool:
def _transform_cupy_array(data: DataType) -> CupyT:
import cupy # pylint: disable=import-error
if not hasattr(data, '__cuda_array_interface__') and hasattr(
data, '__array__'):
if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"):
data = cupy.array(data, copy=False)
if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]:
data = data.astype(cupy.float32, copy=False)
@ -900,9 +901,9 @@ def _from_cupy_array(
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
_check_call(
_LIB.XGDMatrixCreateFromCudaArrayInterface(
interface_str,
config,
ctypes.byref(handle)))
interface_str, config, ctypes.byref(handle)
)
)
return handle, feature_names, feature_types
@ -923,12 +924,13 @@ def _is_cupy_csc(data: DataType) -> bool:
def _is_dlpack(data: DataType) -> bool:
return 'PyCapsule' in str(type(data)) and "dltensor" in str(data)
return "PyCapsule" in str(type(data)) and "dltensor" in str(data)
def _transform_dlpack(data: DataType) -> bool:
from cupy import fromDlpack # pylint: disable=E0401
assert 'used_dltensor' not in str(data)
assert "used_dltensor" not in str(data)
data = fromDlpack(data)
return data
@ -941,8 +943,7 @@ def _from_dlpack(
feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType:
data = _transform_dlpack(data)
return _from_cupy_array(data, missing, nthread, feature_names,
feature_types)
return _from_cupy_array(data, missing, nthread, feature_names, feature_types)
def _is_uri(data: DataType) -> bool:
@ -1003,13 +1004,13 @@ def _is_iter(data: DataType) -> bool:
def _has_array_protocol(data: DataType) -> bool:
return hasattr(data, '__array__')
return hasattr(data, "__array__")
def _convert_unknown_data(data: DataType) -> DataType:
warnings.warn(
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
UserWarning
f"Unknown data type: {type(data)}, trying to convert it to csr_matrix",
UserWarning,
)
try:
import scipy.sparse
@ -1018,7 +1019,7 @@ def _convert_unknown_data(data: DataType) -> DataType:
try:
data = scipy.sparse.csr_matrix(data)
except Exception: # pylint: disable=broad-except
except Exception: # pylint: disable=broad-except
return None
return data
@ -1033,7 +1034,7 @@ def dispatch_data_backend(
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.'''
"""Dispatch data for DMatrix."""
if not _is_cudf_ser(data) and not _is_pandas_series(data):
_check_data_shape(data)
if _is_scipy_csr(data):
@ -1054,6 +1055,7 @@ def dispatch_data_backend(
return _from_tuple(data, missing, threads, feature_names, feature_types)
if _is_pandas_series(data):
import pandas as pd
data = pd.DataFrame(data)
if _is_pandas_df(data):
return _from_pandas_df(
@ -1064,39 +1066,41 @@ def dispatch_data_backend(
data, missing, threads, feature_names, feature_types, enable_categorical
)
if _is_cupy_array(data):
return _from_cupy_array(data, missing, threads, feature_names,
feature_types)
return _from_cupy_array(data, missing, threads, feature_names, feature_types)
if _is_cupy_csr(data):
raise TypeError('cupyx CSR is not supported yet.')
raise TypeError("cupyx CSR is not supported yet.")
if _is_cupy_csc(data):
raise TypeError('cupyx CSC is not supported yet.')
raise TypeError("cupyx CSC is not supported yet.")
if _is_dlpack(data):
return _from_dlpack(data, missing, threads, feature_names,
feature_types)
return _from_dlpack(data, missing, threads, feature_names, feature_types)
if _is_dt_df(data):
_warn_unused_missing(data, missing)
return _from_dt_df(
data, missing, threads, feature_names, feature_types, enable_categorical
)
if _is_modin_df(data):
return _from_pandas_df(data, enable_categorical, missing, threads,
feature_names, feature_types)
return _from_pandas_df(
data, enable_categorical, missing, threads, feature_names, feature_types
)
if _is_modin_series(data):
return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types
)
if _is_arrow(data):
return _from_arrow(
data, missing, threads, feature_names, feature_types, enable_categorical)
data, missing, threads, feature_names, feature_types, enable_categorical
)
if _has_array_protocol(data):
array = np.asarray(data)
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
converted = _convert_unknown_data(data)
if converted is not None:
return _from_scipy_csr(converted, missing, threads, feature_names, feature_types)
return _from_scipy_csr(
converted, missing, threads, feature_names, feature_types
)
raise TypeError('Not supported type for data.' + str(type(data)))
raise TypeError("Not supported type for data." + str(type(data)))
def _validate_meta_shape(data: DataType, name: str) -> None:
@ -1128,20 +1132,14 @@ def _meta_from_numpy(
def _meta_from_list(
data: Sequence,
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
) -> None:
data_np = np.array(data)
_meta_from_numpy(data_np, field, dtype, handle)
def _meta_from_tuple(
data: Sequence,
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
) -> None:
return _meta_from_list(data, field, dtype, handle)
@ -1156,39 +1154,27 @@ def _meta_from_cudf_df(data: DataType, field: str, handle: ctypes.c_void_p) -> N
def _meta_from_cudf_series(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
c_str(field),
interface))
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
def _meta_from_cupy_array(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
data = _transform_cupy_array(data)
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
c_str(field),
interface))
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
def _meta_from_dt(
data: DataType,
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
data: DataType, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
) -> None:
data, _, _ = _transform_dt_df(data, None, None, field, dtype)
_meta_from_numpy(data, field, dtype, handle)
def dispatch_meta_backend(
matrix: DMatrix,
data: DataType,
name: str,
dtype: Optional[NumpyDType] = None
matrix: DMatrix, data: DataType, name: str, dtype: Optional[NumpyDType] = None
) -> None:
'''Dispatch for meta info.'''
"""Dispatch for meta info."""
handle = matrix.handle
assert handle is not None
_validate_meta_shape(data, name)
@ -1231,7 +1217,7 @@ def dispatch_meta_backend(
_meta_from_numpy(data, name, dtype, handle)
return
if _is_modin_series(data):
data = data.values.astype('float')
data = data.values.astype("float")
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle)
return
@ -1240,19 +1226,20 @@ def dispatch_meta_backend(
array = np.asarray(data)
_meta_from_numpy(array, name, dtype, handle)
return
raise TypeError('Unsupported type for ' + name, str(type(data)))
raise TypeError("Unsupported type for " + name, str(type(data)))
class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
'''An iterator for single batch data to help creating device DMatrix.
"""An iterator for single batch data to help creating device DMatrix.
Transforming input directly to histogram with normal single batch data API
can not access weight for sketching. So this iterator acts as a staging
area for meta info.
'''
"""
def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name
self.it = 0 # pylint: disable=invalid-name
# This does not necessarily increase memory usage as the data transformation
# might use memory.

View File

@ -22,45 +22,51 @@ def find_lib_path() -> List[str]:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
# normal, after installation `lib` is copied into Python package tree.
os.path.join(curr_path, 'lib'),
os.path.join(curr_path, "lib"),
# editable installation, no copying is performed.
os.path.join(curr_path, os.path.pardir, os.path.pardir, 'lib'),
os.path.join(curr_path, os.path.pardir, os.path.pardir, "lib"),
# use libxgboost from a system prefix, if available. This should be the last
# option.
os.path.join(sys.prefix, 'lib'),
os.path.join(sys.prefix, "lib"),
]
if sys.platform == 'win32':
if platform.architecture()[0] == '64bit':
dll_path.append(
os.path.join(curr_path, '../../windows/x64/Release/'))
if sys.platform == "win32":
if platform.architecture()[0] == "64bit":
dll_path.append(os.path.join(curr_path, "../../windows/x64/Release/"))
# hack for pip installation when copy all parent source
# directory here
dll_path.append(os.path.join(curr_path, './windows/x64/Release/'))
dll_path.append(os.path.join(curr_path, "./windows/x64/Release/"))
else:
dll_path.append(os.path.join(curr_path, '../../windows/Release/'))
dll_path.append(os.path.join(curr_path, "../../windows/Release/"))
# hack for pip installation when copy all parent source
# directory here
dll_path.append(os.path.join(curr_path, './windows/Release/'))
dll_path = [os.path.join(p, 'xgboost.dll') for p in dll_path]
elif sys.platform.startswith(('linux', 'freebsd', 'emscripten')):
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path]
elif sys.platform == 'darwin':
dll_path = [os.path.join(p, 'libxgboost.dylib') for p in dll_path]
elif sys.platform == 'cygwin':
dll_path = [os.path.join(p, 'cygxgboost.dll') for p in dll_path]
if platform.system() == 'OS400':
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path]
dll_path.append(os.path.join(curr_path, "./windows/Release/"))
dll_path = [os.path.join(p, "xgboost.dll") for p in dll_path]
elif sys.platform.startswith(("linux", "freebsd", "emscripten")):
dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
elif sys.platform == "darwin":
dll_path = [os.path.join(p, "libxgboost.dylib") for p in dll_path]
elif sys.platform == "cygwin":
dll_path = [os.path.join(p, "cygxgboost.dll") for p in dll_path]
if platform.system() == "OS400":
dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
# XGBOOST_BUILD_DOC is defined by sphinx conf.
if not lib_path and not os.environ.get('XGBOOST_BUILD_DOC', False):
link = 'https://xgboost.readthedocs.io/en/latest/build.html'
msg = 'Cannot find XGBoost Library in the candidate path. ' + \
'List of candidates:\n- ' + ('\n- '.join(dll_path)) + \
'\nXGBoost Python package path: ' + curr_path + \
'\nsys.prefix: ' + sys.prefix + \
'\nSee: ' + link + ' for installing XGBoost.'
if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False):
link = "https://xgboost.readthedocs.io/en/latest/build.html"
msg = (
"Cannot find XGBoost Library in the candidate path. "
+ "List of candidates:\n- "
+ ("\n- ".join(dll_path))
+ "\nXGBoost Python package path: "
+ curr_path
+ "\nsys.prefix: "
+ sys.prefix
+ "\nSee: "
+ link
+ " for installing XGBoost."
)
raise XGBoostLibraryNotFound(msg)
return lib_path

View File

@ -81,22 +81,24 @@ def plot_importance(
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError('You must install matplotlib to plot importance') from e
raise ImportError("You must install matplotlib to plot importance") from e
if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score(
importance_type=importance_type, fmap=fmap)
importance_type=importance_type, fmap=fmap
)
elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type, fmap=fmap)
elif isinstance(booster, dict):
importance = booster
else:
raise ValueError('tree must be Booster, XGBModel or dict instance')
raise ValueError("tree must be Booster, XGBModel or dict instance")
if not importance:
raise ValueError(
'Booster.get_score() results in empty. ' +
'This maybe caused by having all trees as decision dumps.')
"Booster.get_score() results in empty. "
+ "This maybe caused by having all trees as decision dumps."
)
tuples = [(k, importance[k]) for k in importance]
if max_num_features is not None:
@ -110,25 +112,25 @@ def plot_importance(
_, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
ax.barh(ylocs, values, align="center", height=height, **kwargs)
if show_values is True:
for x, y in zip(values, ylocs):
ax.text(x + 1, y, values_format.format(v=x), va='center')
ax.text(x + 1, y, values_format.format(v=x), va="center")
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2:
raise ValueError('xlim must be a tuple of 2 elements')
raise ValueError("xlim must be a tuple of 2 elements")
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2:
raise ValueError('ylim must be a tuple of 2 elements')
raise ValueError("ylim must be a tuple of 2 elements")
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
@ -201,44 +203,42 @@ def to_graphviz(
try:
from graphviz import Source
except ImportError as e:
raise ImportError('You must install graphviz to plot tree') from e
raise ImportError("You must install graphviz to plot tree") from e
if isinstance(booster, XGBModel):
booster = booster.get_booster()
# squash everything back into kwargs again for compatibility
parameters = 'dot'
parameters = "dot"
extra = {}
for key, value in kwargs.items():
extra[key] = value
if rankdir is not None:
kwargs['graph_attrs'] = {}
kwargs['graph_attrs']['rankdir'] = rankdir
kwargs["graph_attrs"] = {}
kwargs["graph_attrs"]["rankdir"] = rankdir
for key, value in extra.items():
if kwargs.get("graph_attrs", None) is not None:
kwargs['graph_attrs'][key] = value
kwargs["graph_attrs"][key] = value
else:
kwargs['graph_attrs'] = {}
kwargs["graph_attrs"] = {}
del kwargs[key]
if yes_color is not None or no_color is not None:
kwargs['edge'] = {}
kwargs["edge"] = {}
if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color
kwargs["edge"]["yes_color"] = yes_color
if no_color is not None:
kwargs['edge']['no_color'] = no_color
kwargs["edge"]["no_color"] = no_color
if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params
kwargs["condition_node_params"] = condition_node_params
if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params
kwargs["leaf_node_params"] = leaf_node_params
if kwargs:
parameters += ':'
parameters += ":"
parameters += json.dumps(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
tree = booster.get_dump(fmap=fmap, dump_format=parameters)[num_trees]
g = Source(tree)
return g
@ -277,19 +277,18 @@ def plot_tree(
from matplotlib import image
from matplotlib import pyplot as plt
except ImportError as e:
raise ImportError('You must install matplotlib to plot tree') from e
raise ImportError("You must install matplotlib to plot tree") from e
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
**kwargs)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
s = BytesIO()
s.write(g.pipe(format='png'))
s.write(g.pipe(format="png"))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
ax.axis("off")
return ax

View File

@ -24,7 +24,7 @@ def init(args: Optional[List[bytes]] = None) -> None:
parsed = {}
if args:
for arg in args:
kv = arg.decode().split('=')
kv = arg.decode().split("=")
if len(kv) == 2:
parsed[kv[0]] = kv[1]
collective.init(**parsed)
@ -104,6 +104,7 @@ def broadcast(data: T, root: int) -> T:
@unique
class Op(IntEnum):
"""Supported operations for rabit."""
MAX = 0
MIN = 1
SUM = 2
@ -111,7 +112,7 @@ class Op(IntEnum):
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters

View File

@ -53,7 +53,7 @@ class ExSocket:
# magic number used to verify existence of data
MAGIC_NUM = 0xff99
MAGIC_NUM = 0xFF99
def get_some_ip(host: str) -> str:
@ -334,19 +334,19 @@ class RabitTracker:
while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr)
if s.cmd == 'print':
if s.cmd == "print":
s.print(self._use_logger)
continue
if s.cmd == 'shutdown':
if s.cmd == "shutdown":
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
logging.debug('Received %s signal from %d', s.cmd, s.rank)
logging.debug("Received %s signal from %d", s.cmd, s.rank)
continue
assert s.cmd in ("start", "recover")
# lazily initialize the workers
if tree_map is None:
assert s.cmd == 'start'
assert s.cmd == "start"
if s.world_size > 0:
n_workers = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
@ -354,7 +354,7 @@ class RabitTracker:
todo_nodes = list(range(n_workers))
else:
assert s.world_size in (-1, n_workers)
if s.cmd == 'recover':
if s.cmd == "recover":
assert s.rank >= 0
rank = s.decide_rank(job_map)
@ -410,24 +410,25 @@ def get_host_ip(host_ip: Optional[str] = None) -> str:
returned as it's
"""
if host_ip is None or host_ip == 'auto':
host_ip = 'ip'
if host_ip is None or host_ip == "auto":
host_ip = "ip"
if host_ip == 'dns':
if host_ip == "dns":
host_ip = socket.getfqdn()
elif host_ip == 'ip':
elif host_ip == "ip":
from socket import gaierror
try:
host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
)
host_ip = socket.gethostbyname(socket.gethostname())
if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(('10.255.255.255', 1))
s.connect(("10.255.255.255", 1))
host_ip = s.getsockname()[0]
assert host_ip is not None
@ -458,25 +459,41 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
def main() -> None:
"""Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
parser.add_argument('--num-workers', required=True, type=int,
help='Number of worker process to be launched.')
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
parser.add_argument(
'--num-servers', default=0, type=int,
help='Number of server process to be launched. Only used in PS jobs.'
"--num-workers",
required=True,
type=int,
help="Number of worker process to be launched.",
)
parser.add_argument(
"--num-servers",
default=0,
type=int,
help="Number of server process to be launched. Only used in PS jobs.",
)
parser.add_argument(
"--host-ip",
default=None,
type=str,
help=(
"Host IP addressed, this is only needed "
+ "if the host IP cannot be automatically guessed."
),
)
parser.add_argument(
"--log-level",
default="INFO",
type=str,
choices=["INFO", "DEBUG"],
help="Logging level of the logger.",
)
parser.add_argument('--host-ip', default=None, type=str,
help=('Host IP addressed, this is only needed ' +
'if the host IP cannot be automatically guessed.'))
parser.add_argument('--log-level', default='INFO', type=str,
choices=['INFO', 'DEBUG'],
help='Logging level of the logger.')
args = parser.parse_args()
fmt = '%(asctime)s %(levelname)s %(message)s'
if args.log_level == 'INFO':
fmt = "%(asctime)s %(levelname)s %(message)s"
if args.log_level == "INFO":
level = logging.INFO
elif args.log_level == 'DEBUG':
elif args.log_level == "DEBUG":
level = logging.DEBUG
else:
raise RuntimeError(f"Unknown logging level {args.log_level}")

View File

@ -205,25 +205,29 @@ def train(
class CVPack:
""""Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]) -> None:
""""Initialize the CVPack"""
""" "Auxiliary datastruct to hold one fold of CV."""
def __init__(
self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]
) -> None:
""" "Initialize the CVPack"""
self.dtrain = dtrain
self.dtest = dtest
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.watchlist = [(dtrain, "train"), (dtest, "test")]
self.bst = Booster(param, [dtrain, dtest])
def __getattr__(self, name: str) -> Callable:
def _inner(*args: Any, **kwargs: Any) -> Any:
return getattr(self.bst, name)(*args, **kwargs)
return _inner
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
""""Update the boosters for one iteration"""
""" "Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj)
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str:
""""Evaluate the CVPack for one iteration."""
""" "Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
@ -232,38 +236,42 @@ class _PackedBooster:
self.cvfolds = cvfolds
def update(self, iteration: int, obj: Optional[Objective]) -> None:
'''Iterate through folds for update'''
"""Iterate through folds for update"""
for fold in self.cvfolds:
fold.update(iteration, obj)
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> List[str]:
'''Iterate through folds for eval'''
def eval(
self, iteration: int, feval: Optional[Metric], output_margin: bool
) -> List[str]:
"""Iterate through folds for eval"""
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result
def set_attr(self, **kwargs: Optional[str]) -> Any:
'''Iterate through folds for setting attributes'''
"""Iterate through folds for setting attributes"""
for f in self.cvfolds:
f.bst.set_attr(**kwargs)
def attr(self, key: str) -> Optional[str]:
'''Redirect to booster attr.'''
"""Redirect to booster attr."""
return self.cvfolds[0].bst.attr(key)
def set_param(self,
params: Union[Dict, Iterable[Tuple[str, Any]], str],
value: Optional[str] = None) -> None:
def set_param(
self,
params: Union[Dict, Iterable[Tuple[str, Any]], str],
value: Optional[str] = None,
) -> None:
"""Iterate through folds for set_param"""
for f in self.cvfolds:
f.bst.set_param(params, value)
def num_boosted_rounds(self) -> int:
'''Number of boosted rounds.'''
"""Number of boosted rounds."""
return self.cvfolds[0].num_boosted_rounds()
@property
def best_iteration(self) -> int:
'''Get best_iteration'''
"""Get best_iteration"""
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
@property
@ -279,7 +287,7 @@ def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarr
:param boundaries: rows index limits of each group
:return: row in group
"""
return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups])
return np.concatenate([np.arange(boundaries[g], boundaries[g + 1]) for g in groups])
def mkgroupfold(
@ -305,11 +313,17 @@ def mkgroupfold(
# list by fold of test group indexes
out_group_idset = np.array_split(idx, nfold)
# list by fold of train group indexes
in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)]
in_group_idset = [
np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
# from the group indexes, convert them to row indexes
in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset]
out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset]
in_idset = [
groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset
]
out_idset = [
groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset
]
# build the folds by taking the appropriate slices
ret = []
@ -324,7 +338,7 @@ def mkgroupfold(
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else:
tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret
@ -348,16 +362,20 @@ def mknfold(
if stratified is False and folds is None:
# Do standard k-fold cross validation. Automatically determine the folds.
if len(dall.get_uint_info('group_ptr')) > 1:
return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle)
if len(dall.get_uint_info("group_ptr")) > 1:
return mkgroupfold(
dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle
)
if shuffle is True:
idx = np.random.permutation(dall.num_row())
else:
idx = np.arange(dall.num_row())
out_idset = np.array_split(idx, nfold)
in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)]
in_idset = [
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
elif folds is not None:
# Use user specified custom split using indices
try:
@ -387,7 +405,7 @@ def mknfold(
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else:
tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret
@ -502,29 +520,32 @@ def cv(
evaluation history : list(string)
"""
if stratified is True and not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use stratified cv')
raise XGBoostError(
"sklearn needs to be installed in order to use stratified cv"
)
if isinstance(metrics, str):
metrics = [metrics]
params = params.copy()
if isinstance(params, list):
_metrics = [x[1] for x in params if x[0] == 'eval_metric']
_metrics = [x[1] for x in params if x[0] == "eval_metric"]
params = dict(params)
if 'eval_metric' in params:
params['eval_metric'] = _metrics
if "eval_metric" in params:
params["eval_metric"] = _metrics
if (not metrics) and 'eval_metric' in params:
if isinstance(params['eval_metric'], list):
metrics = params['eval_metric']
if (not metrics) and "eval_metric" in params:
if isinstance(params["eval_metric"], list):
metrics = params["eval_metric"]
else:
metrics = [params['eval_metric']]
metrics = [params["eval_metric"]]
params.pop("eval_metric", None)
results: Dict[str, List[float]] = {}
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
stratified, folds, shuffle)
cvfolds = mknfold(
dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds, shuffle
)
metric_fn = _configure_custom_metric(feval, custom_metric)
@ -555,20 +576,21 @@ def cv(
should_break = callbacks_container.after_iteration(booster, i, dtrain, None)
res = callbacks_container.aggregated_cv
for key, mean, std in cast(List[Tuple[str, float, float]], res):
if key + '-mean' not in results:
results[key + '-mean'] = []
if key + '-std' not in results:
results[key + '-std'] = []
results[key + '-mean'].append(mean)
results[key + '-std'].append(std)
if key + "-mean" not in results:
results[key + "-mean"] = []
if key + "-std" not in results:
results[key + "-std"] = []
results[key + "-mean"].append(mean)
results[key + "-std"].append(std)
if should_break:
for k in results.keys(): # pylint: disable=consider-iterating-dictionary
results[k] = results[k][:(booster.best_iteration + 1)]
results[k] = results[k][: (booster.best_iteration + 1)]
break
if as_pandas:
try:
import pandas as pd
results = pd.DataFrame.from_dict(results)
except ImportError:
pass

View File

@ -132,16 +132,7 @@ def main(args: argparse.Namespace) -> None:
run_black(path)
for path in [
# core
"python-package/xgboost/__init__.py",
"python-package/xgboost/_typing.py",
"python-package/xgboost/callback.py",
"python-package/xgboost/compat.py",
"python-package/xgboost/config.py",
"python-package/xgboost/dask.py",
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark",
"python-package/xgboost/federated.py",
"python-package/xgboost/testing",
"python-package/",
# tests
"tests/python/test_config.py",
"tests/python/test_data_iterator.py",