Require black formatter for the python package. (#8748)
This commit is contained in:
parent
a2e433a089
commit
0f37a01dd9
@ -19,17 +19,17 @@ sys.path.insert(0, CURRENT_DIR)
|
|||||||
# requires using CMake directly.
|
# requires using CMake directly.
|
||||||
USER_OPTIONS = {
|
USER_OPTIONS = {
|
||||||
# libxgboost options.
|
# libxgboost options.
|
||||||
'use-openmp': (None, 'Build with OpenMP support.', 1),
|
"use-openmp": (None, "Build with OpenMP support.", 1),
|
||||||
'use-cuda': (None, 'Build with GPU acceleration.', 0),
|
"use-cuda": (None, "Build with GPU acceleration.", 0),
|
||||||
'use-nccl': (None, 'Build with NCCL to enable distributed GPU support.', 0),
|
"use-nccl": (None, "Build with NCCL to enable distributed GPU support.", 0),
|
||||||
'build-with-shared-nccl': (None, 'Build with shared NCCL library.', 0),
|
"build-with-shared-nccl": (None, "Build with shared NCCL library.", 0),
|
||||||
'hide-cxx-symbols': (None, 'Hide all C++ symbols during build.', 1),
|
"hide-cxx-symbols": (None, "Hide all C++ symbols during build.", 1),
|
||||||
'use-hdfs': (None, 'Build with HDFS support', 0),
|
"use-hdfs": (None, "Build with HDFS support", 0),
|
||||||
'use-azure': (None, 'Build with AZURE support.', 0),
|
"use-azure": (None, "Build with AZURE support.", 0),
|
||||||
'use-s3': (None, 'Build with S3 support', 0),
|
"use-s3": (None, "Build with S3 support", 0),
|
||||||
'plugin-dense-parser': (None, 'Build dense parser plugin.', 0),
|
"plugin-dense-parser": (None, "Build dense parser plugin.", 0),
|
||||||
# Python specific
|
# Python specific
|
||||||
'use-system-libxgboost': (None, 'Use libxgboost.so in system path.', 0)
|
"use-system-libxgboost": (None, "Use libxgboost.so in system path.", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
NEED_CLEAN_TREE = set()
|
NEED_CLEAN_TREE = set()
|
||||||
@ -38,20 +38,21 @@ BUILD_TEMP_DIR = None
|
|||||||
|
|
||||||
|
|
||||||
def lib_name() -> str:
|
def lib_name() -> str:
|
||||||
'''Return platform dependent shared object name.'''
|
"""Return platform dependent shared object name."""
|
||||||
if system() == 'Linux' or system().upper().endswith('BSD'):
|
if system() == "Linux" or system().upper().endswith("BSD"):
|
||||||
name = 'libxgboost.so'
|
name = "libxgboost.so"
|
||||||
elif system() == 'Darwin':
|
elif system() == "Darwin":
|
||||||
name = 'libxgboost.dylib'
|
name = "libxgboost.dylib"
|
||||||
elif system() == 'Windows':
|
elif system() == "Windows":
|
||||||
name = 'xgboost.dll'
|
name = "xgboost.dll"
|
||||||
elif system() == 'OS400':
|
elif system() == "OS400":
|
||||||
name = 'libxgboost.so'
|
name = "libxgboost.so"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def copy_tree(src_dir: str, target_dir: str) -> None:
|
def copy_tree(src_dir: str, target_dir: str) -> None:
|
||||||
'''Copy source tree into build directory.'''
|
"""Copy source tree into build directory."""
|
||||||
|
|
||||||
def clean_copy_tree(src: str, dst: str) -> None:
|
def clean_copy_tree(src: str, dst: str) -> None:
|
||||||
shutil.copytree(src, dst)
|
shutil.copytree(src, dst)
|
||||||
NEED_CLEAN_TREE.add(os.path.abspath(dst))
|
NEED_CLEAN_TREE.add(os.path.abspath(dst))
|
||||||
@ -60,30 +61,30 @@ def copy_tree(src_dir: str, target_dir: str) -> None:
|
|||||||
shutil.copy(src, dst)
|
shutil.copy(src, dst)
|
||||||
NEED_CLEAN_FILE.add(os.path.abspath(dst))
|
NEED_CLEAN_FILE.add(os.path.abspath(dst))
|
||||||
|
|
||||||
src = os.path.join(src_dir, 'src')
|
src = os.path.join(src_dir, "src")
|
||||||
inc = os.path.join(src_dir, 'include')
|
inc = os.path.join(src_dir, "include")
|
||||||
dmlc_core = os.path.join(src_dir, 'dmlc-core')
|
dmlc_core = os.path.join(src_dir, "dmlc-core")
|
||||||
gputreeshap = os.path.join(src_dir, "gputreeshap")
|
gputreeshap = os.path.join(src_dir, "gputreeshap")
|
||||||
rabit = os.path.join(src_dir, 'rabit')
|
rabit = os.path.join(src_dir, "rabit")
|
||||||
cmake = os.path.join(src_dir, 'cmake')
|
cmake = os.path.join(src_dir, "cmake")
|
||||||
plugin = os.path.join(src_dir, 'plugin')
|
plugin = os.path.join(src_dir, "plugin")
|
||||||
|
|
||||||
clean_copy_tree(src, os.path.join(target_dir, 'src'))
|
clean_copy_tree(src, os.path.join(target_dir, "src"))
|
||||||
clean_copy_tree(inc, os.path.join(target_dir, 'include'))
|
clean_copy_tree(inc, os.path.join(target_dir, "include"))
|
||||||
clean_copy_tree(dmlc_core, os.path.join(target_dir, 'dmlc-core'))
|
clean_copy_tree(dmlc_core, os.path.join(target_dir, "dmlc-core"))
|
||||||
clean_copy_tree(gputreeshap, os.path.join(target_dir, "gputreeshap"))
|
clean_copy_tree(gputreeshap, os.path.join(target_dir, "gputreeshap"))
|
||||||
clean_copy_tree(rabit, os.path.join(target_dir, 'rabit'))
|
clean_copy_tree(rabit, os.path.join(target_dir, "rabit"))
|
||||||
clean_copy_tree(cmake, os.path.join(target_dir, 'cmake'))
|
clean_copy_tree(cmake, os.path.join(target_dir, "cmake"))
|
||||||
clean_copy_tree(plugin, os.path.join(target_dir, 'plugin'))
|
clean_copy_tree(plugin, os.path.join(target_dir, "plugin"))
|
||||||
|
|
||||||
cmake_list = os.path.join(src_dir, 'CMakeLists.txt')
|
cmake_list = os.path.join(src_dir, "CMakeLists.txt")
|
||||||
clean_copy_file(cmake_list, os.path.join(target_dir, 'CMakeLists.txt'))
|
clean_copy_file(cmake_list, os.path.join(target_dir, "CMakeLists.txt"))
|
||||||
lic = os.path.join(src_dir, 'LICENSE')
|
lic = os.path.join(src_dir, "LICENSE")
|
||||||
clean_copy_file(lic, os.path.join(target_dir, 'LICENSE'))
|
clean_copy_file(lic, os.path.join(target_dir, "LICENSE"))
|
||||||
|
|
||||||
|
|
||||||
def clean_up() -> None:
|
def clean_up() -> None:
|
||||||
'''Removed copied files.'''
|
"""Removed copied files."""
|
||||||
for path in NEED_CLEAN_TREE:
|
for path in NEED_CLEAN_TREE:
|
||||||
shutil.rmtree(path)
|
shutil.rmtree(path)
|
||||||
for path in NEED_CLEAN_FILE:
|
for path in NEED_CLEAN_FILE:
|
||||||
@ -91,15 +92,16 @@ def clean_up() -> None:
|
|||||||
|
|
||||||
|
|
||||||
class CMakeExtension(Extension): # pylint: disable=too-few-public-methods
|
class CMakeExtension(Extension): # pylint: disable=too-few-public-methods
|
||||||
'''Wrapper for extension'''
|
"""Wrapper for extension"""
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
def __init__(self, name: str) -> None:
|
||||||
super().__init__(name=name, sources=[])
|
super().__init__(name=name, sources=[])
|
||||||
|
|
||||||
|
|
||||||
class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
|
class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
|
||||||
'''Custom build_ext command using CMake.'''
|
"""Custom build_ext command using CMake."""
|
||||||
|
|
||||||
logger = logging.getLogger('XGBoost build_ext')
|
logger = logging.getLogger("XGBoost build_ext")
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def build(
|
def build(
|
||||||
@ -110,157 +112,171 @@ class BuildExt(build_ext.build_ext): # pylint: disable=too-many-ancestors
|
|||||||
build_tool: Optional[str] = None,
|
build_tool: Optional[str] = None,
|
||||||
use_omp: int = 1,
|
use_omp: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''Build the core library with CMake.'''
|
"""Build the core library with CMake."""
|
||||||
cmake_cmd = ['cmake', src_dir, generator]
|
cmake_cmd = ["cmake", src_dir, generator]
|
||||||
|
|
||||||
for k, v in USER_OPTIONS.items():
|
for k, v in USER_OPTIONS.items():
|
||||||
arg = k.replace('-', '_').upper()
|
arg = k.replace("-", "_").upper()
|
||||||
value = str(v[2])
|
value = str(v[2])
|
||||||
if arg == 'USE_SYSTEM_LIBXGBOOST':
|
if arg == "USE_SYSTEM_LIBXGBOOST":
|
||||||
continue
|
continue
|
||||||
if arg == 'USE_OPENMP' and use_omp == 0:
|
if arg == "USE_OPENMP" and use_omp == 0:
|
||||||
cmake_cmd.append("-D" + arg + "=0")
|
cmake_cmd.append("-D" + arg + "=0")
|
||||||
continue
|
continue
|
||||||
cmake_cmd.append('-D' + arg + '=' + value)
|
cmake_cmd.append("-D" + arg + "=" + value)
|
||||||
|
|
||||||
# Flag for cross-compiling for Apple Silicon
|
# Flag for cross-compiling for Apple Silicon
|
||||||
# We use environment variable because it's the only way to pass down custom flags
|
# We use environment variable because it's the only way to pass down custom flags
|
||||||
# through the cibuildwheel package, which otherwise calls `python setup.py bdist_wheel`
|
# through the cibuildwheel package, which otherwise calls `python setup.py bdist_wheel`
|
||||||
# command.
|
# command.
|
||||||
if 'CIBW_TARGET_OSX_ARM64' in os.environ:
|
if "CIBW_TARGET_OSX_ARM64" in os.environ:
|
||||||
cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64")
|
cmake_cmd.append("-DCMAKE_OSX_ARCHITECTURES=arm64")
|
||||||
|
|
||||||
self.logger.info('Run CMake command: %s', str(cmake_cmd))
|
self.logger.info("Run CMake command: %s", str(cmake_cmd))
|
||||||
subprocess.check_call(cmake_cmd, cwd=build_dir)
|
subprocess.check_call(cmake_cmd, cwd=build_dir)
|
||||||
|
|
||||||
if system() != 'Windows':
|
if system() != "Windows":
|
||||||
nproc = os.cpu_count()
|
nproc = os.cpu_count()
|
||||||
assert build_tool is not None
|
assert build_tool is not None
|
||||||
subprocess.check_call([build_tool, '-j' + str(nproc)],
|
subprocess.check_call([build_tool, "-j" + str(nproc)], cwd=build_dir)
|
||||||
cwd=build_dir)
|
|
||||||
else:
|
else:
|
||||||
subprocess.check_call(['cmake', '--build', '.',
|
subprocess.check_call(
|
||||||
'--config', 'Release'], cwd=build_dir)
|
["cmake", "--build", ".", "--config", "Release"], cwd=build_dir
|
||||||
|
)
|
||||||
|
|
||||||
def build_cmake_extension(self) -> None:
|
def build_cmake_extension(self) -> None:
|
||||||
'''Configure and build using CMake'''
|
"""Configure and build using CMake"""
|
||||||
if USER_OPTIONS['use-system-libxgboost'][2]:
|
if USER_OPTIONS["use-system-libxgboost"][2]:
|
||||||
self.logger.info('Using system libxgboost.')
|
self.logger.info("Using system libxgboost.")
|
||||||
return
|
return
|
||||||
|
|
||||||
build_dir = self.build_temp
|
build_dir = self.build_temp
|
||||||
global BUILD_TEMP_DIR # pylint: disable=global-statement
|
global BUILD_TEMP_DIR # pylint: disable=global-statement
|
||||||
BUILD_TEMP_DIR = build_dir
|
BUILD_TEMP_DIR = build_dir
|
||||||
libxgboost = os.path.abspath(
|
libxgboost = os.path.abspath(
|
||||||
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name()))
|
os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(libxgboost):
|
if os.path.exists(libxgboost):
|
||||||
self.logger.info('Found shared library, skipping build.')
|
self.logger.info("Found shared library, skipping build.")
|
||||||
return
|
return
|
||||||
|
|
||||||
src_dir = 'xgboost'
|
src_dir = "xgboost"
|
||||||
try:
|
try:
|
||||||
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir),
|
copy_tree(
|
||||||
os.path.join(self.build_temp, src_dir))
|
os.path.join(CURRENT_DIR, os.path.pardir),
|
||||||
|
os.path.join(self.build_temp, src_dir),
|
||||||
|
)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
copy_tree(src_dir, os.path.join(self.build_temp, src_dir))
|
copy_tree(src_dir, os.path.join(self.build_temp, src_dir))
|
||||||
|
|
||||||
self.logger.info('Building from source. %s', libxgboost)
|
self.logger.info("Building from source. %s", libxgboost)
|
||||||
if not os.path.exists(build_dir):
|
if not os.path.exists(build_dir):
|
||||||
os.mkdir(build_dir)
|
os.mkdir(build_dir)
|
||||||
if shutil.which('ninja'):
|
if shutil.which("ninja"):
|
||||||
build_tool = 'ninja'
|
build_tool = "ninja"
|
||||||
else:
|
else:
|
||||||
build_tool = 'make'
|
build_tool = "make"
|
||||||
if sys.platform.startswith('os400'):
|
if sys.platform.startswith("os400"):
|
||||||
build_tool = 'make'
|
build_tool = "make"
|
||||||
|
|
||||||
if system() == 'Windows':
|
if system() == "Windows":
|
||||||
# Pick up from LGB, just test every possible tool chain.
|
# Pick up from LGB, just test every possible tool chain.
|
||||||
for vs in (
|
for vs in (
|
||||||
"-GVisual Studio 17 2022",
|
"-GVisual Studio 17 2022",
|
||||||
'-GVisual Studio 16 2019',
|
"-GVisual Studio 16 2019",
|
||||||
'-GVisual Studio 15 2017',
|
"-GVisual Studio 15 2017",
|
||||||
'-GVisual Studio 14 2015',
|
"-GVisual Studio 14 2015",
|
||||||
'-GMinGW Makefiles',
|
"-GMinGW Makefiles",
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
self.build(src_dir, build_dir, vs)
|
self.build(src_dir, build_dir, vs)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
'%s is used for building Windows distribution.', vs)
|
"%s is used for building Windows distribution.", vs
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
shutil.rmtree(build_dir)
|
shutil.rmtree(build_dir)
|
||||||
os.mkdir(build_dir)
|
os.mkdir(build_dir)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
gen = '-GNinja' if build_tool == 'ninja' else '-GUnix Makefiles'
|
gen = "-GNinja" if build_tool == "ninja" else "-GUnix Makefiles"
|
||||||
try:
|
try:
|
||||||
self.build(src_dir, build_dir, gen, build_tool, use_omp=1)
|
self.build(src_dir, build_dir, gen, build_tool, use_omp=1)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
self.logger.warning('Disabling OpenMP support.')
|
self.logger.warning("Disabling OpenMP support.")
|
||||||
self.build(src_dir, build_dir, gen, build_tool, use_omp=0)
|
self.build(src_dir, build_dir, gen, build_tool, use_omp=0)
|
||||||
|
|
||||||
def build_extension(self, ext: Extension) -> None:
|
def build_extension(self, ext: Extension) -> None:
|
||||||
'''Override the method for dispatching.'''
|
"""Override the method for dispatching."""
|
||||||
if isinstance(ext, CMakeExtension):
|
if isinstance(ext, CMakeExtension):
|
||||||
self.build_cmake_extension()
|
self.build_cmake_extension()
|
||||||
else:
|
else:
|
||||||
super().build_extension(ext)
|
super().build_extension(ext)
|
||||||
|
|
||||||
def copy_extensions_to_source(self) -> None:
|
def copy_extensions_to_source(self) -> None:
|
||||||
'''Dummy override. Invoked during editable installation. Our binary
|
"""Dummy override. Invoked during editable installation. Our binary
|
||||||
should available in `lib`.
|
should available in `lib`.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
if not os.path.exists(
|
if not os.path.exists(
|
||||||
os.path.join(CURRENT_DIR, os.path.pardir, 'lib', lib_name())):
|
os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
|
||||||
raise ValueError('For using editable installation, please ' +
|
):
|
||||||
'build the shared object first with CMake.')
|
raise ValueError(
|
||||||
|
"For using editable installation, please "
|
||||||
|
+ "build the shared object first with CMake."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors
|
class Sdist(sdist.sdist): # pylint: disable=too-many-ancestors
|
||||||
'''Copy c++ source into Python directory.'''
|
"""Copy c++ source into Python directory."""
|
||||||
logger = logging.getLogger('xgboost sdist')
|
|
||||||
|
logger = logging.getLogger("xgboost sdist")
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
copy_tree(os.path.join(CURRENT_DIR, os.path.pardir),
|
copy_tree(
|
||||||
os.path.join(CURRENT_DIR, 'xgboost'))
|
os.path.join(CURRENT_DIR, os.path.pardir),
|
||||||
libxgboost = os.path.join(
|
os.path.join(CURRENT_DIR, "xgboost"),
|
||||||
CURRENT_DIR, os.path.pardir, 'lib', lib_name())
|
)
|
||||||
|
libxgboost = os.path.join(CURRENT_DIR, os.path.pardir, "lib", lib_name())
|
||||||
if os.path.exists(libxgboost):
|
if os.path.exists(libxgboost):
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
'Found shared library, removing to avoid being included in source distribution.'
|
"Found shared library, removing to avoid being included in source distribution."
|
||||||
)
|
)
|
||||||
os.remove(libxgboost)
|
os.remove(libxgboost)
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
|
|
||||||
class InstallLib(install_lib.install_lib):
|
class InstallLib(install_lib.install_lib):
|
||||||
'''Copy shared object into installation directory.'''
|
"""Copy shared object into installation directory."""
|
||||||
logger = logging.getLogger('xgboost install_lib')
|
|
||||||
|
logger = logging.getLogger("xgboost install_lib")
|
||||||
|
|
||||||
def install(self) -> List[str]:
|
def install(self) -> List[str]:
|
||||||
outfiles = super().install()
|
outfiles = super().install()
|
||||||
|
|
||||||
if USER_OPTIONS['use-system-libxgboost'][2] != 0:
|
if USER_OPTIONS["use-system-libxgboost"][2] != 0:
|
||||||
self.logger.info('Using system libxgboost.')
|
self.logger.info("Using system libxgboost.")
|
||||||
lib_path = os.path.join(sys.prefix, 'lib')
|
lib_path = os.path.join(sys.prefix, "lib")
|
||||||
msg = 'use-system-libxgboost is specified, but ' + lib_name() + \
|
msg = (
|
||||||
' is not found in: ' + lib_path
|
"use-system-libxgboost is specified, but "
|
||||||
|
+ lib_name()
|
||||||
|
+ " is not found in: "
|
||||||
|
+ lib_path
|
||||||
|
)
|
||||||
assert os.path.exists(os.path.join(lib_path, lib_name())), msg
|
assert os.path.exists(os.path.join(lib_path, lib_name())), msg
|
||||||
return []
|
return []
|
||||||
|
|
||||||
lib_dir = os.path.join(self.install_dir, 'xgboost', 'lib')
|
lib_dir = os.path.join(self.install_dir, "xgboost", "lib")
|
||||||
if not os.path.exists(lib_dir):
|
if not os.path.exists(lib_dir):
|
||||||
os.mkdir(lib_dir)
|
os.mkdir(lib_dir)
|
||||||
dst = os.path.join(self.install_dir, 'xgboost', 'lib', lib_name())
|
dst = os.path.join(self.install_dir, "xgboost", "lib", lib_name())
|
||||||
|
|
||||||
libxgboost_path = lib_name()
|
libxgboost_path = lib_name()
|
||||||
|
|
||||||
assert BUILD_TEMP_DIR is not None
|
assert BUILD_TEMP_DIR is not None
|
||||||
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, 'lib')
|
dft_lib_dir = os.path.join(CURRENT_DIR, os.path.pardir, "lib")
|
||||||
build_dir = os.path.join(BUILD_TEMP_DIR, 'xgboost', 'lib')
|
build_dir = os.path.join(BUILD_TEMP_DIR, "xgboost", "lib")
|
||||||
|
|
||||||
if os.path.exists(os.path.join(dft_lib_dir, libxgboost_path)):
|
if os.path.exists(os.path.join(dft_lib_dir, libxgboost_path)):
|
||||||
# The library is built by CMake directly
|
# The library is built by CMake directly
|
||||||
@ -268,18 +284,21 @@ class InstallLib(install_lib.install_lib):
|
|||||||
else:
|
else:
|
||||||
# The library is built by setup.py
|
# The library is built by setup.py
|
||||||
src = os.path.join(build_dir, libxgboost_path)
|
src = os.path.join(build_dir, libxgboost_path)
|
||||||
self.logger.info('Installing shared library: %s', src)
|
self.logger.info("Installing shared library: %s", src)
|
||||||
dst, _ = self.copy_file(src, dst)
|
dst, _ = self.copy_file(src, dst)
|
||||||
outfiles.append(dst)
|
outfiles.append(dst)
|
||||||
return outfiles
|
return outfiles
|
||||||
|
|
||||||
|
|
||||||
class Install(install.install): # pylint: disable=too-many-instance-attributes
|
class Install(install.install): # pylint: disable=too-many-instance-attributes
|
||||||
'''An interface to install command, accepting XGBoost specific
|
"""An interface to install command, accepting XGBoost specific
|
||||||
arguments.
|
arguments.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
user_options = install.install.user_options + [(k, v[0], v[1]) for k, v in USER_OPTIONS.items()]
|
|
||||||
|
user_options = install.install.user_options + [
|
||||||
|
(k, v[0], v[1]) for k, v in USER_OPTIONS.items()
|
||||||
|
]
|
||||||
|
|
||||||
def initialize_options(self) -> None:
|
def initialize_options(self) -> None:
|
||||||
super().initialize_options()
|
super().initialize_options()
|
||||||
@ -302,13 +321,13 @@ class Install(install.install): # pylint: disable=too-many-instance-attributes
|
|||||||
# arguments, then here we propagate them into `USER_OPTIONS` for visibility to
|
# arguments, then here we propagate them into `USER_OPTIONS` for visibility to
|
||||||
# other sub-commands like `build_ext`.
|
# other sub-commands like `build_ext`.
|
||||||
for k, v in USER_OPTIONS.items():
|
for k, v in USER_OPTIONS.items():
|
||||||
arg = k.replace('-', '_')
|
arg = k.replace("-", "_")
|
||||||
if hasattr(self, arg):
|
if hasattr(self, arg):
|
||||||
USER_OPTIONS[k] = (v[0], v[1], getattr(self, arg))
|
USER_OPTIONS[k] = (v[0], v[1], getattr(self, arg))
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# Supported commands:
|
# Supported commands:
|
||||||
# From internet:
|
# From internet:
|
||||||
# - pip install xgboost
|
# - pip install xgboost
|
||||||
@ -326,51 +345,55 @@ if __name__ == '__main__':
|
|||||||
# - python setup.py develop # same as above
|
# - python setup.py develop # same as above
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
with open(os.path.join(CURRENT_DIR, 'README.rst'), encoding='utf-8') as fd:
|
with open(os.path.join(CURRENT_DIR, "README.rst"), encoding="utf-8") as fd:
|
||||||
description = fd.read()
|
description = fd.read()
|
||||||
with open(os.path.join(CURRENT_DIR, 'xgboost/VERSION'), encoding="ascii") as fd:
|
with open(os.path.join(CURRENT_DIR, "xgboost/VERSION"), encoding="ascii") as fd:
|
||||||
version = fd.read().strip()
|
version = fd.read().strip()
|
||||||
|
|
||||||
setup(name='xgboost',
|
setup(
|
||||||
|
name="xgboost",
|
||||||
version=version,
|
version=version,
|
||||||
description="XGBoost Python Package",
|
description="XGBoost Python Package",
|
||||||
long_description=description,
|
long_description=description,
|
||||||
long_description_content_type="text/x-rst",
|
long_description_content_type="text/x-rst",
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'numpy',
|
"numpy",
|
||||||
'scipy',
|
"scipy",
|
||||||
],
|
],
|
||||||
ext_modules=[CMakeExtension('libxgboost')],
|
ext_modules=[CMakeExtension("libxgboost")],
|
||||||
# error: expected "str": "Type[Command]"
|
# error: expected "str": "Type[Command]"
|
||||||
cmdclass={
|
cmdclass={
|
||||||
'build_ext': BuildExt, # type: ignore
|
"build_ext": BuildExt, # type: ignore
|
||||||
'sdist': Sdist, # type: ignore
|
"sdist": Sdist, # type: ignore
|
||||||
'install_lib': InstallLib, # type: ignore
|
"install_lib": InstallLib, # type: ignore
|
||||||
'install': Install # type: ignore
|
"install": Install, # type: ignore
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
'pandas': ['pandas'],
|
"pandas": ["pandas"],
|
||||||
'scikit-learn': ['scikit-learn'],
|
"scikit-learn": ["scikit-learn"],
|
||||||
'dask': ['dask', 'pandas', 'distributed'],
|
"dask": ["dask", "pandas", "distributed"],
|
||||||
'datatable': ['datatable'],
|
"datatable": ["datatable"],
|
||||||
'plotting': ['graphviz', 'matplotlib'],
|
"plotting": ["graphviz", "matplotlib"],
|
||||||
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
|
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
|
||||||
},
|
},
|
||||||
maintainer='Hyunsu Cho',
|
maintainer="Hyunsu Cho",
|
||||||
maintainer_email='chohyu01@cs.washington.edu',
|
maintainer_email="chohyu01@cs.washington.edu",
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
license='Apache-2.0',
|
license="Apache-2.0",
|
||||||
classifiers=['License :: OSI Approved :: Apache Software License',
|
classifiers=[
|
||||||
'Development Status :: 5 - Production/Stable',
|
"License :: OSI Approved :: Apache Software License",
|
||||||
'Operating System :: OS Independent',
|
"Development Status :: 5 - Production/Stable",
|
||||||
'Programming Language :: Python',
|
"Operating System :: OS Independent",
|
||||||
'Programming Language :: Python :: 3',
|
"Programming Language :: Python",
|
||||||
'Programming Language :: Python :: 3.8',
|
"Programming Language :: Python :: 3",
|
||||||
'Programming Language :: Python :: 3.9',
|
"Programming Language :: Python :: 3.8",
|
||||||
'Programming Language :: Python :: 3.10'],
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
],
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
url='https://github.com/dmlc/xgboost')
|
url="https://github.com/dmlc/xgboost",
|
||||||
|
)
|
||||||
|
|
||||||
clean_up()
|
clean_up()
|
||||||
|
|||||||
@ -152,42 +152,52 @@ def broadcast(data: _T, root: int) -> _T:
|
|||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
if root == rank:
|
if root == rank:
|
||||||
assert data is not None, 'need to pass in data when broadcasting'
|
assert data is not None, "need to pass in data when broadcasting"
|
||||||
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
length.value = len(s)
|
length.value = len(s)
|
||||||
# run first broadcast
|
# run first broadcast
|
||||||
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length),
|
_check_call(
|
||||||
ctypes.sizeof(ctypes.c_ulong), root))
|
_LIB.XGCommunicatorBroadcast(
|
||||||
|
ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root
|
||||||
|
)
|
||||||
|
)
|
||||||
if root != rank:
|
if root != rank:
|
||||||
dptr = (ctypes.c_char * length.value)()
|
dptr = (ctypes.c_char * length.value)()
|
||||||
# run second
|
# run second
|
||||||
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
_check_call(
|
||||||
length.value, root))
|
_LIB.XGCommunicatorBroadcast(
|
||||||
|
ctypes.cast(dptr, ctypes.c_void_p), length.value, root
|
||||||
|
)
|
||||||
|
)
|
||||||
data = pickle.loads(dptr.raw)
|
data = pickle.loads(dptr.raw)
|
||||||
del dptr
|
del dptr
|
||||||
else:
|
else:
|
||||||
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
_check_call(
|
||||||
length.value, root))
|
_LIB.XGCommunicatorBroadcast(
|
||||||
|
ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p), length.value, root
|
||||||
|
)
|
||||||
|
)
|
||||||
del s
|
del s
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
# enumeration of dtypes
|
# enumeration of dtypes
|
||||||
DTYPE_ENUM__ = {
|
DTYPE_ENUM__ = {
|
||||||
np.dtype('int8'): 0,
|
np.dtype("int8"): 0,
|
||||||
np.dtype('uint8'): 1,
|
np.dtype("uint8"): 1,
|
||||||
np.dtype('int32'): 2,
|
np.dtype("int32"): 2,
|
||||||
np.dtype('uint32'): 3,
|
np.dtype("uint32"): 3,
|
||||||
np.dtype('int64'): 4,
|
np.dtype("int64"): 4,
|
||||||
np.dtype('uint64'): 5,
|
np.dtype("uint64"): 5,
|
||||||
np.dtype('float32'): 6,
|
np.dtype("float32"): 6,
|
||||||
np.dtype('float64'): 7
|
np.dtype("float64"): 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class Op(IntEnum):
|
class Op(IntEnum):
|
||||||
"""Supported operations for allreduce."""
|
"""Supported operations for allreduce."""
|
||||||
|
|
||||||
MAX = 0
|
MAX = 0
|
||||||
MIN = 1
|
MIN = 1
|
||||||
SUM = 2
|
SUM = 2
|
||||||
@ -196,9 +206,7 @@ class Op(IntEnum):
|
|||||||
BITWISE_XOR = 5
|
BITWISE_XOR = 5
|
||||||
|
|
||||||
|
|
||||||
def allreduce( # pylint:disable=invalid-name
|
def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid-name
|
||||||
data: np.ndarray, op: Op
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Perform allreduce, return the result.
|
"""Perform allreduce, return the result.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -218,15 +226,22 @@ def allreduce( # pylint:disable=invalid-name
|
|||||||
This function is not thread-safe.
|
This function is not thread-safe.
|
||||||
"""
|
"""
|
||||||
if not isinstance(data, np.ndarray):
|
if not isinstance(data, np.ndarray):
|
||||||
raise TypeError('allreduce only takes in numpy.ndarray')
|
raise TypeError("allreduce only takes in numpy.ndarray")
|
||||||
buf = data.ravel()
|
buf = data.ravel()
|
||||||
if buf.base is data.base:
|
if buf.base is data.base:
|
||||||
buf = buf.copy()
|
buf = buf.copy()
|
||||||
if buf.dtype not in DTYPE_ENUM__:
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
raise Exception(f"data type {buf.dtype} not supported")
|
raise Exception(f"data type {buf.dtype} not supported")
|
||||||
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
_check_call(
|
||||||
buf.size, DTYPE_ENUM__[buf.dtype],
|
_LIB.XGCommunicatorAllreduce(
|
||||||
int(op), None, None))
|
buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
buf.size,
|
||||||
|
DTYPE_ENUM__[buf.dtype],
|
||||||
|
int(op),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-branches, too-many-lines
|
# pylint: disable=too-many-arguments, too-many-branches, too-many-lines
|
||||||
# pylint: disable=too-many-return-statements, import-error
|
# pylint: disable=too-many-return-statements, import-error
|
||||||
'''Data dispatching for DMatrix.'''
|
"""Data dispatching for DMatrix."""
|
||||||
import ctypes
|
import ctypes
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -108,6 +108,7 @@ def _from_scipy_csr(
|
|||||||
feature_types: Optional[FeatureTypes],
|
feature_types: Optional[FeatureTypes],
|
||||||
) -> DispatchedDataBackendReturnType:
|
) -> DispatchedDataBackendReturnType:
|
||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
|
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
data = transform_scipy_sparse(data, True)
|
data = transform_scipy_sparse(data, True)
|
||||||
_check_call(
|
_check_call(
|
||||||
@ -178,8 +179,7 @@ def _ensure_np_dtype(
|
|||||||
|
|
||||||
|
|
||||||
def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
|
def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray:
|
||||||
'''Handle numpy slice. This can be removed if we use __array_interface__.
|
"""Handle numpy slice. This can be removed if we use __array_interface__."""
|
||||||
'''
|
|
||||||
try:
|
try:
|
||||||
if not data.flags.c_contiguous:
|
if not data.flags.c_contiguous:
|
||||||
data = np.array(data, copy=True, dtype=dtype)
|
data = np.array(data, copy=True, dtype=dtype)
|
||||||
@ -653,6 +653,7 @@ def _is_arrow(data: DataType) -> bool:
|
|||||||
try:
|
try:
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pyarrow import dataset as arrow_dataset
|
from pyarrow import dataset as arrow_dataset
|
||||||
|
|
||||||
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
return isinstance(data, (pa.Table, arrow_dataset.Dataset))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
@ -878,8 +879,8 @@ def _is_cupy_array(data: DataType) -> bool:
|
|||||||
|
|
||||||
def _transform_cupy_array(data: DataType) -> CupyT:
|
def _transform_cupy_array(data: DataType) -> CupyT:
|
||||||
import cupy # pylint: disable=import-error
|
import cupy # pylint: disable=import-error
|
||||||
if not hasattr(data, '__cuda_array_interface__') and hasattr(
|
|
||||||
data, '__array__'):
|
if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"):
|
||||||
data = cupy.array(data, copy=False)
|
data = cupy.array(data, copy=False)
|
||||||
if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]:
|
if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]:
|
||||||
data = data.astype(cupy.float32, copy=False)
|
data = data.astype(cupy.float32, copy=False)
|
||||||
@ -900,9 +901,9 @@ def _from_cupy_array(
|
|||||||
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
|
config = bytes(json.dumps({"missing": missing, "nthread": nthread}), "utf-8")
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromCudaArrayInterface(
|
_LIB.XGDMatrixCreateFromCudaArrayInterface(
|
||||||
interface_str,
|
interface_str, config, ctypes.byref(handle)
|
||||||
config,
|
)
|
||||||
ctypes.byref(handle)))
|
)
|
||||||
return handle, feature_names, feature_types
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
@ -923,12 +924,13 @@ def _is_cupy_csc(data: DataType) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _is_dlpack(data: DataType) -> bool:
|
def _is_dlpack(data: DataType) -> bool:
|
||||||
return 'PyCapsule' in str(type(data)) and "dltensor" in str(data)
|
return "PyCapsule" in str(type(data)) and "dltensor" in str(data)
|
||||||
|
|
||||||
|
|
||||||
def _transform_dlpack(data: DataType) -> bool:
|
def _transform_dlpack(data: DataType) -> bool:
|
||||||
from cupy import fromDlpack # pylint: disable=E0401
|
from cupy import fromDlpack # pylint: disable=E0401
|
||||||
assert 'used_dltensor' not in str(data)
|
|
||||||
|
assert "used_dltensor" not in str(data)
|
||||||
data = fromDlpack(data)
|
data = fromDlpack(data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -941,8 +943,7 @@ def _from_dlpack(
|
|||||||
feature_types: Optional[FeatureTypes],
|
feature_types: Optional[FeatureTypes],
|
||||||
) -> DispatchedDataBackendReturnType:
|
) -> DispatchedDataBackendReturnType:
|
||||||
data = _transform_dlpack(data)
|
data = _transform_dlpack(data)
|
||||||
return _from_cupy_array(data, missing, nthread, feature_names,
|
return _from_cupy_array(data, missing, nthread, feature_names, feature_types)
|
||||||
feature_types)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uri(data: DataType) -> bool:
|
def _is_uri(data: DataType) -> bool:
|
||||||
@ -1003,13 +1004,13 @@ def _is_iter(data: DataType) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _has_array_protocol(data: DataType) -> bool:
|
def _has_array_protocol(data: DataType) -> bool:
|
||||||
return hasattr(data, '__array__')
|
return hasattr(data, "__array__")
|
||||||
|
|
||||||
|
|
||||||
def _convert_unknown_data(data: DataType) -> DataType:
|
def _convert_unknown_data(data: DataType) -> DataType:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
|
f"Unknown data type: {type(data)}, trying to convert it to csr_matrix",
|
||||||
UserWarning
|
UserWarning,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
@ -1033,7 +1034,7 @@ def dispatch_data_backend(
|
|||||||
enable_categorical: bool = False,
|
enable_categorical: bool = False,
|
||||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||||
) -> DispatchedDataBackendReturnType:
|
) -> DispatchedDataBackendReturnType:
|
||||||
'''Dispatch data for DMatrix.'''
|
"""Dispatch data for DMatrix."""
|
||||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
if _is_scipy_csr(data):
|
if _is_scipy_csr(data):
|
||||||
@ -1054,6 +1055,7 @@ def dispatch_data_backend(
|
|||||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||||
if _is_pandas_series(data):
|
if _is_pandas_series(data):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
data = pd.DataFrame(data)
|
data = pd.DataFrame(data)
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
return _from_pandas_df(
|
return _from_pandas_df(
|
||||||
@ -1064,39 +1066,41 @@ def dispatch_data_backend(
|
|||||||
data, missing, threads, feature_names, feature_types, enable_categorical
|
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||||
)
|
)
|
||||||
if _is_cupy_array(data):
|
if _is_cupy_array(data):
|
||||||
return _from_cupy_array(data, missing, threads, feature_names,
|
return _from_cupy_array(data, missing, threads, feature_names, feature_types)
|
||||||
feature_types)
|
|
||||||
if _is_cupy_csr(data):
|
if _is_cupy_csr(data):
|
||||||
raise TypeError('cupyx CSR is not supported yet.')
|
raise TypeError("cupyx CSR is not supported yet.")
|
||||||
if _is_cupy_csc(data):
|
if _is_cupy_csc(data):
|
||||||
raise TypeError('cupyx CSC is not supported yet.')
|
raise TypeError("cupyx CSC is not supported yet.")
|
||||||
if _is_dlpack(data):
|
if _is_dlpack(data):
|
||||||
return _from_dlpack(data, missing, threads, feature_names,
|
return _from_dlpack(data, missing, threads, feature_names, feature_types)
|
||||||
feature_types)
|
|
||||||
if _is_dt_df(data):
|
if _is_dt_df(data):
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
return _from_dt_df(
|
return _from_dt_df(
|
||||||
data, missing, threads, feature_names, feature_types, enable_categorical
|
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||||
)
|
)
|
||||||
if _is_modin_df(data):
|
if _is_modin_df(data):
|
||||||
return _from_pandas_df(data, enable_categorical, missing, threads,
|
return _from_pandas_df(
|
||||||
feature_names, feature_types)
|
data, enable_categorical, missing, threads, feature_names, feature_types
|
||||||
|
)
|
||||||
if _is_modin_series(data):
|
if _is_modin_series(data):
|
||||||
return _from_pandas_series(
|
return _from_pandas_series(
|
||||||
data, missing, threads, enable_categorical, feature_names, feature_types
|
data, missing, threads, enable_categorical, feature_names, feature_types
|
||||||
)
|
)
|
||||||
if _is_arrow(data):
|
if _is_arrow(data):
|
||||||
return _from_arrow(
|
return _from_arrow(
|
||||||
data, missing, threads, feature_names, feature_types, enable_categorical)
|
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||||
|
)
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
array = np.asarray(data)
|
array = np.asarray(data)
|
||||||
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
return _from_numpy_array(array, missing, threads, feature_names, feature_types)
|
||||||
|
|
||||||
converted = _convert_unknown_data(data)
|
converted = _convert_unknown_data(data)
|
||||||
if converted is not None:
|
if converted is not None:
|
||||||
return _from_scipy_csr(converted, missing, threads, feature_names, feature_types)
|
return _from_scipy_csr(
|
||||||
|
converted, missing, threads, feature_names, feature_types
|
||||||
|
)
|
||||||
|
|
||||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
raise TypeError("Not supported type for data." + str(type(data)))
|
||||||
|
|
||||||
|
|
||||||
def _validate_meta_shape(data: DataType, name: str) -> None:
|
def _validate_meta_shape(data: DataType, name: str) -> None:
|
||||||
@ -1128,20 +1132,14 @@ def _meta_from_numpy(
|
|||||||
|
|
||||||
|
|
||||||
def _meta_from_list(
|
def _meta_from_list(
|
||||||
data: Sequence,
|
data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
|
||||||
field: str,
|
|
||||||
dtype: Optional[NumpyDType],
|
|
||||||
handle: ctypes.c_void_p
|
|
||||||
) -> None:
|
) -> None:
|
||||||
data_np = np.array(data)
|
data_np = np.array(data)
|
||||||
_meta_from_numpy(data_np, field, dtype, handle)
|
_meta_from_numpy(data_np, field, dtype, handle)
|
||||||
|
|
||||||
|
|
||||||
def _meta_from_tuple(
|
def _meta_from_tuple(
|
||||||
data: Sequence,
|
data: Sequence, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
|
||||||
field: str,
|
|
||||||
dtype: Optional[NumpyDType],
|
|
||||||
handle: ctypes.c_void_p
|
|
||||||
) -> None:
|
) -> None:
|
||||||
return _meta_from_list(data, field, dtype, handle)
|
return _meta_from_list(data, field, dtype, handle)
|
||||||
|
|
||||||
@ -1156,39 +1154,27 @@ def _meta_from_cudf_df(data: DataType, field: str, handle: ctypes.c_void_p) -> N
|
|||||||
|
|
||||||
|
|
||||||
def _meta_from_cudf_series(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
|
def _meta_from_cudf_series(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
|
||||||
interface = bytes(json.dumps([data.__cuda_array_interface__],
|
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
|
||||||
indent=2), 'utf-8')
|
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
|
||||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
|
|
||||||
c_str(field),
|
|
||||||
interface))
|
|
||||||
|
|
||||||
|
|
||||||
def _meta_from_cupy_array(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
|
def _meta_from_cupy_array(data: DataType, field: str, handle: ctypes.c_void_p) -> None:
|
||||||
data = _transform_cupy_array(data)
|
data = _transform_cupy_array(data)
|
||||||
interface = bytes(json.dumps([data.__cuda_array_interface__],
|
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), "utf-8")
|
||||||
indent=2), 'utf-8')
|
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle, c_str(field), interface))
|
||||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(handle,
|
|
||||||
c_str(field),
|
|
||||||
interface))
|
|
||||||
|
|
||||||
|
|
||||||
def _meta_from_dt(
|
def _meta_from_dt(
|
||||||
data: DataType,
|
data: DataType, field: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
|
||||||
field: str,
|
|
||||||
dtype: Optional[NumpyDType],
|
|
||||||
handle: ctypes.c_void_p
|
|
||||||
) -> None:
|
) -> None:
|
||||||
data, _, _ = _transform_dt_df(data, None, None, field, dtype)
|
data, _, _ = _transform_dt_df(data, None, None, field, dtype)
|
||||||
_meta_from_numpy(data, field, dtype, handle)
|
_meta_from_numpy(data, field, dtype, handle)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_meta_backend(
|
def dispatch_meta_backend(
|
||||||
matrix: DMatrix,
|
matrix: DMatrix, data: DataType, name: str, dtype: Optional[NumpyDType] = None
|
||||||
data: DataType,
|
|
||||||
name: str,
|
|
||||||
dtype: Optional[NumpyDType] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
'''Dispatch for meta info.'''
|
"""Dispatch for meta info."""
|
||||||
handle = matrix.handle
|
handle = matrix.handle
|
||||||
assert handle is not None
|
assert handle is not None
|
||||||
_validate_meta_shape(data, name)
|
_validate_meta_shape(data, name)
|
||||||
@ -1231,7 +1217,7 @@ def dispatch_meta_backend(
|
|||||||
_meta_from_numpy(data, name, dtype, handle)
|
_meta_from_numpy(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
if _is_modin_series(data):
|
if _is_modin_series(data):
|
||||||
data = data.values.astype('float')
|
data = data.values.astype("float")
|
||||||
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
|
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
|
||||||
_meta_from_numpy(data, name, dtype, handle)
|
_meta_from_numpy(data, name, dtype, handle)
|
||||||
return
|
return
|
||||||
@ -1240,16 +1226,17 @@ def dispatch_meta_backend(
|
|||||||
array = np.asarray(data)
|
array = np.asarray(data)
|
||||||
_meta_from_numpy(array, name, dtype, handle)
|
_meta_from_numpy(array, name, dtype, handle)
|
||||||
return
|
return
|
||||||
raise TypeError('Unsupported type for ' + name, str(type(data)))
|
raise TypeError("Unsupported type for " + name, str(type(data)))
|
||||||
|
|
||||||
|
|
||||||
class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||||
'''An iterator for single batch data to help creating device DMatrix.
|
"""An iterator for single batch data to help creating device DMatrix.
|
||||||
Transforming input directly to histogram with normal single batch data API
|
Transforming input directly to histogram with normal single batch data API
|
||||||
can not access weight for sketching. So this iterator acts as a staging
|
can not access weight for sketching. So this iterator acts as a staging
|
||||||
area for meta info.
|
area for meta info.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.it = 0 # pylint: disable=invalid-name
|
self.it = 0 # pylint: disable=invalid-name
|
||||||
|
|||||||
@ -22,45 +22,51 @@ def find_lib_path() -> List[str]:
|
|||||||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
|
||||||
dll_path = [
|
dll_path = [
|
||||||
# normal, after installation `lib` is copied into Python package tree.
|
# normal, after installation `lib` is copied into Python package tree.
|
||||||
os.path.join(curr_path, 'lib'),
|
os.path.join(curr_path, "lib"),
|
||||||
# editable installation, no copying is performed.
|
# editable installation, no copying is performed.
|
||||||
os.path.join(curr_path, os.path.pardir, os.path.pardir, 'lib'),
|
os.path.join(curr_path, os.path.pardir, os.path.pardir, "lib"),
|
||||||
# use libxgboost from a system prefix, if available. This should be the last
|
# use libxgboost from a system prefix, if available. This should be the last
|
||||||
# option.
|
# option.
|
||||||
os.path.join(sys.prefix, 'lib'),
|
os.path.join(sys.prefix, "lib"),
|
||||||
]
|
]
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
if sys.platform == "win32":
|
||||||
if platform.architecture()[0] == '64bit':
|
if platform.architecture()[0] == "64bit":
|
||||||
dll_path.append(
|
dll_path.append(os.path.join(curr_path, "../../windows/x64/Release/"))
|
||||||
os.path.join(curr_path, '../../windows/x64/Release/'))
|
|
||||||
# hack for pip installation when copy all parent source
|
# hack for pip installation when copy all parent source
|
||||||
# directory here
|
# directory here
|
||||||
dll_path.append(os.path.join(curr_path, './windows/x64/Release/'))
|
dll_path.append(os.path.join(curr_path, "./windows/x64/Release/"))
|
||||||
else:
|
else:
|
||||||
dll_path.append(os.path.join(curr_path, '../../windows/Release/'))
|
dll_path.append(os.path.join(curr_path, "../../windows/Release/"))
|
||||||
# hack for pip installation when copy all parent source
|
# hack for pip installation when copy all parent source
|
||||||
# directory here
|
# directory here
|
||||||
dll_path.append(os.path.join(curr_path, './windows/Release/'))
|
dll_path.append(os.path.join(curr_path, "./windows/Release/"))
|
||||||
dll_path = [os.path.join(p, 'xgboost.dll') for p in dll_path]
|
dll_path = [os.path.join(p, "xgboost.dll") for p in dll_path]
|
||||||
elif sys.platform.startswith(('linux', 'freebsd', 'emscripten')):
|
elif sys.platform.startswith(("linux", "freebsd", "emscripten")):
|
||||||
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path]
|
dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
|
||||||
elif sys.platform == 'darwin':
|
elif sys.platform == "darwin":
|
||||||
dll_path = [os.path.join(p, 'libxgboost.dylib') for p in dll_path]
|
dll_path = [os.path.join(p, "libxgboost.dylib") for p in dll_path]
|
||||||
elif sys.platform == 'cygwin':
|
elif sys.platform == "cygwin":
|
||||||
dll_path = [os.path.join(p, 'cygxgboost.dll') for p in dll_path]
|
dll_path = [os.path.join(p, "cygxgboost.dll") for p in dll_path]
|
||||||
if platform.system() == 'OS400':
|
if platform.system() == "OS400":
|
||||||
dll_path = [os.path.join(p, 'libxgboost.so') for p in dll_path]
|
dll_path = [os.path.join(p, "libxgboost.so") for p in dll_path]
|
||||||
|
|
||||||
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
|
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
|
||||||
|
|
||||||
# XGBOOST_BUILD_DOC is defined by sphinx conf.
|
# XGBOOST_BUILD_DOC is defined by sphinx conf.
|
||||||
if not lib_path and not os.environ.get('XGBOOST_BUILD_DOC', False):
|
if not lib_path and not os.environ.get("XGBOOST_BUILD_DOC", False):
|
||||||
link = 'https://xgboost.readthedocs.io/en/latest/build.html'
|
link = "https://xgboost.readthedocs.io/en/latest/build.html"
|
||||||
msg = 'Cannot find XGBoost Library in the candidate path. ' + \
|
msg = (
|
||||||
'List of candidates:\n- ' + ('\n- '.join(dll_path)) + \
|
"Cannot find XGBoost Library in the candidate path. "
|
||||||
'\nXGBoost Python package path: ' + curr_path + \
|
+ "List of candidates:\n- "
|
||||||
'\nsys.prefix: ' + sys.prefix + \
|
+ ("\n- ".join(dll_path))
|
||||||
'\nSee: ' + link + ' for installing XGBoost.'
|
+ "\nXGBoost Python package path: "
|
||||||
|
+ curr_path
|
||||||
|
+ "\nsys.prefix: "
|
||||||
|
+ sys.prefix
|
||||||
|
+ "\nSee: "
|
||||||
|
+ link
|
||||||
|
+ " for installing XGBoost."
|
||||||
|
)
|
||||||
raise XGBoostLibraryNotFound(msg)
|
raise XGBoostLibraryNotFound(msg)
|
||||||
return lib_path
|
return lib_path
|
||||||
|
|||||||
@ -81,22 +81,24 @@ def plot_importance(
|
|||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError('You must install matplotlib to plot importance') from e
|
raise ImportError("You must install matplotlib to plot importance") from e
|
||||||
|
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
importance = booster.get_booster().get_score(
|
importance = booster.get_booster().get_score(
|
||||||
importance_type=importance_type, fmap=fmap)
|
importance_type=importance_type, fmap=fmap
|
||||||
|
)
|
||||||
elif isinstance(booster, Booster):
|
elif isinstance(booster, Booster):
|
||||||
importance = booster.get_score(importance_type=importance_type, fmap=fmap)
|
importance = booster.get_score(importance_type=importance_type, fmap=fmap)
|
||||||
elif isinstance(booster, dict):
|
elif isinstance(booster, dict):
|
||||||
importance = booster
|
importance = booster
|
||||||
else:
|
else:
|
||||||
raise ValueError('tree must be Booster, XGBModel or dict instance')
|
raise ValueError("tree must be Booster, XGBModel or dict instance")
|
||||||
|
|
||||||
if not importance:
|
if not importance:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Booster.get_score() results in empty. ' +
|
"Booster.get_score() results in empty. "
|
||||||
'This maybe caused by having all trees as decision dumps.')
|
+ "This maybe caused by having all trees as decision dumps."
|
||||||
|
)
|
||||||
|
|
||||||
tuples = [(k, importance[k]) for k in importance]
|
tuples = [(k, importance[k]) for k in importance]
|
||||||
if max_num_features is not None:
|
if max_num_features is not None:
|
||||||
@ -110,25 +112,25 @@ def plot_importance(
|
|||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
ylocs = np.arange(len(values))
|
ylocs = np.arange(len(values))
|
||||||
ax.barh(ylocs, values, align='center', height=height, **kwargs)
|
ax.barh(ylocs, values, align="center", height=height, **kwargs)
|
||||||
|
|
||||||
if show_values is True:
|
if show_values is True:
|
||||||
for x, y in zip(values, ylocs):
|
for x, y in zip(values, ylocs):
|
||||||
ax.text(x + 1, y, values_format.format(v=x), va='center')
|
ax.text(x + 1, y, values_format.format(v=x), va="center")
|
||||||
|
|
||||||
ax.set_yticks(ylocs)
|
ax.set_yticks(ylocs)
|
||||||
ax.set_yticklabels(labels)
|
ax.set_yticklabels(labels)
|
||||||
|
|
||||||
if xlim is not None:
|
if xlim is not None:
|
||||||
if not isinstance(xlim, tuple) or len(xlim) != 2:
|
if not isinstance(xlim, tuple) or len(xlim) != 2:
|
||||||
raise ValueError('xlim must be a tuple of 2 elements')
|
raise ValueError("xlim must be a tuple of 2 elements")
|
||||||
else:
|
else:
|
||||||
xlim = (0, max(values) * 1.1)
|
xlim = (0, max(values) * 1.1)
|
||||||
ax.set_xlim(xlim)
|
ax.set_xlim(xlim)
|
||||||
|
|
||||||
if ylim is not None:
|
if ylim is not None:
|
||||||
if not isinstance(ylim, tuple) or len(ylim) != 2:
|
if not isinstance(ylim, tuple) or len(ylim) != 2:
|
||||||
raise ValueError('ylim must be a tuple of 2 elements')
|
raise ValueError("ylim must be a tuple of 2 elements")
|
||||||
else:
|
else:
|
||||||
ylim = (-1, len(values))
|
ylim = (-1, len(values))
|
||||||
ax.set_ylim(ylim)
|
ax.set_ylim(ylim)
|
||||||
@ -201,44 +203,42 @@ def to_graphviz(
|
|||||||
try:
|
try:
|
||||||
from graphviz import Source
|
from graphviz import Source
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError('You must install graphviz to plot tree') from e
|
raise ImportError("You must install graphviz to plot tree") from e
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
booster = booster.get_booster()
|
booster = booster.get_booster()
|
||||||
|
|
||||||
# squash everything back into kwargs again for compatibility
|
# squash everything back into kwargs again for compatibility
|
||||||
parameters = 'dot'
|
parameters = "dot"
|
||||||
extra = {}
|
extra = {}
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
extra[key] = value
|
extra[key] = value
|
||||||
|
|
||||||
if rankdir is not None:
|
if rankdir is not None:
|
||||||
kwargs['graph_attrs'] = {}
|
kwargs["graph_attrs"] = {}
|
||||||
kwargs['graph_attrs']['rankdir'] = rankdir
|
kwargs["graph_attrs"]["rankdir"] = rankdir
|
||||||
for key, value in extra.items():
|
for key, value in extra.items():
|
||||||
if kwargs.get("graph_attrs", None) is not None:
|
if kwargs.get("graph_attrs", None) is not None:
|
||||||
kwargs['graph_attrs'][key] = value
|
kwargs["graph_attrs"][key] = value
|
||||||
else:
|
else:
|
||||||
kwargs['graph_attrs'] = {}
|
kwargs["graph_attrs"] = {}
|
||||||
del kwargs[key]
|
del kwargs[key]
|
||||||
|
|
||||||
if yes_color is not None or no_color is not None:
|
if yes_color is not None or no_color is not None:
|
||||||
kwargs['edge'] = {}
|
kwargs["edge"] = {}
|
||||||
if yes_color is not None:
|
if yes_color is not None:
|
||||||
kwargs['edge']['yes_color'] = yes_color
|
kwargs["edge"]["yes_color"] = yes_color
|
||||||
if no_color is not None:
|
if no_color is not None:
|
||||||
kwargs['edge']['no_color'] = no_color
|
kwargs["edge"]["no_color"] = no_color
|
||||||
|
|
||||||
if condition_node_params is not None:
|
if condition_node_params is not None:
|
||||||
kwargs['condition_node_params'] = condition_node_params
|
kwargs["condition_node_params"] = condition_node_params
|
||||||
if leaf_node_params is not None:
|
if leaf_node_params is not None:
|
||||||
kwargs['leaf_node_params'] = leaf_node_params
|
kwargs["leaf_node_params"] = leaf_node_params
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
parameters += ':'
|
parameters += ":"
|
||||||
parameters += json.dumps(kwargs)
|
parameters += json.dumps(kwargs)
|
||||||
tree = booster.get_dump(
|
tree = booster.get_dump(fmap=fmap, dump_format=parameters)[num_trees]
|
||||||
fmap=fmap,
|
|
||||||
dump_format=parameters)[num_trees]
|
|
||||||
g = Source(tree)
|
g = Source(tree)
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@ -277,19 +277,18 @@ def plot_tree(
|
|||||||
from matplotlib import image
|
from matplotlib import image
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError('You must install matplotlib to plot tree') from e
|
raise ImportError("You must install matplotlib to plot tree") from e
|
||||||
|
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
|
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
s = BytesIO()
|
s = BytesIO()
|
||||||
s.write(g.pipe(format='png'))
|
s.write(g.pipe(format="png"))
|
||||||
s.seek(0)
|
s.seek(0)
|
||||||
img = image.imread(s)
|
img = image.imread(s)
|
||||||
|
|
||||||
ax.imshow(img)
|
ax.imshow(img)
|
||||||
ax.axis('off')
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def init(args: Optional[List[bytes]] = None) -> None:
|
|||||||
parsed = {}
|
parsed = {}
|
||||||
if args:
|
if args:
|
||||||
for arg in args:
|
for arg in args:
|
||||||
kv = arg.decode().split('=')
|
kv = arg.decode().split("=")
|
||||||
if len(kv) == 2:
|
if len(kv) == 2:
|
||||||
parsed[kv[0]] = kv[1]
|
parsed[kv[0]] = kv[1]
|
||||||
collective.init(**parsed)
|
collective.init(**parsed)
|
||||||
@ -104,6 +104,7 @@ def broadcast(data: T, root: int) -> T:
|
|||||||
@unique
|
@unique
|
||||||
class Op(IntEnum):
|
class Op(IntEnum):
|
||||||
"""Supported operations for rabit."""
|
"""Supported operations for rabit."""
|
||||||
|
|
||||||
MAX = 0
|
MAX = 0
|
||||||
MIN = 1
|
MIN = 1
|
||||||
SUM = 2
|
SUM = 2
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class ExSocket:
|
|||||||
|
|
||||||
|
|
||||||
# magic number used to verify existence of data
|
# magic number used to verify existence of data
|
||||||
MAGIC_NUM = 0xff99
|
MAGIC_NUM = 0xFF99
|
||||||
|
|
||||||
|
|
||||||
def get_some_ip(host: str) -> str:
|
def get_some_ip(host: str) -> str:
|
||||||
@ -334,19 +334,19 @@ class RabitTracker:
|
|||||||
while len(shutdown) != n_workers:
|
while len(shutdown) != n_workers:
|
||||||
fd, s_addr = self.sock.accept()
|
fd, s_addr = self.sock.accept()
|
||||||
s = WorkerEntry(fd, s_addr)
|
s = WorkerEntry(fd, s_addr)
|
||||||
if s.cmd == 'print':
|
if s.cmd == "print":
|
||||||
s.print(self._use_logger)
|
s.print(self._use_logger)
|
||||||
continue
|
continue
|
||||||
if s.cmd == 'shutdown':
|
if s.cmd == "shutdown":
|
||||||
assert s.rank >= 0 and s.rank not in shutdown
|
assert s.rank >= 0 and s.rank not in shutdown
|
||||||
assert s.rank not in wait_conn
|
assert s.rank not in wait_conn
|
||||||
shutdown[s.rank] = s
|
shutdown[s.rank] = s
|
||||||
logging.debug('Received %s signal from %d', s.cmd, s.rank)
|
logging.debug("Received %s signal from %d", s.cmd, s.rank)
|
||||||
continue
|
continue
|
||||||
assert s.cmd in ("start", "recover")
|
assert s.cmd in ("start", "recover")
|
||||||
# lazily initialize the workers
|
# lazily initialize the workers
|
||||||
if tree_map is None:
|
if tree_map is None:
|
||||||
assert s.cmd == 'start'
|
assert s.cmd == "start"
|
||||||
if s.world_size > 0:
|
if s.world_size > 0:
|
||||||
n_workers = s.world_size
|
n_workers = s.world_size
|
||||||
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
|
tree_map, parent_map, ring_map = self.get_link_map(n_workers)
|
||||||
@ -354,7 +354,7 @@ class RabitTracker:
|
|||||||
todo_nodes = list(range(n_workers))
|
todo_nodes = list(range(n_workers))
|
||||||
else:
|
else:
|
||||||
assert s.world_size in (-1, n_workers)
|
assert s.world_size in (-1, n_workers)
|
||||||
if s.cmd == 'recover':
|
if s.cmd == "recover":
|
||||||
assert s.rank >= 0
|
assert s.rank >= 0
|
||||||
|
|
||||||
rank = s.decide_rank(job_map)
|
rank = s.decide_rank(job_map)
|
||||||
@ -410,24 +410,25 @@ def get_host_ip(host_ip: Optional[str] = None) -> str:
|
|||||||
returned as it's
|
returned as it's
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if host_ip is None or host_ip == 'auto':
|
if host_ip is None or host_ip == "auto":
|
||||||
host_ip = 'ip'
|
host_ip = "ip"
|
||||||
|
|
||||||
if host_ip == 'dns':
|
if host_ip == "dns":
|
||||||
host_ip = socket.getfqdn()
|
host_ip = socket.getfqdn()
|
||||||
elif host_ip == 'ip':
|
elif host_ip == "ip":
|
||||||
from socket import gaierror
|
from socket import gaierror
|
||||||
|
|
||||||
try:
|
try:
|
||||||
host_ip = socket.gethostbyname(socket.getfqdn())
|
host_ip = socket.gethostbyname(socket.getfqdn())
|
||||||
except gaierror:
|
except gaierror:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
"gethostbyname(socket.getfqdn()) failed... trying on hostname()"
|
||||||
)
|
)
|
||||||
host_ip = socket.gethostbyname(socket.gethostname())
|
host_ip = socket.gethostbyname(socket.gethostname())
|
||||||
if host_ip.startswith("127."):
|
if host_ip.startswith("127."):
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
# doesn't have to be reachable
|
# doesn't have to be reachable
|
||||||
s.connect(('10.255.255.255', 1))
|
s.connect(("10.255.255.255", 1))
|
||||||
host_ip = s.getsockname()[0]
|
host_ip = s.getsockname()[0]
|
||||||
|
|
||||||
assert host_ip is not None
|
assert host_ip is not None
|
||||||
@ -458,25 +459,41 @@ def start_rabit_tracker(args: argparse.Namespace) -> None:
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Main function if tracker is executed in standalone mode."""
|
"""Main function if tracker is executed in standalone mode."""
|
||||||
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
|
parser = argparse.ArgumentParser(description="Rabit Tracker start.")
|
||||||
parser.add_argument('--num-workers', required=True, type=int,
|
|
||||||
help='Number of worker process to be launched.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num-servers', default=0, type=int,
|
"--num-workers",
|
||||||
help='Number of server process to be launched. Only used in PS jobs.'
|
required=True,
|
||||||
|
type=int,
|
||||||
|
help="Number of worker process to be launched.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-servers",
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help="Number of server process to be launched. Only used in PS jobs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host-ip",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Host IP addressed, this is only needed "
|
||||||
|
+ "if the host IP cannot be automatically guessed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
default="INFO",
|
||||||
|
type=str,
|
||||||
|
choices=["INFO", "DEBUG"],
|
||||||
|
help="Logging level of the logger.",
|
||||||
)
|
)
|
||||||
parser.add_argument('--host-ip', default=None, type=str,
|
|
||||||
help=('Host IP addressed, this is only needed ' +
|
|
||||||
'if the host IP cannot be automatically guessed.'))
|
|
||||||
parser.add_argument('--log-level', default='INFO', type=str,
|
|
||||||
choices=['INFO', 'DEBUG'],
|
|
||||||
help='Logging level of the logger.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
fmt = '%(asctime)s %(levelname)s %(message)s'
|
fmt = "%(asctime)s %(levelname)s %(message)s"
|
||||||
if args.log_level == 'INFO':
|
if args.log_level == "INFO":
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
elif args.log_level == 'DEBUG':
|
elif args.log_level == "DEBUG":
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown logging level {args.log_level}")
|
raise RuntimeError(f"Unknown logging level {args.log_level}")
|
||||||
|
|||||||
@ -205,25 +205,29 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
class CVPack:
|
class CVPack:
|
||||||
""""Auxiliary datastruct to hold one fold of CV."""
|
""" "Auxiliary datastruct to hold one fold of CV."""
|
||||||
def __init__(self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]) -> None:
|
|
||||||
""""Initialize the CVPack"""
|
def __init__(
|
||||||
|
self, dtrain: DMatrix, dtest: DMatrix, param: Optional[Union[Dict, List]]
|
||||||
|
) -> None:
|
||||||
|
""" "Initialize the CVPack"""
|
||||||
self.dtrain = dtrain
|
self.dtrain = dtrain
|
||||||
self.dtest = dtest
|
self.dtest = dtest
|
||||||
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
|
self.watchlist = [(dtrain, "train"), (dtest, "test")]
|
||||||
self.bst = Booster(param, [dtrain, dtest])
|
self.bst = Booster(param, [dtrain, dtest])
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Callable:
|
def __getattr__(self, name: str) -> Callable:
|
||||||
def _inner(*args: Any, **kwargs: Any) -> Any:
|
def _inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
return getattr(self.bst, name)(*args, **kwargs)
|
return getattr(self.bst, name)(*args, **kwargs)
|
||||||
|
|
||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
|
def update(self, iteration: int, fobj: Optional[Objective]) -> None:
|
||||||
""""Update the boosters for one iteration"""
|
""" "Update the boosters for one iteration"""
|
||||||
self.bst.update(self.dtrain, iteration, fobj)
|
self.bst.update(self.dtrain, iteration, fobj)
|
||||||
|
|
||||||
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str:
|
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> str:
|
||||||
""""Evaluate the CVPack for one iteration."""
|
""" "Evaluate the CVPack for one iteration."""
|
||||||
return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
|
return self.bst.eval_set(self.watchlist, iteration, feval, output_margin)
|
||||||
|
|
||||||
|
|
||||||
@ -232,38 +236,42 @@ class _PackedBooster:
|
|||||||
self.cvfolds = cvfolds
|
self.cvfolds = cvfolds
|
||||||
|
|
||||||
def update(self, iteration: int, obj: Optional[Objective]) -> None:
|
def update(self, iteration: int, obj: Optional[Objective]) -> None:
|
||||||
'''Iterate through folds for update'''
|
"""Iterate through folds for update"""
|
||||||
for fold in self.cvfolds:
|
for fold in self.cvfolds:
|
||||||
fold.update(iteration, obj)
|
fold.update(iteration, obj)
|
||||||
|
|
||||||
def eval(self, iteration: int, feval: Optional[Metric], output_margin: bool) -> List[str]:
|
def eval(
|
||||||
'''Iterate through folds for eval'''
|
self, iteration: int, feval: Optional[Metric], output_margin: bool
|
||||||
|
) -> List[str]:
|
||||||
|
"""Iterate through folds for eval"""
|
||||||
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
|
result = [f.eval(iteration, feval, output_margin) for f in self.cvfolds]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def set_attr(self, **kwargs: Optional[str]) -> Any:
|
def set_attr(self, **kwargs: Optional[str]) -> Any:
|
||||||
'''Iterate through folds for setting attributes'''
|
"""Iterate through folds for setting attributes"""
|
||||||
for f in self.cvfolds:
|
for f in self.cvfolds:
|
||||||
f.bst.set_attr(**kwargs)
|
f.bst.set_attr(**kwargs)
|
||||||
|
|
||||||
def attr(self, key: str) -> Optional[str]:
|
def attr(self, key: str) -> Optional[str]:
|
||||||
'''Redirect to booster attr.'''
|
"""Redirect to booster attr."""
|
||||||
return self.cvfolds[0].bst.attr(key)
|
return self.cvfolds[0].bst.attr(key)
|
||||||
|
|
||||||
def set_param(self,
|
def set_param(
|
||||||
|
self,
|
||||||
params: Union[Dict, Iterable[Tuple[str, Any]], str],
|
params: Union[Dict, Iterable[Tuple[str, Any]], str],
|
||||||
value: Optional[str] = None) -> None:
|
value: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""Iterate through folds for set_param"""
|
"""Iterate through folds for set_param"""
|
||||||
for f in self.cvfolds:
|
for f in self.cvfolds:
|
||||||
f.bst.set_param(params, value)
|
f.bst.set_param(params, value)
|
||||||
|
|
||||||
def num_boosted_rounds(self) -> int:
|
def num_boosted_rounds(self) -> int:
|
||||||
'''Number of boosted rounds.'''
|
"""Number of boosted rounds."""
|
||||||
return self.cvfolds[0].num_boosted_rounds()
|
return self.cvfolds[0].num_boosted_rounds()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def best_iteration(self) -> int:
|
def best_iteration(self) -> int:
|
||||||
'''Get best_iteration'''
|
"""Get best_iteration"""
|
||||||
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
|
return int(cast(int, self.cvfolds[0].bst.attr("best_iteration")))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -279,7 +287,7 @@ def groups_to_rows(groups: List[np.ndarray], boundaries: np.ndarray) -> np.ndarr
|
|||||||
:param boundaries: rows index limits of each group
|
:param boundaries: rows index limits of each group
|
||||||
:return: row in group
|
:return: row in group
|
||||||
"""
|
"""
|
||||||
return np.concatenate([np.arange(boundaries[g], boundaries[g+1]) for g in groups])
|
return np.concatenate([np.arange(boundaries[g], boundaries[g + 1]) for g in groups])
|
||||||
|
|
||||||
|
|
||||||
def mkgroupfold(
|
def mkgroupfold(
|
||||||
@ -305,11 +313,17 @@ def mkgroupfold(
|
|||||||
# list by fold of test group indexes
|
# list by fold of test group indexes
|
||||||
out_group_idset = np.array_split(idx, nfold)
|
out_group_idset = np.array_split(idx, nfold)
|
||||||
# list by fold of train group indexes
|
# list by fold of train group indexes
|
||||||
in_group_idset = [np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
|
in_group_idset = [
|
||||||
for k in range(nfold)]
|
np.concatenate([out_group_idset[i] for i in range(nfold) if k != i])
|
||||||
|
for k in range(nfold)
|
||||||
|
]
|
||||||
# from the group indexes, convert them to row indexes
|
# from the group indexes, convert them to row indexes
|
||||||
in_idset = [groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset]
|
in_idset = [
|
||||||
out_idset = [groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset]
|
groups_to_rows(in_groups, group_boundaries) for in_groups in in_group_idset
|
||||||
|
]
|
||||||
|
out_idset = [
|
||||||
|
groups_to_rows(out_groups, group_boundaries) for out_groups in out_group_idset
|
||||||
|
]
|
||||||
|
|
||||||
# build the folds by taking the appropriate slices
|
# build the folds by taking the appropriate slices
|
||||||
ret = []
|
ret = []
|
||||||
@ -324,7 +338,7 @@ def mkgroupfold(
|
|||||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||||
else:
|
else:
|
||||||
tparam = param
|
tparam = param
|
||||||
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
|
plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
|
||||||
ret.append(CVPack(dtrain, dtest, plst))
|
ret.append(CVPack(dtrain, dtest, plst))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -348,16 +362,20 @@ def mknfold(
|
|||||||
|
|
||||||
if stratified is False and folds is None:
|
if stratified is False and folds is None:
|
||||||
# Do standard k-fold cross validation. Automatically determine the folds.
|
# Do standard k-fold cross validation. Automatically determine the folds.
|
||||||
if len(dall.get_uint_info('group_ptr')) > 1:
|
if len(dall.get_uint_info("group_ptr")) > 1:
|
||||||
return mkgroupfold(dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle)
|
return mkgroupfold(
|
||||||
|
dall, nfold, param, evals=evals, fpreproc=fpreproc, shuffle=shuffle
|
||||||
|
)
|
||||||
|
|
||||||
if shuffle is True:
|
if shuffle is True:
|
||||||
idx = np.random.permutation(dall.num_row())
|
idx = np.random.permutation(dall.num_row())
|
||||||
else:
|
else:
|
||||||
idx = np.arange(dall.num_row())
|
idx = np.arange(dall.num_row())
|
||||||
out_idset = np.array_split(idx, nfold)
|
out_idset = np.array_split(idx, nfold)
|
||||||
in_idset = [np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
in_idset = [
|
||||||
for k in range(nfold)]
|
np.concatenate([out_idset[i] for i in range(nfold) if k != i])
|
||||||
|
for k in range(nfold)
|
||||||
|
]
|
||||||
elif folds is not None:
|
elif folds is not None:
|
||||||
# Use user specified custom split using indices
|
# Use user specified custom split using indices
|
||||||
try:
|
try:
|
||||||
@ -387,7 +405,7 @@ def mknfold(
|
|||||||
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
|
||||||
else:
|
else:
|
||||||
tparam = param
|
tparam = param
|
||||||
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
|
plst = list(tparam.items()) + [("eval_metric", itm) for itm in evals]
|
||||||
ret.append(CVPack(dtrain, dtest, plst))
|
ret.append(CVPack(dtrain, dtest, plst))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@ -502,29 +520,32 @@ def cv(
|
|||||||
evaluation history : list(string)
|
evaluation history : list(string)
|
||||||
"""
|
"""
|
||||||
if stratified is True and not SKLEARN_INSTALLED:
|
if stratified is True and not SKLEARN_INSTALLED:
|
||||||
raise XGBoostError('sklearn needs to be installed in order to use stratified cv')
|
raise XGBoostError(
|
||||||
|
"sklearn needs to be installed in order to use stratified cv"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(metrics, str):
|
if isinstance(metrics, str):
|
||||||
metrics = [metrics]
|
metrics = [metrics]
|
||||||
|
|
||||||
params = params.copy()
|
params = params.copy()
|
||||||
if isinstance(params, list):
|
if isinstance(params, list):
|
||||||
_metrics = [x[1] for x in params if x[0] == 'eval_metric']
|
_metrics = [x[1] for x in params if x[0] == "eval_metric"]
|
||||||
params = dict(params)
|
params = dict(params)
|
||||||
if 'eval_metric' in params:
|
if "eval_metric" in params:
|
||||||
params['eval_metric'] = _metrics
|
params["eval_metric"] = _metrics
|
||||||
|
|
||||||
if (not metrics) and 'eval_metric' in params:
|
if (not metrics) and "eval_metric" in params:
|
||||||
if isinstance(params['eval_metric'], list):
|
if isinstance(params["eval_metric"], list):
|
||||||
metrics = params['eval_metric']
|
metrics = params["eval_metric"]
|
||||||
else:
|
else:
|
||||||
metrics = [params['eval_metric']]
|
metrics = [params["eval_metric"]]
|
||||||
|
|
||||||
params.pop("eval_metric", None)
|
params.pop("eval_metric", None)
|
||||||
|
|
||||||
results: Dict[str, List[float]] = {}
|
results: Dict[str, List[float]] = {}
|
||||||
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc,
|
cvfolds = mknfold(
|
||||||
stratified, folds, shuffle)
|
dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds, shuffle
|
||||||
|
)
|
||||||
|
|
||||||
metric_fn = _configure_custom_metric(feval, custom_metric)
|
metric_fn = _configure_custom_metric(feval, custom_metric)
|
||||||
|
|
||||||
@ -555,20 +576,21 @@ def cv(
|
|||||||
should_break = callbacks_container.after_iteration(booster, i, dtrain, None)
|
should_break = callbacks_container.after_iteration(booster, i, dtrain, None)
|
||||||
res = callbacks_container.aggregated_cv
|
res = callbacks_container.aggregated_cv
|
||||||
for key, mean, std in cast(List[Tuple[str, float, float]], res):
|
for key, mean, std in cast(List[Tuple[str, float, float]], res):
|
||||||
if key + '-mean' not in results:
|
if key + "-mean" not in results:
|
||||||
results[key + '-mean'] = []
|
results[key + "-mean"] = []
|
||||||
if key + '-std' not in results:
|
if key + "-std" not in results:
|
||||||
results[key + '-std'] = []
|
results[key + "-std"] = []
|
||||||
results[key + '-mean'].append(mean)
|
results[key + "-mean"].append(mean)
|
||||||
results[key + '-std'].append(std)
|
results[key + "-std"].append(std)
|
||||||
|
|
||||||
if should_break:
|
if should_break:
|
||||||
for k in results.keys(): # pylint: disable=consider-iterating-dictionary
|
for k in results.keys(): # pylint: disable=consider-iterating-dictionary
|
||||||
results[k] = results[k][:(booster.best_iteration + 1)]
|
results[k] = results[k][: (booster.best_iteration + 1)]
|
||||||
break
|
break
|
||||||
if as_pandas:
|
if as_pandas:
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
results = pd.DataFrame.from_dict(results)
|
results = pd.DataFrame.from_dict(results)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -132,16 +132,7 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
run_black(path)
|
run_black(path)
|
||||||
for path in [
|
for path in [
|
||||||
# core
|
# core
|
||||||
"python-package/xgboost/__init__.py",
|
"python-package/",
|
||||||
"python-package/xgboost/_typing.py",
|
|
||||||
"python-package/xgboost/callback.py",
|
|
||||||
"python-package/xgboost/compat.py",
|
|
||||||
"python-package/xgboost/config.py",
|
|
||||||
"python-package/xgboost/dask.py",
|
|
||||||
"python-package/xgboost/sklearn.py",
|
|
||||||
"python-package/xgboost/spark",
|
|
||||||
"python-package/xgboost/federated.py",
|
|
||||||
"python-package/xgboost/testing",
|
|
||||||
# tests
|
# tests
|
||||||
"tests/python/test_config.py",
|
"tests/python/test_config.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user