Cleanup Python code. (#6223)

* Remove pathlike as XGBoost 1.2 requires Python 3.6.
* Move conditional import of dask/distributed into dask module.
This commit is contained in:
Jiaming Yuan 2020-10-12 15:44:41 +08:00 committed by GitHub
parent 70c2039748
commit 2443275891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 118 deletions

View File

@ -19,67 +19,6 @@ def py_str(x):
return x.decode('utf-8') return x.decode('utf-8')
###############################################################################
# START NUMPY PATHLIB ATTRIBUTION
###############################################################################
# os.PathLike compatibility used in Numpy:
# https://github.com/numpy/numpy/tree/v1.17.0
# Attribution:
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/compat/py3k.py#L188-L247
# Backport os.fs_path, os.PathLike, and PurePath.__fspath__
if sys.version_info[:2] >= (3, 6):
os_fspath = os.fspath
os_PathLike = os.PathLike
else:
def _PurePath__fspath__(self):
return str(self)
class os_PathLike(abc.ABC):
"""Abstract base class for implementing the file system path protocol."""
@abc.abstractmethod
def __fspath__(self):
"""Return the file system path representation of the object."""
raise NotImplementedError
@classmethod
def __subclasshook__(cls, subclass):
if issubclass(subclass, PurePath):
return True
return hasattr(subclass, '__fspath__')
def os_fspath(path):
"""Return the path representation of a path-like object.
If str or bytes is passed in, it is returned unchanged. Otherwise the
os.PathLike interface is used to get the path representation. If the
path representation is not str or bytes, TypeError is raised. If the
provided path is not str, bytes, or os.PathLike, TypeError is raised.
"""
if isinstance(path, (str, bytes)):
return path
# Work from the object's type to match method resolution of other magic
# methods.
path_type = type(path)
try:
path_repr = path_type.__fspath__(path)
except AttributeError as e:
if hasattr(path_type, '__fspath__'):
raise
if issubclass(path_type, PurePath):
return _PurePath__fspath__(path)
raise TypeError("expected str, bytes or os.PathLike object, "
"not " + path_type.__name__) from e
if isinstance(path_repr, (str, bytes)):
return path_repr
raise TypeError("expected {}.__fspath__() to return str or bytes, "
"not {}".format(path_type.__name__,
type(path_repr).__name__))
###############################################################################
# END NUMPY PATHLIB ATTRIBUTION
###############################################################################
def lazy_isinstance(instance, module, name): def lazy_isinstance(instance, module, name):
'''Use string representation to identify a type.''' '''Use string representation to identify a type.'''
module = type(instance).__module__ == module module = type(instance).__module__ == module
@ -167,26 +106,9 @@ except ImportError:
# dask # dask
try: try:
import dask import dask
from dask import delayed
from dask import dataframe as dd
from dask import array as da
from dask.distributed import Client, get_client
from dask.distributed import comm as distributed_comm
from dask.distributed import wait as distributed_wait
from distributed import get_worker as distributed_get_worker
DASK_INSTALLED = True DASK_INSTALLED = True
except ImportError: except ImportError:
dd = None
da = None
Client = None
delayed = None
get_client = None
distributed_comm = None
distributed_wait = None
distributed_get_worker = None
dask = None dask = None
DASK_INSTALLED = False DASK_INSTALLED = False

View File

@ -16,10 +16,8 @@ import warnings
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import ( from .compat import (STRING_TYPES, DataFrame, py_str, PANDAS_INSTALLED,
STRING_TYPES, DataFrame, py_str, lazy_isinstance)
PANDAS_INSTALLED,
os_fspath, os_PathLike, lazy_isinstance)
from .libpath import find_lib_path from .libpath import find_lib_path
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h # c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
@ -590,7 +588,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
If set, the output is suppressed. If set, the output is suppressed.
""" """
_check_call(_LIB.XGDMatrixSaveBinary(self.handle, _check_call(_LIB.XGDMatrixSaveBinary(self.handle,
c_str(os_fspath(fname)), c_str(os.fspath(fname)),
ctypes.c_int(silent))) ctypes.c_int(silent)))
def set_label(self, label): def set_label(self, label):
@ -982,7 +980,7 @@ class Booster(object):
_check_call( _check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length)) _LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state) self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)): elif isinstance(model_file, (STRING_TYPES, os.PathLike, bytearray)):
self.load_model(model_file) self.load_model(model_file)
elif model_file is None: elif model_file is None:
pass pass
@ -1582,11 +1580,11 @@ class Booster(object):
Output file name Output file name
""" """
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name if isinstance(fname, (STRING_TYPES, os.PathLike)): # assume file name
_check_call(_LIB.XGBoosterSaveModel( _check_call(_LIB.XGBoosterSaveModel(
self.handle, c_str(os_fspath(fname)))) self.handle, c_str(os.fspath(fname))))
else: else:
raise TypeError("fname must be a string or os_PathLike") raise TypeError("fname must be a string or os PathLike")
def save_raw(self): def save_raw(self):
"""Save the model to a in memory buffer representation instead of file. """Save the model to a in memory buffer representation instead of file.
@ -1620,11 +1618,11 @@ class Booster(object):
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, (STRING_TYPES, os_PathLike)): if isinstance(fname, (STRING_TYPES, os.PathLike)):
# assume file name, cannot use os.path.exist to check, file can be # assume file name, cannot use os.path.exist to check, file can be
# from URL. # from URL.
_check_call(_LIB.XGBoosterLoadModel( _check_call(_LIB.XGBoosterLoadModel(
self.handle, c_str(os_fspath(fname)))) self.handle, c_str(os.fspath(fname))))
elif isinstance(fname, bytearray): elif isinstance(fname, bytearray):
buf = fname buf = fname
length = c_bst_ulong(len(buf)) length = c_bst_ulong(len(buf))
@ -1650,8 +1648,8 @@ class Booster(object):
dump_format : string, optional dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'. Format of model dump file. Can be 'text' or 'json'.
""" """
if isinstance(fout, (STRING_TYPES, os_PathLike)): if isinstance(fout, (STRING_TYPES, os.PathLike)):
fout = open(os_fspath(fout), 'w') fout = open(os.fspath(fout), 'w')
need_close = True need_close = True
else: else:
need_close = False need_close = False
@ -1685,7 +1683,7 @@ class Booster(object):
Format of model dump. Can be 'text', 'json' or 'dot'. Format of model dump. Can be 'text', 'json' or 'dot'.
""" """
fmap = os_fspath(fmap) fmap = os.fspath(fmap)
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '': if self.feature_names is not None and fmap == '':
@ -1765,7 +1763,7 @@ class Booster(object):
importance_type: str, default 'weight' importance_type: str, default 'weight'
One of the importance types defined above. One of the importance types defined above.
""" """
fmap = os_fspath(fmap) fmap = os.fspath(fmap)
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('Feature importance is not defined for Booster type {}' raise ValueError('Feature importance is not defined for Booster type {}'
.format(self.booster)) .format(self.booster))
@ -1858,7 +1856,7 @@ class Booster(object):
The name of feature map file. The name of feature map file.
""" """
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
fmap = os_fspath(fmap) fmap = os.fspath(fmap)
if not PANDAS_INSTALLED: if not PANDAS_INSTALLED:
raise Exception(('pandas must be available to use this method.' raise Exception(('pandas must be available to use this method.'
'Install pandas before calling again.')) 'Install pandas before calling again.'))

View File

@ -24,8 +24,6 @@ import numpy
from . import rabit from . import rabit
from .compat import DASK_INSTALLED from .compat import DASK_INSTALLED
from .compat import distributed_get_worker, distributed_wait, distributed_comm
from .compat import da, dd, delayed, get_client
from .compat import sparse, scipy_sparse from .compat import sparse, scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_concat from .compat import CUDF_concat
@ -38,9 +36,22 @@ from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc from .sklearn import xgboost_model_doc
try: try:
from distributed import Client from dask.distributed import Client, get_client
from dask.distributed import comm as distributed_comm
from dask.distributed import wait as distributed_wait
from dask.distributed import get_worker as distributed_get_worker
from dask import dataframe as dd
from dask import array as da
from dask import delayed
except ImportError: except ImportError:
Client = None Client = None
get_client = None
distributed_comm = None
distributed_wait = None
distributed_get_worker = None
dd = None
da = None
delayed = None
# Current status is considered as initial support, many features are # Current status is considered as initial support, many features are
# not properly supported yet. # not properly supported yet.
@ -83,6 +94,9 @@ def _assert_dask_support():
if not DASK_INSTALLED: if not DASK_INSTALLED:
raise ImportError( raise ImportError(
'Dask needs to be installed in order to use this module') 'Dask needs to be installed in order to use this module')
if not distributed_wait:
raise ImportError(
'distributed needs to be installed in order to use this module.')
if platform.system() == 'Windows': if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,' msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.' msg += ' contribution are welcomed.'

View File

@ -4,12 +4,13 @@
import ctypes import ctypes
import json import json
import warnings import warnings
import os
import numpy as np import numpy as np
from .core import c_array, _LIB, _check_call, c_str from .core import c_array, _LIB, _check_call, c_str
from .core import DataIter, DeviceQuantileDMatrix, DMatrix from .core import DataIter, DeviceQuantileDMatrix, DMatrix
from .compat import lazy_isinstance, os_fspath, os_PathLike from .compat import lazy_isinstance
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
@ -478,13 +479,13 @@ def _from_dlpack(data, missing, nthread, feature_names, feature_types):
def _is_uri(data): def _is_uri(data):
return isinstance(data, (str, os_PathLike)) return isinstance(data, (str, os.PathLike))
def _from_uri(data, missing, feature_names, feature_types): def _from_uri(data, missing, feature_names, feature_types):
_warn_unused_missing(data, missing) _warn_unused_missing(data, missing)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os.fspath(data)),
ctypes.c_int(1), ctypes.c_int(1),
ctypes.byref(handle))) ctypes.byref(handle)))
return handle, feature_names, feature_types return handle, feature_names, feature_types

View File

@ -248,7 +248,7 @@ class TestBasic(unittest.TestCase):
class TestBasicPathLike(unittest.TestCase): class TestBasicPathLike(unittest.TestCase):
"""Unit tests using the os_fspath and pathlib.Path for file interaction.""" """Unit tests using pathlib.Path for file interaction."""
def test_DMatrix_init_from_path(self): def test_DMatrix_init_from_path(self):
"""Initialization from the data path.""" """Initialization from the data path."""
@ -317,19 +317,3 @@ class TestBasicPathLike(unittest.TestCase):
# remove file # remove file
Path.unlink(save_path) Path.unlink(save_path)
def test_os_fspath(self):
"""Core properties of the os_fspath function."""
# strings are returned unmodified
assert '' == xgb.compat.os_fspath('')
assert '/this/path' == xgb.compat.os_fspath('/this/path')
# bytes are returned unmodified
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
# path objects are returned as string representation
path_test = Path('this') / 'path'
assert str(path_test) == xgb.compat.os_fspath(path_test)
# invalid values raise Type error
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)