Require black formatter for the python package. (#8748)

This commit is contained in:
Jiaming Yuan 2023-02-07 01:53:33 +08:00 committed by GitHub
parent a2e433a089
commit 0f37a01dd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 707 additions and 574 deletions

View File

@ -19,17 +19,17 @@ sys.path.insert(0, CURRENT_DIR)
# requires using CMake directly. # requires using CMake directly.
USER_OPTIONS = { USER_OPTIONS = {
# libxgboost options. # libxgboost options.
'use-openmp': (None, 'Build with OpenMP support.', 1), "use-openmp": (None, "Build with OpenMP support.", 1),
'use-cuda': (None, 'Build with GPU acceleration.', 0), "use-cuda": (None, "Build with GPU acceleration.", 0),
'use-nccl': (None, 'Build with NCCL to enable distributed GPU support.', 0), "use-nccl": (None, "Build with NCCL to enable distributed GPU support.", 0),
'build-with-shared-nccl': (None, 'Build with shared NCCL library.', 0), "build-with-shared-nccl": (None, "Build with shared NCCL library.", 0),
'hide-cxx-symbols': (None, 'Hide all C++ symbols during build.', 1), "hide-cxx-symbols": (None, "Hide all C++ symbols during build.", 1),
'use-hdfs': (None, 'Build with HDFS support', 0), "use-hdfs": (None, "Build with HDFS support", 0),
'use-azure': (None, 'Build with AZURE support.', 0), "use-azure": (None, "Build with AZURE support.", 0),
'use-s3': (None, 'Build with S3 support', 0), "use-s3": (None, "Build with S3 support", 0),
'plugin-dense-parser': (None, 'Build dense parser plugin.', 0), "plugin-dense-parser": (None, "Build dense parser plugin.", 0),
# Python specific # Python specific
'use-system-libxgboost': (None, 'Use libxgboost.so in system path.', 0) "use-system-libxgboost": (None, "Use libxgboost.so in system path.", 0),
} }
NEED_CLEAN_TREE = set() NEED_CLEAN_TREE = set()
@ -38,20 +38,21 @@ BUILD_TEMP_DIR = None
def lib_name() -> str: def lib_name() -> str:
'''Return platform dependent shared object name.''' """Return platform dependent shared object name."""
if system() == 'Linux' or system().upper().endswith('BSD'): if system() == "Linux" or system().upper().endswith("BSD"):
name = 'libxgboost.so' name = "libxgboost.so"
elif system() == 'Darwin': elif system() == "Darwin":
name = 'libxgboost.dylib' name = "libxgboost.dylib"
elif system() == 'Windows': elif system() == "Windows":
name = 'xgboost.dll' name = "xgboost.dll"
elif system() == 'OS400': elif system() == "OS400":
name = 'libxgboost.so' name = "libxgboost.so"
return name return name
def copy_tree(src_dir: str, target_dir: str) -> None: def copy_tree(src_dir: str, target_dir: str) -> None:
'''Copy source tree into build directory.''' """Copy source tree into build directory."""
def clean_copy_tree(src: str, dst: str) -> None: def clean_copy_tree(src: str, dst: str) -> None:
shutil.copytree(src, dst) shutil.copytree(src, dst)
NEED_CLEAN_TREE.add(os.path.abspath(dst)) NEED_CLEAN_TREE.add(os.path.abspath(dst))
@ -60,30 +61,30 @@ def copy_tree(src_dir: str, target_dir: str) -> None:
shutil.copy(src, dst) shutil.copy(src, dst)
NEED_CLEAN_FILE.add(os.path.abspath(dst)) NEED_CLEAN_FILE.add(os.path.abspath(dst))
src = os.path.join(src_dir, 'src') src = os.path.join(src_dir, "src")
inc = os.path.join(src_dir, 'include') inc = os.path.join(src_dir, "include")
dmlc_core = os.path.join(src_dir, 'dmlc-core') dmlc_core = os.path.join(src_dir, "dmlc-core")
gputreeshap = os.path.join(src_dir, "gputreeshap") gputreeshap = os.path.join(src_dir, "gputreeshap")
rabit = os.path.join(src_dir, 'rabit') rabit = os.path.join(src_dir, "rabit")
cmake = os.path.join(src_dir, 'cmake') cmake = os.path.join(src_dir, "cmake")
plugin = os.path.join(src_dir, 'plugin') plugin = os.path.join(src_dir, "plugin")
clean_copy_tree(src, os.path.join(target_dir, 'src')) clean_copy_tree(src, os.path.join(target_dir, "src"))
clean_copy_tree(inc, os.path.join(target_dir, 'include')) clean_copy_tree(inc, os.path.join(target_dir, "include"))
clean_copy_tree(dmlc_core, os.path.join(target_dir, 'dmlc-core')) clean_copy_tree(dmlc_core, os.path.join(target_dir, "dmlc-core"))
clean_copy_tree(gputreeshap, os.path.join(target_dir, "gputreeshap")) clean_copy_tree(gputreeshap, os.path.join(target_dir, "gputreeshap"))
clean_copy_tree(rabit, os.path.join(target_dir, 'rabit')) clean_copy_tree(rabit, os.path.join(target_dir, "rabit"))
clean_copy_tree(cmake, os.path.join(target_dir, 'cmake')) clean_copy_tree(cmake, os.path.join(target_dir, "cmake"))
clean_copy_tree(plugin, os.path.join(target_dir, 'plugin')) clean_copy_tree(plugin, os.path.join(target_dir, "plugin"))
cmake_list = os.path.join(src_dir, 'CMakeLists.txt') cmake_list = os.path.join(src_dir, "CMakeLists.txt")
clean_copy_file(cmake_list, os.path.join(target_dir, 'CMakeLists.txt')) clean_copy_file(cmake_list, os.path.join(target_dir, "CMakeLists.txt"))
lic = os.path.join(src_dir, 'LICENSE') lic = os.path.join(src_dir, "LICENSE")
clean_copy_file(lic, os.path.join(target_dir, 'LICENSE')) clean_copy_file(lic, os.path.join(target_dir, "LICENSE"))
def clean_up() -> None: def clean_up() -> None:
'''Removed copied files.''' """Removed copied files."""
for path in NEED_CLEAN_TREE: for path in NEED_CLEAN_TREE:
shutil.rmtree(path) shutil.rmtree(path)
for path in NEED_CLEAN_FILE: for path in NEED_CLEAN_FILE:
@ -91,15 +92,16 @@ def clean_up() -> None:
class CMakeExtension(Extension): # pylint: disable=too-few-public-methods class CMakeExtension(Extension): # pylint: disable=too-few-public-methods
'''Wrapper for extension''' """Wrapper for extension"""
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(name=name, sources=[]) super().__init__(name=name, sources=[])
class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
'''Custom build_ext command using CMake.''' """Custom build_ext command using CMake."""
logger = logging.getLogger('XGBoost build_ext') logger = logging.getLogger("XGBoost build_ext")
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def build( def build(
@ -110,157 +112,171 @@ class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
build_tool: Optional[str] = None, build_tool: Optional[str] = None,
use_omp: int = 1, use_omp: int = 1,
) -> None: ) -> None:
'''Build the core library with CMake.''' """Build the core library with CMake."""
cmake_cmd = ['cmake', src_dir, generator] cmake_cmd = ["cmake", src_dir, generator]
for k, v in USER_OPTIONS.items(): for k, v in USER_OPTIONS.items():
arg = k.replace('-', '_').upper() arg = k.replace("-", "_").upper()
value = str(v[2]) value = str(v[2])
if arg == 'USE_SYSTEM_LIBXGBOOST': if arg == "USE_SYSTEM_LIBXGBOOST":
continue continue
if arg == 'USE_OPENMP' and use_omp == 0: if arg == "USE_OPENMP" and use_omp == 0:
cmake_cmd.append("-D" + arg + "=0") cmake_cmd.append("-D" + arg + "=0")
continue continue
cmake_cmd.append('-D' + arg + '=' + value) cmake_cmd.append("-D" + arg + "=" + value)
# Flag for cross-compiling for Apple Silicon # Flag for cross-compiling for Apple Silicon
# We use environment variable because it's the only way to pass down custom flags # We use environment variable because it's the only way to pass down custom flags
# through the cibuildwheel package, which otherwise calls `python setup.py bdist_wheel` # through the cibuildwheel package, which otherwise calls `python setup.py bdist_wheel`
# command. # command.
if 'CIBW_TARGET_OSX_ARM64' in os.environ: if "CIBW_TARGET_OSX_ARM64" in os.environ:
cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64") cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64")
self.logger.info('Run CMake command: %s', str(cmake_cmd)) self.logger.info("Run CMake command: %s", str(cmake_cmd))
subprocess.check_call(cmake_cmd, cwd=build_dir) subprocess.check_call(cmake_cmd, cwd=build_dir)
if system() != 'Windows': if system() != "Windows":
nproc = os.cpu_count() nproc = os.cpu_count()
assert build_tool is not None assert build_tool is not None
subprocess.check_call([build_tool, '-j' + str(nproc)], subprocess.check_call([build_tool, "-j" + str(nproc)], cwd=build_dir)
cwd=build_dir)
else: else:
subprocess.check_call(['cmake', '--build', '.', subprocess.check_call(
'--config', 'Release'], cwd=build_dir) ["cmake", "--build", ".", "--config", "Release"], cwd=build_dir
)
def build_cmake_extension(self) -> None: def build_cmake_extension(self) -> None:
'''Configure and build using CMake''' """Configure and build using CMake"""
if USER_OPTIONS['use-system-libxgboost'][2]: if USER_OPTIONS["use-system-libxgboost"][2]:
self.logger.info('Using system libxgboost.') self.logger.info("Using system libxgboost.")
return return
build_dir = self.build_temp build_dir = self.build_temp
global BUILD_TEMP_DIR # pylint: disable=global-statement global BUILD_TEMP_DIR # pylint: disable=global-statement
BUILD_TEMP_DIR = build_dir BUILD_TEMP_DIR = build_dir
libxgboost = os.path.abspath( libxgboost = os.path.abspath(
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name())) os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
)
if os.path.exists(libxgboost): if os.path.exists(libxgboost):
self.logger.info('Found shared library, skipping build.') self.logger.info("Found shared library, skipping build.")
return return
src_dir = 'xgboost' src_dir = "xgboost"
try: try:
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir), copy_tree(
os.path.join(self.build_temp, src_dir)) os.path.join(CURRENT_DIR, os.path.pardir),
os.path.join(self.build_temp, src_dir),
)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
copy_tree(src_dir, os.path.join(self.build_temp, src_dir)) copy_tree(src_dir, os.path.join(self.build_temp, src_dir))
self.logger.info('Building from source. %s', libxgboost) self.logger.info("Building from source. %s", libxgboost)
if not os.path.exists(build_dir): if not os.path.exists(build_dir):
os.mkdir(build_dir) os.mkdir(build_dir)
if shutil.which('ninja'): if shutil.which("ninja"):
build_tool = 'ninja' build_tool = "ninja"
else: else:
build_tool = 'make' build_tool = "make"
if sys.platform.startswith('os400'): if sys.platform.startswith("os400"):
build_tool = 'make' build_tool = "make"
if system() == 'Windows': if system() == "Windows":
# Pick up from LGB, just test every possible tool chain. # Pick up from LGB, just test every possible tool chain.
for vs in ( for vs in (
"-GVisual Studio 17 2022", "-GVisual Studio 17 2022",
'-GVisual Studio 16 2019', "-GVisual Studio 16 2019",
'-GVisual Studio 15 2017', "-GVisual Studio 15 2017",
'-GVisual Studio 14 2015', "-GVisual Studio 14 2015",
'-GMinGW Makefiles', "-GMinGW Makefiles",
): ):
try: try:
self.build(src_dir, build_dir, vs) self.build(src_dir, build_dir, vs)
self.logger.info( self.logger.info(
'%s is used for building Windows distribution.', vs) "%s is used for building Windows distribution.", vs
)
break break
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
shutil.rmtree(build_dir) shutil.rmtree(build_dir)
os.mkdir(build_dir) os.mkdir(build_dir)
continue continue
else: else:
gen = '-GNinja' if build_tool == 'ninja' else '-GUnix Makefiles' gen = "-GNinja" if build_tool == "ninja" else "-GUnix Makefiles"
try: try:
self.build(src_dir, build_dir, gen, build_tool, use_omp=1) self.build(src_dir, build_dir, gen, build_tool, use_omp=1)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
self.logger.warning('Disabling OpenMP support.') self.logger.warning("Disabling OpenMP support.")
self.build(src_dir, build_dir, gen, build_tool, use_omp=0) self.build(src_dir, build_dir, gen, build_tool, use_omp=0)
def build_extension(self, ext: Extension) -> None: def build_extension(self, ext: Extension) -> None:
'''Override the method for dispatching.''' """Override the method for dispatching."""
if isinstance(ext, CMakeExtension): if isinstance(ext, CMakeExtension):
self.build_cmake_extension() self.build_cmake_extension()
else: else:
super().build_extension(ext) super().build_extension(ext)
def copy_extensions_to_source(self) -> None: def copy_extensions_to_source(self) -> None:
'''Dummy override. Invoked during editable installation. Our binary """Dummy override. Invoked during editable installation. Our binary
should available in `lib`. should available in `lib`.
''' """
if not os.path.exists( if not os.path.exists(
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name())): os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
raise ValueError('For using editable installation, please ' + ):
'build the shared object first with CMake.') raise ValueError(
"For using editable installation, please "
+ "build the shared object first with CMake."
)
class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors
'''Copy c++ source into Python directory.''' """Copy c++ source into Python directory."""
logger = logging.getLogger('xgboost sdist')
logger = logging.getLogger("xgboost sdist")
def run(self) -> None: def run(self) -> None:
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir), copy_tree(
os.path.join(CURRENT_DIR, 'xgboost')) os.path.join(CURRENT_DIR, os.path.pardir),
libxgboost = os.path.join( os.path.join(CURRENT_DIR, "xgboost"),
CURRENT_DIR, os.path.pardir, 'lib', lib_name()) )
libxgboost = os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
if os.path.exists(libxgboost): if os.path.exists(libxgboost):
self.logger.warning( self.logger.warning(
'Found shared library, removing to avoid being included in source distribution.' "Found shared library, removing to avoid being included in source distribution."
) )
os.remove(libxgboost) os.remove(libxgboost)
super().run() super().run()
class InstallLib(install_lib.install_lib): class InstallLib(install_lib.install_lib):
'''Copy shared object into installation directory.''' """Copy shared object into installation directory."""
logger = logging.getLogger('xgboost install_lib')
logger = logging.getLogger("xgboost install_lib")
def install(self) -> List[str]: def install(self) -> List[str]:
outfiles = super().install() outfiles = super().install()
if USER_OPTIONS['use-system-libxgboost'][2] != 0: if USER_OPTIONS["use-system-libxgboost"][2] != 0:
self.logger.info('Using system libxgboost.') self.logger.info("Using system libxgboost.")
lib_path = os.path.join(sys.prefix, 'lib') lib_path = os.path.join(sys.prefix, "lib")
msg = 'use-system-libxgboost is specified, but ' + lib_name() + \ msg = (
' is not found in: ' + lib_path "use-system-libxgboost is specified, but "
+ lib_name()
+ " is not found in: "
+ lib_path
)
assert os.path.exists(os.path.join(lib_path, lib_name())), msg assert os.path.exists(os.path.join(lib_path, lib_name())), msg
return [] return []
lib_dir = os.path.join(self.install_dir, 'xgboost', 'lib') lib_dir = os.path.join(self.install_dir, "xgboost", "lib")
if not os.path.exists(lib_dir): if not os.path.exists(lib_dir):
os.mkdir(lib_dir) os.mkdir(lib_dir)
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name()) dst = os.path.join(self.install_dir, "xgboost", "lib", lib_name())
libxgboost_path = lib_name() libxgboost_path = lib_name()
assert BUILD_TEMP_DIR is not None assert BUILD_TEMP_DIR is not None
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib') dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, "lib")
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib') build_dir = os.path.join(BUILD_TEMP_DIR, "xgboost", "lib")
if os.path.exists(os.path.join(dft_lib_dir, libxgboost_path)): if os.path.exists(os.path.join(dft_lib_dir, libxgboost_path)):
# The library is built by CMake directly # The library is built by CMake directly
@ -268,18 +284,21 @@ class InstallLib(install_lib.install_lib):
else: else:
# The library is built by setup.py # The library is built by setup.py
src = os.path.join(build_dir, libxgboost_path) src = os.path.join(build_dir, libxgboost_path)
self.logger.info('Installing shared library: %s', src) self.logger.info("Installing shared library: %s", src)
dst, _ = self.copy_file(src, dst) dst, _ = self.copy_file(src, dst)
outfiles.append(dst) outfiles.append(dst)
return outfiles return outfiles
class Install(install.install): # pylint: disable=too-many-instance-attributes class Install(install.install): # pylint: disable=too-many-instance-attributes
'''An interface to install command, accepting XGBoost specific """An interface to install command, accepting XGBoost specific
arguments. arguments.
''' """
user_options = install.install.user_options + [(k, v[0], v[1]) for k, v in USER_OPTIONS.items()]
user_options = install.install.user_options + [
(k, v[0], v[1]) for k, v in USER_OPTIONS.items()
]
def initialize_options(self) -> None: def initialize_options(self) -> None:
super().initialize_options() super().initialize_options()
@ -302,13 +321,13 @@ class Install(install.install): # pylint: disable=too-many-instance-attributes
# arguments, then here we propagate them into `USER_OPTIONS` for visibility to # arguments, then here we propagate them into `USER_OPTIONS` for visibility to
# other sub-commands like `build_ext`. # other sub-commands like `build_ext`.
for k, v in USER_OPTIONS.items(): for k, v in USER_OPTIONS.items():
arg = k.replace('-', '_') arg = k.replace("-", "_")
if hasattr(self, arg): if hasattr(self, arg):
USER_OPTIONS[k] = (v[0], v[1], getattr(self, arg)) USER_OPTIONS[k] = (v[0], v[1], getattr(self, arg))
super().run() super().run()
if __name__ == '__main__': if __name__ == "__main__":
# Supported commands: # Supported commands:
# From internet: # From internet:
# - pip install xgboost # - pip install xgboost
@ -326,51 +345,55 @@ if __name__ == '__main__':
# - python setup.py develop # same as above # - python setup.py develop # same as above
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd: with open(os.path.join(CURRENT_DIR, "README.rst"), encoding="utf-8") as fd:
description = fd.read() description = fd.read()
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd: with open(os.path.join(CURRENT_DIR, "xgboost/VERSION"), encoding="ascii") as fd:
version = fd.read().strip() version = fd.read().strip()
setup(name='xgboost', setup(
name="xgboost",
version=version, version=version,
description="XGBoost Python Package", description="XGBoost Python Package",
long_description=description, long_description=description,
long_description_content_type="text/x-rst", long_description_content_type="text/x-rst",
install_requires=[ install_requires=[
'numpy', "numpy",
'scipy', "scipy",
], ],
ext_modules=[CMakeExtension('libxgboost')], ext_modules=[CMakeExtension("libxgboost")],
# error: expected "str": "Type[Command]" # error: expected "str": "Type[Command]"
cmdclass={ cmdclass={
'build_ext': BuildExt, # type: ignore "build_ext": BuildExt, # type: ignore
'sdist': Sdist, # type: ignore "sdist": Sdist, # type: ignore
'install_lib': InstallLib, # type: ignore "install_lib": InstallLib, # type: ignore
'install': Install # type: ignore "install": Install, # type: ignore
}, },
extras_require={ extras_require={
'pandas': ['pandas'], "pandas": ["pandas"],
'scikit-learn': ['scikit-learn'], "scikit-learn": ["scikit-learn"],
'dask': ['dask', 'pandas', 'distributed'], "dask": ["dask", "pandas", "distributed"],
'datatable': ['datatable'], "datatable": ["datatable"],
'plotting': ['graphviz', 'matplotlib'], "plotting": ["graphviz", "matplotlib"],
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"], "pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
}, },
maintainer='Hyunsu Cho', maintainer="Hyunsu Cho",
maintainer_email='chohyu01@cs.washington.edu', maintainer_email="chohyu01@cs.washington.edu",
zip_safe=False, zip_safe=False,
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
license='Apache-2.0', license="Apache-2.0",
classifiers=['License :: OSI Approved :: Apache Software License', classifiers=[
'Development Status :: 5 - Production/Stable', "License :: OSI Approved :: Apache Software License",
'Operating System :: OS Independent', "Development Status :: 5 - Production/Stable",
'Programming Language :: Python', "Operating System :: OS Independent",
'Programming Language :: Python :: 3', "Programming Language :: Python",
'Programming Language :: Python :: 3.8', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.9', "Programming Language :: Python :: 3.8",
'Programming Language :: Python :: 3.10'], "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
python_requires=">=3.8", python_requires=">=3.8",
url='https://github.com/dmlc/xgboost') url="https://github.com/dmlc/xgboost",
)
clean_up() clean_up()

View File

@ -152,42 +152,52 @@ def broadcast(data: _T, root: int) -> _T:
rank = get_rank() rank = get_rank()
length = ctypes.c_ulong() length = ctypes.c_ulong()
if root == rank: if root == rank:
assert data is not None, 'need to pass in data when broadcasting' assert data is not None, "need to pass in data when broadcasting"
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s) length.value = len(s)
# run first broadcast # run first broadcast
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length), _check_call(
ctypes.sizeof(ctypes.c_ulong), root)) _LIB.XGCommunicatorBroadcast(
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
)
)
if root != rank: if root != rank:
dptr = (ctypes.c_char * length.value)() dptr = (ctypes.c_char * length.value)()
# run second # run second
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p), _check_call(
length.value, root)) _LIB.XGCommunicatorBroadcast(
ctypes.cast(dptr, ctypes.c_void_p), length.value, root
)
)
data = pickle.loads(dptr.raw) data = pickle.loads(dptr.raw)
del dptr del dptr
else: else:
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), _check_call(
length.value, root)) _LIB.XGCommunicatorBroadcast(
ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), length.value, root
)
)
del s del s
return data return data
# enumeration of dtypes # enumeration of dtypes
DTYPE_ENUM__ = { DTYPE_ENUM__ = {
np.dtype('int8'): 0, np.dtype("int8"): 0,
np.dtype('uint8'): 1, np.dtype("uint8"): 1,
np.dtype('int32'): 2, np.dtype("int32"): 2,
np.dtype('uint32'): 3, np.dtype("uint32"): 3,
np.dtype('int64'): 4, np.dtype("int64"): 4,
np.dtype('uint64'): 5, np.dtype("uint64"): 5,
np.dtype('float32'): 6, np.dtype("float32"): 6,
np.dtype('float64'): 7 np.dtype("float64"): 7,
} }
@unique @unique
class Op(IntEnum): class Op(IntEnum):
"""Supported operations for allreduce.""" """Supported operations for allreduce."""
MAX = 0 MAX = 0
MIN = 1 MIN = 1
SUM = 2 SUM = 2
@ -196,9 +206,7 @@ class Op(IntEnum):
BITWISE_XOR = 5 BITWISE_XOR = 5
def allreduce( # pylint:disable=invalid-name def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid-name
data: np.ndarray, op: Op
) -> np.ndarray:
"""Perform allreduce, return the result. """Perform allreduce, return the result.
Parameters Parameters
@ -218,15 +226,22 @@ def allreduce( # pylint:disable=invalid-name
This function is not thread-safe. This function is not thread-safe.
""" """
if not isinstance(data, np.ndarray): if not isinstance(data, np.ndarray):
raise TypeError('allreduce only takes in numpy.ndarray') raise TypeError("allreduce only takes in numpy.ndarray")
buf = data.ravel() buf = data.ravel()
if buf.base is data.base: if buf.base is data.base:
buf = buf.copy() buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__: if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported") raise Exception(f"data type {buf.dtype} not supported")
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _check_call(
buf.size, DTYPE_ENUM__[buf.dtype], _LIB.XGCommunicatorAllreduce(
int(op), None, None)) buf.ctypes.data_as(ctypes.c_void_p),
buf.size,
DTYPE_ENUM__[buf.dtype],
int(op),
None,
None,
)
)
return buf return buf

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# pylint: disable=too-many-arguments, too-many-branches, too-many-lines # pylint: disable=too-many-arguments, too-many-branches, too-many-lines
# pylint: disable=too-many-return-statements, import-error # pylint: disable=too-many-return-statements, import-error
'''Data dispatching for DMatrix.''' """Data dispatching for DMatrix."""
import ctypes import ctypes
import json import json
import os import os
@ -108,6 +108,7 @@ def _from_scipy_csr(
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSR matrix.""" """Initialize data from a CSR matrix."""
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
data = transform_scipy_sparse(data, True) data = transform_scipy_sparse(data, True)
_check_call( _check_call(
@ -178,8 +179,7 @@ def _ensure_np_dtype(
def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray: def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
'''Handle numpy slice. This can be removed if we use __array_interface__. """Handle numpy slice. This can be removed if we use __array_interface__."""
'''
try: try:
if not data.flags.c_contiguous: if not data.flags.c_contiguous:
data = np.array(data, copy=True, dtype=dtype) data = np.array(data, copy=True, dtype=dtype)
@ -653,6 +653,7 @@ def _is_arrow(data: DataType) -> bool:
try: try:
import pyarrow as pa import pyarrow as pa
from pyarrow import dataset as arrow_dataset from pyarrow import dataset as arrow_dataset
return isinstance(data, (pa.Table, arrow_dataset.Dataset)) return isinstance(data, (pa.Table, arrow_dataset.Dataset))
except ImportError: except ImportError:
return False return False
@ -878,8 +879,8 @@ def _is_cupy_array(data: DataType) -> bool:
def _transform_cupy_array(data: DataType) -> CupyT: def _transform_cupy_array(data: DataType) -> CupyT:
import cupy # pylint: disable=import-error import cupy # pylint: disable=import-error
if not hasattr(data, '__cuda_array_interface__') and hasattr(
data, '__array__'): if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"):
data = cupy.array(data, copy=False) data = cupy.array(data, copy=False)
if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]: if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]:
data = data.astype(cupy.float32, copy=False) data = data.astype(cupy.float32, copy=False)
@ -900,9 +901,9 @@ def _from_cupy_array(
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8") config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
_check_call( _check_call(
_LIB.XGDMatrixCreateFromCudaArrayInterface( _LIB.XGDMatrixCreateFromCudaArrayInterface(
interface_str, interface_str, config, ctypes.byref(handle)
config, )
ctypes.byref(handle))) )
return handle, feature_names, feature_types return handle, feature_names, feature_types
@ -923,12 +924,13 @@ def _is_cupy_csc(data: DataType) -> bool:
def _is_dlpack(data: DataType) -> bool: def _is_dlpack(data: DataType) -> bool:
return 'PyCapsule' in str(type(data)) and "dltensor" in str(data) return "PyCapsule" in str(type(data)) and "dltensor" in str(data)
def _transform_dlpack(data: DataType) -> bool: def _transform_dlpack(data: DataType) -> bool:
from cupy import fromDlpack # pylint: disable=E0401 from cupy import fromDlpack # pylint: disable=E0401
assert 'used_dltensor' not in str(data)
assert "used_dltensor" not in str(data)
data = fromDlpack(data) data = fromDlpack(data)
return data return data
@ -941,8 +943,7 @@ def _from_dlpack(
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
data = _transform_dlpack(data) data = _transform_dlpack(data)
return _from_cupy_array(data, missing, nthread, feature_names, return _from_cupy_array(data, missing, nthread, feature_names, feature_types)
feature_types)
def _is_uri(data: DataType) -> bool: def _is_uri(data: DataType) -> bool:
@ -1003,13 +1004,13 @@ def _is_iter(data: DataType) -> bool:
def _has_array_protocol(data: DataType) -> bool: def _has_array_protocol(data: DataType) -> bool:
return hasattr(data, '__array__') return hasattr(data, "__array__")
def _convert_unknown_data(data: DataType) -> DataType: def _convert_unknown_data(data: DataType) -> DataType:
warnings.warn( warnings.warn(
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix', f"Unknown data type: {type(data)}, trying to convert it to csr_matrix",
UserWarning UserWarning,
) )
try: try:
import scipy.sparse import scipy.sparse
@ -1033,7 +1034,7 @@ def dispatch_data_backend(
enable_categorical: bool = False, enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW, data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.''' """Dispatch data for DMatrix."""
if not _is_cudf_ser(data) and not _is_pandas_series(data): if not _is_cudf_ser(data) and not _is_pandas_series(data):
_check_data_shape(data) _check_data_shape(data)
if _is_scipy_csr(data): if _is_scipy_csr(data):
@ -1054,6 +1055,7 @@ def dispatch_data_backend(
return _from_tuple(data, missing, threads, feature_names, feature_types) return _from_tuple(data, missing, threads, feature_names, feature_types)
if _is_pandas_series(data): if _is_pandas_series(data):
import pandas as pd import pandas as pd
data = pd.DataFrame(data) data = pd.DataFrame(data)
if _is_pandas_df(data): if _is_pandas_df(data):
return _from_pandas_df( return _from_pandas_df(
@ -1064,39 +1066,41 @@ def dispatch_data_backend(
data, missing, threads, feature_names, feature_types, enable_categorical data, missing, threads, feature_names, feature_types, enable_categorical
) )
if _is_cupy_array(data): if _is_cupy_array(data):
return _from_cupy_array(data, missing, threads, feature_names, return _from_cupy_array(data, missing, threads, feature_names, feature_types)
feature_types)
if _is_cupy_csr(data): if _is_cupy_csr(data):
raise TypeError('cupyx CSR is not supported yet.') raise TypeError("cupyx CSR is not supported yet.")
if _is_cupy_csc(data): if _is_cupy_csc(data):
raise TypeError('cupyx CSC is not supported yet.') raise TypeError("cupyx CSC is not supported yet.")
if _is_dlpack(data): if _is_dlpack(data):
return _from_dlpack(data, missing, threads, feature_names, return _from_dlpack(data, missing, threads, feature_names, feature_types)
feature_types)
if _is_dt_df(data): if _is_dt_df(data):
_warn_unused_missing(data, missing) _warn_unused_missing(data, missing)
return _from_dt_df( return _from_dt_df(
data, missing, threads, feature_names, feature_types, enable_categorical data, missing, threads, feature_names, feature_types, enable_categorical
) )
if _is_modin_df(data): if _is_modin_df(data):
return _from_pandas_df(data, enable_categorical, missing, threads, return _from_pandas_df(
feature_names, feature_types) data, enable_categorical, missing, threads, feature_names, feature_types
)
if _is_modin_series(data): if _is_modin_series(data):
return _from_pandas_series( return _from_pandas_series(
data, missing, threads, enable_categorical, feature_names, feature_types data, missing, threads, enable_categorical, feature_names, feature_types
) )
if _is_arrow(data): if _is_arrow(data):
return _from_arrow( return _from_arrow(
data, missing, threads, feature_names, feature_types, enable_categorical) data, missing, threads, feature_names, feature_types, enable_categorical
)
if _has_array_protocol(data): if _has_array_protocol(data):
array = np.asarray(data) array = np.asarray(data)
return _from_numpy_array(array, missing, threads, feature_names, feature_types) return _from_numpy_array(array, missing, threads, feature_names, feature_types)
converted = _convert_unknown_data(data) converted = _convert_unknown_data(data)
if converted is not None: if converted is not None:
return _from_scipy_csr(converted, missing, threads, feature_names, feature_types) return _from_scipy_csr(
converted, missing, threads, feature_names, feature_types
)
raise TypeError('Not supported type for data.' + str(type(data))) raise TypeError("Not supported type for data." + str(type(data)))
def _validate_meta_shape(data: DataType, name: str) -> None: def _validate_meta_shape(data: DataType, name: str) -> None:
@ -1128,20 +1132,14 @@ def _meta_from_numpy(
def _meta_from_list( def _meta_from_list(
data: Sequence, data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
) -> None: ) -> None:
data_np = np.array(data) data_np = np.array(data)
_meta_from_numpy(data_np, field, dtype, handle) _meta_from_numpy(data_np, field, dtype, handle)
def _meta_from_tuple( def _meta_from_tuple(
data: Sequence, data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
) -> None: ) -> None:
return _meta_from_list(data, field, dtype, handle) return _meta_from_list(data, field, dtype, handle)
@ -1156,39 +1154,27 @@ def _meta_from_cudf_df(data: DataType, field: str, handle: ctypes.c_void_p) -> N
def _meta_from_cudf_series(data: DataType, field: str, handle: ctypes.c_void_p) -> None: def _meta_from_cudf_series(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
interface = bytes(json.dumps([data.__cuda_array_interface__], interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
indent=2), 'utf-8') _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
c_str(field),
interface))
def _meta_from_cupy_array(data: DataType, field: str, handle: ctypes.c_void_p) -> None: def _meta_from_cupy_array(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
data = _transform_cupy_array(data) data = _transform_cupy_array(data)
interface = bytes(json.dumps([data.__cuda_array_interface__], interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
indent=2), 'utf-8') _check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
c_str(field),
interface))
def _meta_from_dt( def _meta_from_dt(
data: DataType, data: DataType, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
field: str,
dtype: Optional[NumpyDType],
handle: ctypes.c_void_p
) -> None: ) -> None:
data, _, _ = _transform_dt_df(data, None, None, field, dtype) data, _, _ = _transform_dt_df(data, None, None, field, dtype)
_meta_from_numpy(data, field, dtype, handle) _meta_from_numpy(data, field, dtype, handle)
def dispatch_meta_backend( def dispatch_meta_backend(
matrix: DMatrix, matrix: DMatrix, data: DataType, name: str, dtype: Optional[NumpyDType] = None
data: DataType,
name: str,
dtype: Optional[NumpyDType] = None
) -> None: ) -> None:
'''Dispatch for meta info.''' """Dispatch for meta info."""
handle = matrix.handle handle = matrix.handle
assert handle is not None assert handle is not None
_validate_meta_shape(data, name) _validate_meta_shape(data, name)
@ -1231,7 +1217,7 @@ def dispatch_meta_backend(
_meta_from_numpy(data, name, dtype, handle) _meta_from_numpy(data, name, dtype, handle)
return return
if _is_modin_series(data): if _is_modin_series(data):
data = data.values.astype('float') data = data.values.astype("float")
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
_meta_from_numpy(data, name, dtype, handle) _meta_from_numpy(data, name, dtype, handle)
return return
@ -1240,16 +1226,17 @@ def dispatch_meta_backend(
array = np.asarray(data) array = np.asarray(data)
_meta_from_numpy(array, name, dtype, handle) _meta_from_numpy(array, name, dtype, handle)
return return
raise TypeError('Unsupported type for ' + name, str(type(data))) raise TypeError("Unsupported type for " + name, str(type(data)))
class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
'''An iterator for single batch data to help creating device DMatrix. """An iterator for single batch data to help creating device DMatrix.
Transforming input directly to histogram with normal single batch data API Transforming input directly to histogram with normal single batch data API
can not access weight for sketching. So this iterator acts as a staging can not access weight for sketching. So this iterator acts as a staging
area for meta info. area for meta info.
''' """
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name self.it = 0 # pylint: disable=invalid-name

View File

@ -22,45 +22,51 @@ def find_lib_path() -> List[str]:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [ dll_path = [
# normal, after installation `lib` is copied into Python package tree. # normal, after installation `lib` is copied into Python package tree.
os.path.join(curr_path, 'lib'), os.path.join(curr_path, "lib"),
# editable installation, no copying is performed. # editable installation, no copying is performed.
os.path.join(curr_path, os.path.pardir, os.path.pardir, 'lib'), os.path.join(curr_path, os.path.pardir, os.path.pardir, "lib"),
# use libxgboost from a system prefix, if available. This should be the last # use libxgboost from a system prefix, if available. This should be the last
# option. # option.
os.path.join(sys.prefix, 'lib'), os.path.join(sys.prefix, "lib"),
] ]
if sys.platform == 'win32': if sys.platform == "win32":
if platform.architecture()[0] == '64bit': if platform.architecture()[0] == "64bit":
dll_path.append( dll_path.append(os.path.join(curr_path, "../../windows/x64/Release/"))
os.path.join(curr_path, '../../windows/x64/Release/'))
# hack for pip installation when copy all parent source # hack for pip installation when copy all parent source
# directory here # directory here
dll_path.append(os.path.join(curr_path, './windows/x64/Release/')) dll_path.append(os.path.join(curr_path, "./windows/x64/Release/"))
else: else:
dll_path.append(os.path.join(curr_path, '../../windows/Release/')) dll_path.append(os.path.join(curr_path, "../../windows/Release/"))
# hack for pip installation when copy all parent source # hack for pip installation when copy all parent source
# directory here # directory here
dll_path.append(os.path.join(curr_path, './windows/Release/')) dll_path.append(os.path.join(curr_path, "./windows/Release/"))
dll_path = [os.path.join(p, 'xgboost.dll') for p in dll_path] dll_path = [os.path.join(p, "xgboost.dll") for p in dll_path]
elif sys.platform.startswith(('linux', 'freebsd', 'emscripten')): elif sys.platform.startswith(("linux", "freebsd", "emscripten")):
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path] dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
elif sys.platform == 'darwin': elif sys.platform == "darwin":
dll_path = [os.path.join(p, 'libxgboost.dylib') for p in dll_path] dll_path = [os.path.join(p, "libxgboost.dylib") for p in dll_path]
elif sys.platform == 'cygwin': elif sys.platform == "cygwin":
dll_path = [os.path.join(p, 'cygxgboost.dll') for p in dll_path] dll_path = [os.path.join(p, "cygxgboost.dll") for p in dll_path]
if platform.system() == 'OS400': if platform.system() == "OS400":
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path] dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
# XGBOOST_BUILD_DOC is defined by sphinx conf. # XGBOOST_BUILD_DOC is defined by sphinx conf.
if not lib_path and not os.environ.get('XGBOOST_BUILD_DOC', False): if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False):
link = 'https://xgboost.readthedocs.io/en/latest/build.html' link = "https://xgboost.readthedocs.io/en/latest/build.html"
msg = 'Cannot find XGBoost Library in the candidate path. ' + \ msg = (
'List of candidates:\n- ' + ('\n- '.join(dll_path)) + \ "Cannot find XGBoost Library in the candidate path. "
'\nXGBoost Python package path: ' + curr_path + \ + "List of candidates:\n- "
'\nsys.prefix: ' + sys.prefix + \ + ("\n- ".join(dll_path))
'\nSee: ' + link + ' for installing XGBoost.' + "\nXGBoost Python package path: "
+ curr_path
+ "\nsys.prefix: "
+ sys.prefix
+ "\nSee: "
+ link
+ " for installing XGBoost."
)
raise XGBoostLibraryNotFound(msg) raise XGBoostLibraryNotFound(msg)
return lib_path return lib_path

View File

@ -81,22 +81,24 @@ def plot_importance(
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError as e: except ImportError as e:
raise ImportError('You must install matplotlib to plot importance') from e raise ImportError("You must install matplotlib to plot importance") from e
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
importance = booster.get_booster().get_score( importance = booster.get_booster().get_score(
importance_type=importance_type, fmap=fmap) importance_type=importance_type, fmap=fmap
)
elif isinstance(booster, Booster): elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type, fmap=fmap) importance = booster.get_score(importance_type=importance_type, fmap=fmap)
elif isinstance(booster, dict): elif isinstance(booster, dict):
importance = booster importance = booster
else: else:
raise ValueError('tree must be Booster, XGBModel or dict instance') raise ValueError("tree must be Booster, XGBModel or dict instance")
if not importance: if not importance:
raise ValueError( raise ValueError(
'Booster.get_score() results in empty. ' + "Booster.get_score() results in empty. "
'This maybe caused by having all trees as decision dumps.') + "This maybe caused by having all trees as decision dumps."
)
tuples = [(k, importance[k]) for k in importance] tuples = [(k, importance[k]) for k in importance]
if max_num_features is not None: if max_num_features is not None:
@ -110,25 +112,25 @@ def plot_importance(
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)
ylocs = np.arange(len(values)) ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs) ax.barh(ylocs, values, align="center", height=height, **kwargs)
if show_values is True: if show_values is True:
for x, y in zip(values, ylocs): for x, y in zip(values, ylocs):
ax.text(x + 1, y, values_format.format(v=x), va='center') ax.text(x + 1, y, values_format.format(v=x), va="center")
ax.set_yticks(ylocs) ax.set_yticks(ylocs)
ax.set_yticklabels(labels) ax.set_yticklabels(labels)
if xlim is not None: if xlim is not None:
if not isinstance(xlim, tuple) or len(xlim) != 2: if not isinstance(xlim, tuple) or len(xlim) != 2:
raise ValueError('xlim must be a tuple of 2 elements') raise ValueError("xlim must be a tuple of 2 elements")
else: else:
xlim = (0, max(values) * 1.1) xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim) ax.set_xlim(xlim)
if ylim is not None: if ylim is not None:
if not isinstance(ylim, tuple) or len(ylim) != 2: if not isinstance(ylim, tuple) or len(ylim) != 2:
raise ValueError('ylim must be a tuple of 2 elements') raise ValueError("ylim must be a tuple of 2 elements")
else: else:
ylim = (-1, len(values)) ylim = (-1, len(values))
ax.set_ylim(ylim) ax.set_ylim(ylim)
@ -201,44 +203,42 @@ def to_graphviz(
try: try:
from graphviz import Source from graphviz import Source
except ImportError as e: except ImportError as e:
raise ImportError('You must install graphviz to plot tree') from e raise ImportError("You must install graphviz to plot tree") from e
if isinstance(booster, XGBModel): if isinstance(booster, XGBModel):
booster = booster.get_booster() booster = booster.get_booster()
# squash everything back into kwargs again for compatibility # squash everything back into kwargs again for compatibility
parameters = 'dot' parameters = "dot"
extra = {} extra = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
extra[key] = value extra[key] = value
if rankdir is not None: if rankdir is not None:
kwargs['graph_attrs'] = {} kwargs["graph_attrs"] = {}
kwargs['graph_attrs']['rankdir'] = rankdir kwargs["graph_attrs"]["rankdir"] = rankdir
for key, value in extra.items(): for key, value in extra.items():
if kwargs.get("graph_attrs", None) is not None: if kwargs.get("graph_attrs", None) is not None:
kwargs['graph_attrs'][key] = value kwargs["graph_attrs"][key] = value
else: else:
kwargs['graph_attrs'] = {} kwargs["graph_attrs"] = {}
del kwargs[key] del kwargs[key]
if yes_color is not None or no_color is not None: if yes_color is not None or no_color is not None:
kwargs['edge'] = {} kwargs["edge"] = {}
if yes_color is not None: if yes_color is not None:
kwargs['edge']['yes_color'] = yes_color kwargs["edge"]["yes_color"] = yes_color
if no_color is not None: if no_color is not None:
kwargs['edge']['no_color'] = no_color kwargs["edge"]["no_color"] = no_color
if condition_node_params is not None: if condition_node_params is not None:
kwargs['condition_node_params'] = condition_node_params kwargs["condition_node_params"] = condition_node_params
if leaf_node_params is not None: if leaf_node_params is not None:
kwargs['leaf_node_params'] = leaf_node_params kwargs["leaf_node_params"] = leaf_node_params
if kwargs: if kwargs:
parameters += ':' parameters += ":"
parameters += json.dumps(kwargs) parameters += json.dumps(kwargs)
tree = booster.get_dump( tree = booster.get_dump(fmap=fmap, dump_format=parameters)[num_trees]
fmap=fmap,
dump_format=parameters)[num_trees]
g = Source(tree) g = Source(tree)
return g return g
@ -277,19 +277,18 @@ def plot_tree(
from matplotlib import image from matplotlib import image
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
except ImportError as e: except ImportError as e:
raise ImportError('You must install matplotlib to plot tree') from e raise ImportError("You must install matplotlib to plot tree") from e
if ax is None: if ax is None:
_, ax = plt.subplots(1, 1) _, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
**kwargs)
s = BytesIO() s = BytesIO()
s.write(g.pipe(format='png')) s.write(g.pipe(format="png"))
s.seek(0) s.seek(0)
img = image.imread(s) img = image.imread(s)
ax.imshow(img) ax.imshow(img)
ax.axis('off') ax.axis("off")
return ax return ax

View File

@ -24,7 +24,7 @@ def init(args: Optional[List[bytes]] = None) -> None:
parsed = {} parsed = {}
if args: if args:
for arg in args: for arg in args:
kv = arg.decode().split('=') kv = arg.decode().split("=")
if len(kv) == 2: if len(kv) == 2:
parsed[kv[0]] = kv[1] parsed[kv[0]] = kv[1]
collective.init(**parsed) collective.init(**parsed)
@ -104,6 +104,7 @@ def broadcast(data: T, root: int) -> T:
@unique @unique
class Op(IntEnum): class Op(IntEnum):
"""Supported operations for rabit.""" """Supported operations for rabit."""
MAX = 0 MAX = 0
MIN = 1 MIN = 1
SUM = 2 SUM = 2

View File

@ -53,7 +53,7 @@ class ExSocket:
# magic number used to verify existence of data # magic number used to verify existence of data
MAGIC_NUM = 0xff99 MAGIC_NUM = 0xFF99
def get_some_ip(host: str) -> str: def get_some_ip(host: str) -> str:
@ -334,19 +334,19 @@ class RabitTracker:
while len(shutdown) != n_workers: while len(shutdown) != n_workers:
fd, s_addr = self.sock.accept() fd, s_addr = self.sock.accept()
s = WorkerEntry(fd, s_addr) s = WorkerEntry(fd, s_addr)
if s.cmd == 'print': if s.cmd == "print":
s.print(self._use_logger) s.print(self._use_logger)
continue continue
if s.cmd == 'shutdown': if s.cmd == "shutdown":
assert s.rank >= 0 and s.rank not in shutdown assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn assert s.rank not in wait_conn
shutdown[s.rank] = s shutdown[s.rank] = s
logging.debug('Received %s signal from %d', s.cmd, s.rank) logging.debug("Received %s signal from %d", s.cmd, s.rank)
continue continue
assert s.cmd in ("start", "recover") assert s.cmd in ("start", "recover")
# lazily initialize the workers # lazily initialize the workers
if tree_map is None: if tree_map is None:
assert s.cmd == 'start' assert s.cmd == "start"
if s.world_size > 0: if s.world_size > 0:
n_workers = s.world_size n_workers = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(n_workers) tree_map, parent_map, ring_map = self.get_link_map(n_workers)
@ -354,7 +354,7 @@ class RabitTracker:
todo_nodes = list(range(n_workers)) todo_nodes = list(range(n_workers))
else: else:
assert s.world_size in (-1, n_workers) assert s.world_size in (-1, n_workers)
if s.cmd == 'recover': if s.cmd == "recover":
assert s.rank >= 0 assert s.rank >= 0
rank = s.decide_rank(job_map) rank = s.decide_rank(job_map)
@ -410,24 +410,25 @@ def get_host_ip(host_ip: Optional[str] = None) -> str:
returned as it's returned as it's
""" """
if host_ip is None or host_ip == 'auto': if host_ip is None or host_ip == "auto":
host_ip = 'ip' host_ip = "ip"
if host_ip == 'dns': if host_ip == "dns":
host_ip = socket.getfqdn() host_ip = socket.getfqdn()
elif host_ip == 'ip': elif host_ip == "ip":
from socket import gaierror from socket import gaierror
try: try:
host_ip = socket.gethostbyname(socket.getfqdn()) host_ip = socket.gethostbyname(socket.getfqdn())
except gaierror: except gaierror:
logging.debug( logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()' "gethostbyname(socket.getfqdn()) failed... trying on hostname()"
) )
host_ip = socket.gethostbyname(socket.gethostname()) host_ip = socket.gethostbyname(socket.gethostname())
if host_ip.startswith("127."): if host_ip.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable # doesn't have to be reachable
s.connect(('10.255.255.255', 1)) s.connect(("10.255.255.255", 1))
host_ip = s.getsockname()[0] host_ip = s.getsockname()[0]
assert host_ip is not None assert host_ip is not None
@ -458,25 +459,41 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
def main() -> None: def main() -> None:
"""Main function if tracker is executed in standalone mode.""" """Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description='Rabit Tracker start.') parser = argparse.ArgumentParser(description="Rabit Tracker start.")
parser.add_argument('--num-workers', required=True, type=int,
help='Number of worker process to be launched.')
parser.add_argument( parser.add_argument(
'--num-servers', default=0, type=int, "--num-workers",
help='Number of server process to be launched. Only used in PS jobs.' required=True,
type=int,
help="Number of worker process to be launched.",
)
parser.add_argument(
"--num-servers",
default=0,
type=int,
help="Number of server process to be launched. Only used in PS jobs.",
)
parser.add_argument(
"--host-ip",
default=None,
type=str,
help=(
"Host IP addressed, this is only needed "
+ "if the host IP cannot be automatically guessed."
),
)
parser.add_argument(
"--log-level",
default="INFO",
type=str,
choices=["INFO", "DEBUG"],
help="Logging level of the logger.",
) )
parser.add_argument('--host-ip', default=None, type=str,
help=('Host IP addressed, this is only needed ' +
'if the host IP cannot be automatically guessed.'))
parser.add_argument('--log-level', default='INFO', type=str,
choices=['INFO', 'DEBUG'],
help='Logging level of the logger.')
args = parser.parse_args() args = parser.parse_args()
fmt = '%(asctime)s %(levelname)s %(message)s' fmt = "%(asctime)s %(levelname)s %(message)s"
if args.log_level == 'INFO': if args.log_level == "INFO":
level = logging.INFO level = logging.INFO
elif args.log_level == 'DEBUG': elif args.log_level == "DEBUG":
level = logging.DEBUG level = logging.DEBUG
else: else:
raise RuntimeError(f"Unknown logging level {args.log_level}") raise RuntimeError(f"Unknown logging level {args.log_level}")

View File

@ -205,25 +205,29 @@ def train(
class CVPack: class CVPack:
""""Auxiliary datastruct to hold one fold of CV.""" """ "Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]) -> None:
""""Initialize the CVPack""" def __init__(
self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]
) -> None:
""" "Initialize the CVPack"""
self.dtrain = dtrain self.dtrain = dtrain
self.dtest = dtest self.dtest = dtest
self.watchlist = [(dtrain, 'train'), (dtest, 'test')] self.watchlist = [(dtrain, "train"), (dtest, "test")]
self.bst = Booster(param, [dtrain, dtest]) self.bst = Booster(param, [dtrain, dtest])
def __getattr__(self, name: str) -> Callable: def __getattr__(self, name: str) -> Callable:
def _inner(*args: Any, **kwargs: Any) -> Any: def _inner(*args: Any, **kwargs: Any) -> Any:
return getattr(self.bst, name)(*args, **kwargs) return getattr(self.bst, name)(*args, **kwargs)
return _inner return _inner
def update(self, iteration: int, fobj: Optional[Objective]) -> None: def update(self, iteration: int, fobj: Optional[Objective]) -> None:
""""Update the boosters for one iteration""" """ "Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj) self.bst.update(self.dtrain, iteration, fobj)
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str: def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str:
""""Evaluate the CVPack for one iteration.""" """ "Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval, output_margin) return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
@ -232,38 +236,42 @@ class _PackedBooster:
self.cvfolds = cvfolds self.cvfolds = cvfolds
def update(self, iteration: int, obj: Optional[Objective]) -> None: def update(self, iteration: int, obj: Optional[Objective]) -> None:
'''Iterate through folds for update''' """Iterate through folds for update"""
for fold in self.cvfolds: for fold in self.cvfolds:
fold.update(iteration, obj) fold.update(iteration, obj)
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> List[str]: def eval(
'''Iterate through folds for eval''' self, iteration: int, feval: Optional[Metric], output_margin: bool
) -> List[str]:
"""Iterate through folds for eval"""
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds] result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
return result return result
def set_attr(self, **kwargs: Optional[str]) -> Any: def set_attr(self, **kwargs: Optional[str]) -> Any:
'''Iterate through folds for setting attributes''' """Iterate through folds for setting attributes"""
for f in self.cvfolds: for f in self.cvfolds:
f.bst.set_attr(**kwargs) f.bst.set_attr(**kwargs)
def attr(self, key: str) -> Optional[str]: def attr(self, key: str) -> Optional[str]:
'''Redirect to booster attr.''' """Redirect to booster attr."""
return self.cvfolds[0].bst.attr(key) return self.cvfolds[0].bst.attr(key)
def set_param(self, def set_param(
self,
params: Union[Dict, Iterable[Tuple[str, Any]], str], params: Union[Dict, Iterable[Tuple[str, Any]], str],
value: Optional[str] = None) -> None: value: Optional[str] = None,
) -> None:
"""Iterate through folds for set_param""" """Iterate through folds for set_param"""
for f in self.cvfolds: for f in self.cvfolds:
f.bst.set_param(params, value) f.bst.set_param(params, value)
def num_boosted_rounds(self) -> int: def num_boosted_rounds(self) -> int:
'''Number of boosted rounds.''' """Number of boosted rounds."""
return self.cvfolds[0].num_boosted_rounds() return self.cvfolds[0].num_boosted_rounds()
@property @property
def best_iteration(self) -> int: def best_iteration(self) -> int:
'''Get best_iteration''' """Get best_iteration"""
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration"))) return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
@property @property
@ -279,7 +287,7 @@ def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarr
:param boundaries: rows index limits of each group :param boundaries: rows index limits of each group
:return: row in group :return: row in group
""" """
return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups]) return np.concatenate([np.arange(boundaries[g], boundaries[g + 1]) for g in groups])
def mkgroupfold( def mkgroupfold(
@ -305,11 +313,17 @@ def mkgroupfold(
# list by fold of test group indexes # list by fold of test group indexes
out_group_idset = np.array_split(idx, nfold) out_group_idset = np.array_split(idx, nfold)
# list by fold of train group indexes # list by fold of train group indexes
in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i]) in_group_idset = [
for k in range(nfold)] np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
# from the group indexes, convert them to row indexes # from the group indexes, convert them to row indexes
in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset] in_idset = [
out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset] groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset
]
out_idset = [
groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset
]
# build the folds by taking the appropriate slices # build the folds by taking the appropriate slices
ret = [] ret = []
@ -324,7 +338,7 @@ def mkgroupfold(
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy()) dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else: else:
tparam = param tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals] plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst)) ret.append(CVPack(dtrain, dtest, plst))
return ret return ret
@ -348,16 +362,20 @@ def mknfold(
if stratified is False and folds is None: if stratified is False and folds is None:
# Do standard k-fold cross validation. Automatically determine the folds. # Do standard k-fold cross validation. Automatically determine the folds.
if len(dall.get_uint_info('group_ptr')) > 1: if len(dall.get_uint_info("group_ptr")) > 1:
return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle) return mkgroupfold(
dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle
)
if shuffle is True: if shuffle is True:
idx = np.random.permutation(dall.num_row()) idx = np.random.permutation(dall.num_row())
else: else:
idx = np.arange(dall.num_row()) idx = np.arange(dall.num_row())
out_idset = np.array_split(idx, nfold) out_idset = np.array_split(idx, nfold)
in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i]) in_idset = [
for k in range(nfold)] np.concatenate([out_idset[i] for i in range(nfold) if k != i])
for k in range(nfold)
]
elif folds is not None: elif folds is not None:
# Use user specified custom split using indices # Use user specified custom split using indices
try: try:
@ -387,7 +405,7 @@ def mknfold(
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy()) dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else: else:
tparam = param tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals] plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst)) ret.append(CVPack(dtrain, dtest, plst))
return ret return ret
@ -502,29 +520,32 @@ def cv(
evaluation history : list(string) evaluation history : list(string)
""" """
if stratified is True and not SKLEARN_INSTALLED: if stratified is True and not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use stratified cv') raise XGBoostError(
"sklearn needs to be installed in order to use stratified cv"
)
if isinstance(metrics, str): if isinstance(metrics, str):
metrics = [metrics] metrics = [metrics]
params = params.copy() params = params.copy()
if isinstance(params, list): if isinstance(params, list):
_metrics = [x[1] for x in params if x[0] == 'eval_metric'] _metrics = [x[1] for x in params if x[0] == "eval_metric"]
params = dict(params) params = dict(params)
if 'eval_metric' in params: if "eval_metric" in params:
params['eval_metric'] = _metrics params["eval_metric"] = _metrics
if (not metrics) and 'eval_metric' in params: if (not metrics) and "eval_metric" in params:
if isinstance(params['eval_metric'], list): if isinstance(params["eval_metric"], list):
metrics = params['eval_metric'] metrics = params["eval_metric"]
else: else:
metrics = [params['eval_metric']] metrics = [params["eval_metric"]]
params.pop("eval_metric", None) params.pop("eval_metric", None)
results: Dict[str, List[float]] = {} results: Dict[str, List[float]] = {}
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, cvfolds = mknfold(
stratified, folds, shuffle) dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds, shuffle
)
metric_fn = _configure_custom_metric(feval, custom_metric) metric_fn = _configure_custom_metric(feval, custom_metric)
@ -555,20 +576,21 @@ def cv(
should_break = callbacks_container.after_iteration(booster, i, dtrain, None) should_break = callbacks_container.after_iteration(booster, i, dtrain, None)
res = callbacks_container.aggregated_cv res = callbacks_container.aggregated_cv
for key, mean, std in cast(List[Tuple[str, float, float]], res): for key, mean, std in cast(List[Tuple[str, float, float]], res):
if key + '-mean' not in results: if key + "-mean" not in results:
results[key + '-mean'] = [] results[key + "-mean"] = []
if key + '-std' not in results: if key + "-std" not in results:
results[key + '-std'] = [] results[key + "-std"] = []
results[key + '-mean'].append(mean) results[key + "-mean"].append(mean)
results[key + '-std'].append(std) results[key + "-std"].append(std)
if should_break: if should_break:
for k in results.keys(): # pylint: disable=consider-iterating-dictionary for k in results.keys(): # pylint: disable=consider-iterating-dictionary
results[k] = results[k][:(booster.best_iteration + 1)] results[k] = results[k][: (booster.best_iteration + 1)]
break break
if as_pandas: if as_pandas:
try: try:
import pandas as pd import pandas as pd
results = pd.DataFrame.from_dict(results) results = pd.DataFrame.from_dict(results)
except ImportError: except ImportError:
pass pass

View File

@ -132,16 +132,7 @@ def main(args: argparse.Namespace) -> None:
run_black(path) run_black(path)
for path in [ for path in [
# core # core
"python-package/xgboost/__init__.py", "python-package/",
"python-package/xgboost/_typing.py",
"python-package/xgboost/callback.py",
"python-package/xgboost/compat.py",
"python-package/xgboost/config.py",
"python-package/xgboost/dask.py",
"python-package/xgboost/sklearn.py",
"python-package/xgboost/spark",
"python-package/xgboost/federated.py",
"python-package/xgboost/testing",
# tests # tests
"tests/python/test_config.py", "tests/python/test_config.py",
"tests/python/test_data_iterator.py", "tests/python/test_data_iterator.py",