add os.PathLike support for file paths to DMatrix and Booster Python classes (#4757)

This commit is contained in:
Evan Kepner 2019-08-15 04:46:25 -04:00 committed by Jiaming Yuan
parent 7b5cbcc846
commit 53d4272c2a
4 changed files with 184 additions and 24 deletions

View File

@ -100,3 +100,5 @@ List of Contributors
* [Bryan Woods](https://github.com/bryan-woods)
- Bryan added support for cross-validation for the ranking objective
* [Haoda Fu](https://github.com/fuhaoda)
* [Evan Kepner](https://github.com/EvanKepner)
- Evan Kepner added support for os.PathLike file paths in Python

View File

@ -4,8 +4,12 @@
from __future__ import absolute_import
import abc
import os
import sys
from pathlib import PurePath
PY3 = (sys.version_info[0] == 3)
if PY3:
@ -24,6 +28,67 @@ else:
"""convert c string back to python string"""
return x
########################################################################################
# START NUMPY PATHLIB ATTRIBUTION
########################################################################################
# os.PathLike compatibility used in Numpy: https://github.com/numpy/numpy/tree/v1.17.0
# Attribution:
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/compat/py3k.py#L188-L247
# Backport os.fs_path, os.PathLike, and PurePath.__fspath__
if sys.version_info[:2] >= (3, 6):
os_fspath = os.fspath
os_PathLike = os.PathLike
else:
def _PurePath__fspath__(self):
return str(self)
class os_PathLike(abc.ABC):
"""Abstract base class for implementing the file system path protocol."""
@abc.abstractmethod
def __fspath__(self):
"""Return the file system path representation of the object."""
raise NotImplementedError
@classmethod
def __subclasshook__(cls, subclass):
if issubclass(subclass, PurePath):
return True
return hasattr(subclass, '__fspath__')
def os_fspath(path):
"""Return the path representation of a path-like object.
If str or bytes is passed in, it is returned unchanged. Otherwise the
os.PathLike interface is used to get the path representation. If the
path representation is not str or bytes, TypeError is raised. If the
provided path is not str, bytes, or os.PathLike, TypeError is raised.
"""
if isinstance(path, (str, bytes)):
return path
# Work from the object's type to match method resolution of other magic
# methods.
path_type = type(path)
try:
path_repr = path_type.__fspath__(path)
except AttributeError:
if hasattr(path_type, '__fspath__'):
raise
if issubclass(path_type, PurePath):
return _PurePath__fspath__(path)
raise TypeError("expected str, bytes or os.PathLike object, "
"not " + path_type.__name__)
if isinstance(path_repr, (str, bytes)):
return path_repr
raise TypeError("expected {}.__fspath__() to return str or bytes, "
"not {}".format(path_type.__name__,
type(path_repr).__name__))
########################################################################################
# END NUMPY PATHLIB ATTRIBUTION
########################################################################################
# pickle
try:
import cPickle as pickle # noqa
except ImportError:

View File

@ -20,7 +20,7 @@ import numpy as np
import scipy.sparse
from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str,
PANDAS_INSTALLED, DataTable)
PANDAS_INSTALLED, DataTable, os_fspath, os_PathLike)
from .libpath import find_lib_path
@ -336,10 +336,10 @@ class DMatrix(object):
"""
Parameters
----------
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from.
When data is string or os.PathLike type, it represents the path libsvm format
txt file, or binary file that xgboost can read from.
label : list or numpy 1-D array, optional
Label of the training data.
missing : float, optional
@ -390,9 +390,9 @@ class DMatrix(object):
warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning)
if isinstance(data, STRING_TYPES):
if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(silent),
ctypes.byref(handle)))
self.handle = handle
@ -653,13 +653,13 @@ class DMatrix(object):
Parameters
----------
fname : string
fname : string or os.PathLike
Name of the output buffer file.
silent : bool (optional; default: True)
If set, the output is suppressed.
"""
_check_call(_LIB.XGDMatrixSaveBinary(self.handle,
c_str(fname),
c_str(os_fspath(fname)),
ctypes.c_int(silent)))
def set_label(self, label):
@ -937,7 +937,7 @@ class Booster(object):
Parameters for boosters.
cache : list
List of cache items.
model_file : string
model_file : string or os.PathLike
Path to the model file.
"""
for d in cache:
@ -1329,11 +1329,11 @@ class Booster(object):
Parameters
----------
fname : string
fname : string or os.PathLike
Output file name
"""
if isinstance(fname, STRING_TYPES): # assume file name
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(fname)))
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname))))
else:
raise TypeError("fname must be a string")
@ -1363,12 +1363,12 @@ class Booster(object):
Parameters
----------
fname : string or a memory buffer
fname : string, os.PathLike, or a memory buffer
Input file name or memory buffer(see also save_raw)
"""
if isinstance(fname, STRING_TYPES):
if isinstance(fname, (STRING_TYPES, os_PathLike)):
# assume file name, cannot use os.path.exist to check, file can be from URL.
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(fname)))
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname))))
else:
buf = fname
length = c_bst_ulong(len(buf))
@ -1381,17 +1381,17 @@ class Booster(object):
Parameters
----------
fout : string
fout : string or os.PathLike
Output file name.
fmap : string, optional
fmap : string or os.PathLike, optional
Name of the file containing feature map names.
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'.
"""
if isinstance(fout, STRING_TYPES):
fout = open(fout, 'w')
if isinstance(fout, (STRING_TYPES, os_PathLike)):
fout = open(os_fspath(fout), 'w')
need_close = True
else:
need_close = False
@ -1416,13 +1416,14 @@ class Booster(object):
Parameters
----------
fmap : string, optional
fmap : string or os.PathLike, optional
Name of the file containing feature map names.
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump. Can be 'text', 'json' or 'dot'.
"""
fmap = os_fspath(fmap)
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '':
@ -1473,7 +1474,7 @@ class Booster(object):
Parameters
----------
fmap: str (optional)
fmap: str or os.PathLike (optional)
The name of feature map file
"""
@ -1497,11 +1498,12 @@ class Booster(object):
Parameters
----------
fmap: str (optional)
fmap: str or os.PathLike (optional)
The name of feature map file.
importance_type: str, default 'weight'
One of the importance types defined above.
"""
fmap = os_fspath(fmap)
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('Feature importance is not defined for Booster type {}'
.format(self.booster))
@ -1591,10 +1593,11 @@ class Booster(object):
Parameters
----------
fmap: str (optional)
fmap: str or os.PathLike (optional)
The name of feature map file.
"""
# pylint: disable=too-many-locals
fmap = os_fspath(fmap)
if not PANDAS_INSTALLED:
raise Exception(('pandas must be available to use this method.'
'Install pandas before calling again.'))
@ -1701,7 +1704,7 @@ class Booster(object):
----------
feature: str
The name of the feature.
fmap: str (optional)
fmap: str or os.PathLike (optional)
The name of feature map file.
bin: int, default None
The maximum number of bins.

View File

@ -11,6 +11,7 @@ import numpy as np
import xgboost as xgb
import unittest
import json
from pathlib import Path
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
@ -341,3 +342,92 @@ class TestBasic(unittest.TestCase):
dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('root_index')
class TestBasicPathLike(unittest.TestCase):
"""Unit tests using the os_fspath and pathlib.Path for file interaction."""
def test_DMatrix_init_from_path(self):
"""Initialization from the data path."""
dpath = Path('demo/data')
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
assert dtrain.num_row() == 6513
assert dtrain.num_col() == 127
def test_DMatrix_save_to_path(self):
"""Saving to a binary file using pathlib from a DMatrix."""
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
# save, assert exists, remove file
binary_path = Path("dtrain.bin")
dm.save_binary(binary_path)
assert binary_path.exists()
Path.unlink(binary_path)
def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
model_file=Path("invalidpath"))
def test_Booster_save_and_load(self):
"""Saving and loading model files from paths."""
save_path = Path("saveload.model")
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
params = {'objective': 'binary:logistic',
'eval_metric': 'logloss',
'eta': 0.3,
'max_depth': 1}
bst = xgb.train(params, dm, num_boost_round=1)
# save, assert exists
bst.save_model(save_path)
assert save_path.exists()
def dump_assertions(dump):
"""Assertions for the expected dump from Booster"""
assert len(dump) == 1, 'Exepcted only 1 tree to be dumped.'
assert len(dump[0].splitlines()) == 3, 'Expected 1 root and 2 leaves - 3 lines.'
# load the model again using Path
bst2 = xgb.Booster(model_file=save_path)
dump2 = bst2.get_dump()
dump_assertions(dump2)
# load again using load_model
bst3 = xgb.Booster()
bst3.load_model(save_path)
dump3= bst3.get_dump()
dump_assertions(dump3)
# remove file
Path.unlink(save_path)
def test_os_fspath(self):
"""Core properties of the os_fspath function."""
# strings are returned unmodified
assert '' == xgb.compat.os_fspath('')
assert '/this/path' == xgb.compat.os_fspath('/this/path')
# bytes are returned unmodified
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
# path objects are returned as string representation
path_test = Path('this') / 'path'
assert str(path_test) == xgb.compat.os_fspath(path_test)
# invalid values raise Type error
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)