add os.PathLike support for file paths to DMatrix and Booster Python classes (#4757)

This commit is contained in:
Evan Kepner 2019-08-15 04:46:25 -04:00 committed by Jiaming Yuan
parent 7b5cbcc846
commit 53d4272c2a
4 changed files with 184 additions and 24 deletions

View File

@ -100,3 +100,5 @@ List of Contributors
* [Bryan Woods](https://github.com/bryan-woods) * [Bryan Woods](https://github.com/bryan-woods)
- Bryan added support for cross-validation for the ranking objective - Bryan added support for cross-validation for the ranking objective
* [Haoda Fu](https://github.com/fuhaoda) * [Haoda Fu](https://github.com/fuhaoda)
* [Evan Kepner](https://github.com/EvanKepner)
- Evan Kepner added support for os.PathLike file paths in Python

View File

@ -4,8 +4,12 @@
from __future__ import absolute_import from __future__ import absolute_import
import abc
import os
import sys import sys
from pathlib import PurePath
PY3 = (sys.version_info[0] == 3) PY3 = (sys.version_info[0] == 3)
if PY3: if PY3:
@ -24,6 +28,67 @@ else:
"""convert c string back to python string""" """convert c string back to python string"""
return x return x
########################################################################################
# START NUMPY PATHLIB ATTRIBUTION
########################################################################################
# os.PathLike compatibility used in Numpy: https://github.com/numpy/numpy/tree/v1.17.0
# Attribution:
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/compat/py3k.py#L188-L247
# Backport os.fs_path, os.PathLike, and PurePath.__fspath__
if sys.version_info[:2] >= (3, 6):
os_fspath = os.fspath
os_PathLike = os.PathLike
else:
def _PurePath__fspath__(self):
return str(self)
class os_PathLike(abc.ABC):
"""Abstract base class for implementing the file system path protocol."""
@abc.abstractmethod
def __fspath__(self):
"""Return the file system path representation of the object."""
raise NotImplementedError
@classmethod
def __subclasshook__(cls, subclass):
if issubclass(subclass, PurePath):
return True
return hasattr(subclass, '__fspath__')
def os_fspath(path):
"""Return the path representation of a path-like object.
If str or bytes is passed in, it is returned unchanged. Otherwise the
os.PathLike interface is used to get the path representation. If the
path representation is not str or bytes, TypeError is raised. If the
provided path is not str, bytes, or os.PathLike, TypeError is raised.
"""
if isinstance(path, (str, bytes)):
return path
# Work from the object's type to match method resolution of other magic
# methods.
path_type = type(path)
try:
path_repr = path_type.__fspath__(path)
except AttributeError:
if hasattr(path_type, '__fspath__'):
raise
if issubclass(path_type, PurePath):
return _PurePath__fspath__(path)
raise TypeError("expected str, bytes or os.PathLike object, "
"not " + path_type.__name__)
if isinstance(path_repr, (str, bytes)):
return path_repr
raise TypeError("expected {}.__fspath__() to return str or bytes, "
"not {}".format(path_type.__name__,
type(path_repr).__name__))
########################################################################################
# END NUMPY PATHLIB ATTRIBUTION
########################################################################################
# pickle
try: try:
import cPickle as pickle # noqa import cPickle as pickle # noqa
except ImportError: except ImportError:

View File

@ -20,7 +20,7 @@ import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str,
PANDAS_INSTALLED, DataTable) PANDAS_INSTALLED, DataTable, os_fspath, os_PathLike)
from .libpath import find_lib_path from .libpath import find_lib_path
@ -336,10 +336,10 @@ class DMatrix(object):
""" """
Parameters Parameters
---------- ----------
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
Data source of DMatrix. Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file, When data is string or os.PathLike type, it represents the path libsvm format
or binary file that xgboost can read from. txt file, or binary file that xgboost can read from.
label : list or numpy 1-D array, optional label : list or numpy 1-D array, optional
Label of the training data. Label of the training data.
missing : float, optional missing : float, optional
@ -390,9 +390,9 @@ class DMatrix(object):
warnings.warn('Initializing DMatrix from List is deprecated.', warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning) DeprecationWarning)
if isinstance(data, STRING_TYPES): if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(silent), ctypes.c_int(silent),
ctypes.byref(handle))) ctypes.byref(handle)))
self.handle = handle self.handle = handle
@ -653,13 +653,13 @@ class DMatrix(object):
Parameters Parameters
---------- ----------
fname : string fname : string or os.PathLike
Name of the output buffer file. Name of the output buffer file.
silent : bool (optional; default: True) silent : bool (optional; default: True)
If set, the output is suppressed. If set, the output is suppressed.
""" """
_check_call(_LIB.XGDMatrixSaveBinary(self.handle, _check_call(_LIB.XGDMatrixSaveBinary(self.handle,
c_str(fname), c_str(os_fspath(fname)),
ctypes.c_int(silent))) ctypes.c_int(silent)))
def set_label(self, label): def set_label(self, label):
@ -937,7 +937,7 @@ class Booster(object):
Parameters for boosters. Parameters for boosters.
cache : list cache : list
List of cache items. List of cache items.
model_file : string model_file : string or os.PathLike
Path to the model file. Path to the model file.
""" """
for d in cache: for d in cache:
@ -1329,11 +1329,11 @@ class Booster(object):
Parameters Parameters
---------- ----------
fname : string fname : string or os.PathLike
Output file name Output file name
""" """
if isinstance(fname, STRING_TYPES): # assume file name if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(fname))) _check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname))))
else: else:
raise TypeError("fname must be a string") raise TypeError("fname must be a string")
@ -1363,12 +1363,12 @@ class Booster(object):
Parameters Parameters
---------- ----------
fname : string or a memory buffer fname : string, os.PathLike, or a memory buffer
Input file name or memory buffer(see also save_raw) Input file name or memory buffer(see also save_raw)
""" """
if isinstance(fname, STRING_TYPES): if isinstance(fname, (STRING_TYPES, os_PathLike)):
# assume file name, cannot use os.path.exist to check, file can be from URL. # assume file name, cannot use os.path.exist to check, file can be from URL.
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(fname))) _check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname))))
else: else:
buf = fname buf = fname
length = c_bst_ulong(len(buf)) length = c_bst_ulong(len(buf))
@ -1381,17 +1381,17 @@ class Booster(object):
Parameters Parameters
---------- ----------
fout : string fout : string or os.PathLike
Output file name. Output file name.
fmap : string, optional fmap : string or os.PathLike, optional
Name of the file containing feature map names. Name of the file containing feature map names.
with_stats : bool, optional with_stats : bool, optional
Controls whether the split statistics are output. Controls whether the split statistics are output.
dump_format : string, optional dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'. Format of model dump file. Can be 'text' or 'json'.
""" """
if isinstance(fout, STRING_TYPES): if isinstance(fout, (STRING_TYPES, os_PathLike)):
fout = open(fout, 'w') fout = open(os_fspath(fout), 'w')
need_close = True need_close = True
else: else:
need_close = False need_close = False
@ -1416,13 +1416,14 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap : string, optional fmap : string or os.PathLike, optional
Name of the file containing feature map names. Name of the file containing feature map names.
with_stats : bool, optional with_stats : bool, optional
Controls whether the split statistics are output. Controls whether the split statistics are output.
dump_format : string, optional dump_format : string, optional
Format of model dump. Can be 'text', 'json' or 'dot'. Format of model dump. Can be 'text', 'json' or 'dot'.
""" """
fmap = os_fspath(fmap)
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '': if self.feature_names is not None and fmap == '':
@ -1473,7 +1474,7 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap: str (optional) fmap: str or os.PathLike (optional)
The name of feature map file The name of feature map file
""" """
@ -1497,11 +1498,12 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap: str (optional) fmap: str or os.PathLike (optional)
The name of feature map file. The name of feature map file.
importance_type: str, default 'weight' importance_type: str, default 'weight'
One of the importance types defined above. One of the importance types defined above.
""" """
fmap = os_fspath(fmap)
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('Feature importance is not defined for Booster type {}' raise ValueError('Feature importance is not defined for Booster type {}'
.format(self.booster)) .format(self.booster))
@ -1591,10 +1593,11 @@ class Booster(object):
Parameters Parameters
---------- ----------
fmap: str (optional) fmap: str or os.PathLike (optional)
The name of feature map file. The name of feature map file.
""" """
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
fmap = os_fspath(fmap)
if not PANDAS_INSTALLED: if not PANDAS_INSTALLED:
raise Exception(('pandas must be available to use this method.' raise Exception(('pandas must be available to use this method.'
'Install pandas before calling again.')) 'Install pandas before calling again.'))
@ -1701,7 +1704,7 @@ class Booster(object):
---------- ----------
feature: str feature: str
The name of the feature. The name of the feature.
fmap: str (optional) fmap: str or os.PathLike (optional)
The name of feature map file. The name of feature map file.
bin: int, default None bin: int, default None
The maximum number of bins. The maximum number of bins.

View File

@ -11,6 +11,7 @@ import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest import unittest
import json import json
from pathlib import Path
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -341,3 +342,92 @@ class TestBasic(unittest.TestCase):
dtrain.get_float_info('weight') dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin') dtrain.get_float_info('base_margin')
dtrain.get_uint_info('root_index') dtrain.get_uint_info('root_index')
class TestBasicPathLike(unittest.TestCase):
"""Unit tests using the os_fspath and pathlib.Path for file interaction."""
def test_DMatrix_init_from_path(self):
"""Initialization from the data path."""
dpath = Path('demo/data')
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
assert dtrain.num_row() == 6513
assert dtrain.num_col() == 127
def test_DMatrix_save_to_path(self):
"""Saving to a binary file using pathlib from a DMatrix."""
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
# save, assert exists, remove file
binary_path = Path("dtrain.bin")
dm.save_binary(binary_path)
assert binary_path.exists()
Path.unlink(binary_path)
def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
model_file=Path("invalidpath"))
def test_Booster_save_and_load(self):
"""Saving and loading model files from paths."""
save_path = Path("saveload.model")
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
params = {'objective': 'binary:logistic',
'eval_metric': 'logloss',
'eta': 0.3,
'max_depth': 1}
bst = xgb.train(params, dm, num_boost_round=1)
# save, assert exists
bst.save_model(save_path)
assert save_path.exists()
def dump_assertions(dump):
"""Assertions for the expected dump from Booster"""
assert len(dump) == 1, 'Exepcted only 1 tree to be dumped.'
assert len(dump[0].splitlines()) == 3, 'Expected 1 root and 2 leaves - 3 lines.'
# load the model again using Path
bst2 = xgb.Booster(model_file=save_path)
dump2 = bst2.get_dump()
dump_assertions(dump2)
# load again using load_model
bst3 = xgb.Booster()
bst3.load_model(save_path)
dump3= bst3.get_dump()
dump_assertions(dump3)
# remove file
Path.unlink(save_path)
def test_os_fspath(self):
"""Core properties of the os_fspath function."""
# strings are returned unmodified
assert '' == xgb.compat.os_fspath('')
assert '/this/path' == xgb.compat.os_fspath('/this/path')
# bytes are returned unmodified
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
# path objects are returned as string representation
path_test = Path('this') / 'path'
assert str(path_test) == xgb.compat.os_fspath(path_test)
# invalid values raise Type error
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)