add os.PathLike support for file paths to DMatrix and Booster Python classes (#4757)
This commit is contained in:
parent
7b5cbcc846
commit
53d4272c2a
@ -100,3 +100,5 @@ List of Contributors
|
||||
* [Bryan Woods](https://github.com/bryan-woods)
|
||||
- Bryan added support for cross-validation for the ranking objective
|
||||
* [Haoda Fu](https://github.com/fuhaoda)
|
||||
* [Evan Kepner](https://github.com/EvanKepner)
|
||||
- Evan Kepner added support for os.PathLike file paths in Python
|
||||
|
||||
@ -4,8 +4,12 @@
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import abc
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pathlib import PurePath
|
||||
|
||||
PY3 = (sys.version_info[0] == 3)
|
||||
|
||||
if PY3:
|
||||
@ -24,6 +28,67 @@ else:
|
||||
"""convert c string back to python string"""
|
||||
return x
|
||||
|
||||
########################################################################################
|
||||
# START NUMPY PATHLIB ATTRIBUTION
|
||||
########################################################################################
|
||||
# os.PathLike compatibility used in Numpy: https://github.com/numpy/numpy/tree/v1.17.0
|
||||
# Attribution:
|
||||
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/compat/py3k.py#L188-L247
|
||||
# Backport os.fs_path, os.PathLike, and PurePath.__fspath__
|
||||
if sys.version_info[:2] >= (3, 6):
|
||||
os_fspath = os.fspath
|
||||
os_PathLike = os.PathLike
|
||||
else:
|
||||
def _PurePath__fspath__(self):
|
||||
return str(self)
|
||||
|
||||
class os_PathLike(abc.ABC):
|
||||
"""Abstract base class for implementing the file system path protocol."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __fspath__(self):
|
||||
"""Return the file system path representation of the object."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass):
|
||||
if issubclass(subclass, PurePath):
|
||||
return True
|
||||
return hasattr(subclass, '__fspath__')
|
||||
|
||||
|
||||
def os_fspath(path):
|
||||
"""Return the path representation of a path-like object.
|
||||
If str or bytes is passed in, it is returned unchanged. Otherwise the
|
||||
os.PathLike interface is used to get the path representation. If the
|
||||
path representation is not str or bytes, TypeError is raised. If the
|
||||
provided path is not str, bytes, or os.PathLike, TypeError is raised.
|
||||
"""
|
||||
if isinstance(path, (str, bytes)):
|
||||
return path
|
||||
|
||||
# Work from the object's type to match method resolution of other magic
|
||||
# methods.
|
||||
path_type = type(path)
|
||||
try:
|
||||
path_repr = path_type.__fspath__(path)
|
||||
except AttributeError:
|
||||
if hasattr(path_type, '__fspath__'):
|
||||
raise
|
||||
if issubclass(path_type, PurePath):
|
||||
return _PurePath__fspath__(path)
|
||||
raise TypeError("expected str, bytes or os.PathLike object, "
|
||||
"not " + path_type.__name__)
|
||||
if isinstance(path_repr, (str, bytes)):
|
||||
return path_repr
|
||||
raise TypeError("expected {}.__fspath__() to return str or bytes, "
|
||||
"not {}".format(path_type.__name__,
|
||||
type(path_repr).__name__))
|
||||
########################################################################################
|
||||
# END NUMPY PATHLIB ATTRIBUTION
|
||||
########################################################################################
|
||||
|
||||
# pickle
|
||||
try:
|
||||
import cPickle as pickle # noqa
|
||||
except ImportError:
|
||||
|
||||
@ -20,7 +20,7 @@ import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str,
|
||||
PANDAS_INSTALLED, DataTable)
|
||||
PANDAS_INSTALLED, DataTable, os_fspath, os_PathLike)
|
||||
from .libpath import find_lib_path
|
||||
|
||||
|
||||
@ -336,10 +336,10 @@ class DMatrix(object):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
|
||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
|
||||
Data source of DMatrix.
|
||||
When data is string type, it represents the path libsvm format txt file,
|
||||
or binary file that xgboost can read from.
|
||||
When data is string or os.PathLike type, it represents the path libsvm format
|
||||
txt file, or binary file that xgboost can read from.
|
||||
label : list or numpy 1-D array, optional
|
||||
Label of the training data.
|
||||
missing : float, optional
|
||||
@ -390,9 +390,9 @@ class DMatrix(object):
|
||||
warnings.warn('Initializing DMatrix from List is deprecated.',
|
||||
DeprecationWarning)
|
||||
|
||||
if isinstance(data, STRING_TYPES):
|
||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
|
||||
ctypes.c_int(silent),
|
||||
ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
@ -653,13 +653,13 @@ class DMatrix(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
fname : string or os.PathLike
|
||||
Name of the output buffer file.
|
||||
silent : bool (optional; default: True)
|
||||
If set, the output is suppressed.
|
||||
"""
|
||||
_check_call(_LIB.XGDMatrixSaveBinary(self.handle,
|
||||
c_str(fname),
|
||||
c_str(os_fspath(fname)),
|
||||
ctypes.c_int(silent)))
|
||||
|
||||
def set_label(self, label):
|
||||
@ -937,7 +937,7 @@ class Booster(object):
|
||||
Parameters for boosters.
|
||||
cache : list
|
||||
List of cache items.
|
||||
model_file : string
|
||||
model_file : string or os.PathLike
|
||||
Path to the model file.
|
||||
"""
|
||||
for d in cache:
|
||||
@ -1329,11 +1329,11 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string
|
||||
fname : string or os.PathLike
|
||||
Output file name
|
||||
"""
|
||||
if isinstance(fname, STRING_TYPES): # assume file name
|
||||
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(fname)))
|
||||
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
|
||||
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname))))
|
||||
else:
|
||||
raise TypeError("fname must be a string")
|
||||
|
||||
@ -1363,12 +1363,12 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string or a memory buffer
|
||||
fname : string, os.PathLike, or a memory buffer
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
"""
|
||||
if isinstance(fname, STRING_TYPES):
|
||||
if isinstance(fname, (STRING_TYPES, os_PathLike)):
|
||||
# assume file name, cannot use os.path.exist to check, file can be from URL.
|
||||
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(fname)))
|
||||
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname))))
|
||||
else:
|
||||
buf = fname
|
||||
length = c_bst_ulong(len(buf))
|
||||
@ -1381,17 +1381,17 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fout : string
|
||||
fout : string or os.PathLike
|
||||
Output file name.
|
||||
fmap : string, optional
|
||||
fmap : string or os.PathLike, optional
|
||||
Name of the file containing feature map names.
|
||||
with_stats : bool, optional
|
||||
Controls whether the split statistics are output.
|
||||
dump_format : string, optional
|
||||
Format of model dump file. Can be 'text' or 'json'.
|
||||
"""
|
||||
if isinstance(fout, STRING_TYPES):
|
||||
fout = open(fout, 'w')
|
||||
if isinstance(fout, (STRING_TYPES, os_PathLike)):
|
||||
fout = open(os_fspath(fout), 'w')
|
||||
need_close = True
|
||||
else:
|
||||
need_close = False
|
||||
@ -1416,13 +1416,14 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap : string, optional
|
||||
fmap : string or os.PathLike, optional
|
||||
Name of the file containing feature map names.
|
||||
with_stats : bool, optional
|
||||
Controls whether the split statistics are output.
|
||||
dump_format : string, optional
|
||||
Format of model dump. Can be 'text', 'json' or 'dot'.
|
||||
"""
|
||||
fmap = os_fspath(fmap)
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
if self.feature_names is not None and fmap == '':
|
||||
@ -1473,7 +1474,7 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str (optional)
|
||||
fmap: str or os.PathLike (optional)
|
||||
The name of feature map file
|
||||
"""
|
||||
|
||||
@ -1497,11 +1498,12 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str (optional)
|
||||
fmap: str or os.PathLike (optional)
|
||||
The name of feature map file.
|
||||
importance_type: str, default 'weight'
|
||||
One of the importance types defined above.
|
||||
"""
|
||||
fmap = os_fspath(fmap)
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError('Feature importance is not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
@ -1591,10 +1593,11 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fmap: str (optional)
|
||||
fmap: str or os.PathLike (optional)
|
||||
The name of feature map file.
|
||||
"""
|
||||
# pylint: disable=too-many-locals
|
||||
fmap = os_fspath(fmap)
|
||||
if not PANDAS_INSTALLED:
|
||||
raise Exception(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
@ -1701,7 +1704,7 @@ class Booster(object):
|
||||
----------
|
||||
feature: str
|
||||
The name of the feature.
|
||||
fmap: str (optional)
|
||||
fmap: str or os.PathLike (optional)
|
||||
The name of feature map file.
|
||||
bin: int, default None
|
||||
The maximum number of bins.
|
||||
|
||||
@ -11,6 +11,7 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
@ -341,3 +342,92 @@ class TestBasic(unittest.TestCase):
|
||||
dtrain.get_float_info('weight')
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('root_index')
|
||||
|
||||
|
||||
class TestBasicPathLike(unittest.TestCase):
|
||||
"""Unit tests using the os_fspath and pathlib.Path for file interaction."""
|
||||
|
||||
def test_DMatrix_init_from_path(self):
|
||||
"""Initialization from the data path."""
|
||||
dpath = Path('demo/data')
|
||||
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
|
||||
assert dtrain.num_row() == 6513
|
||||
assert dtrain.num_col() == 127
|
||||
|
||||
|
||||
def test_DMatrix_save_to_path(self):
|
||||
"""Saving to a binary file using pathlib from a DMatrix."""
|
||||
data = np.random.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ['Feature1', 'Feature2']
|
||||
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
|
||||
# save, assert exists, remove file
|
||||
binary_path = Path("dtrain.bin")
|
||||
dm.save_binary(binary_path)
|
||||
assert binary_path.exists()
|
||||
Path.unlink(binary_path)
|
||||
|
||||
|
||||
def test_Booster_init_invalid_path(self):
|
||||
"""An invalid model_file path should raise XGBoostError."""
|
||||
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
|
||||
model_file=Path("invalidpath"))
|
||||
|
||||
|
||||
def test_Booster_save_and_load(self):
|
||||
"""Saving and loading model files from paths."""
|
||||
save_path = Path("saveload.model")
|
||||
|
||||
data = np.random.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ['Feature1', 'Feature2']
|
||||
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
params = {'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'eta': 0.3,
|
||||
'max_depth': 1}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=1)
|
||||
|
||||
# save, assert exists
|
||||
bst.save_model(save_path)
|
||||
assert save_path.exists()
|
||||
|
||||
def dump_assertions(dump):
|
||||
"""Assertions for the expected dump from Booster"""
|
||||
assert len(dump) == 1, 'Exepcted only 1 tree to be dumped.'
|
||||
assert len(dump[0].splitlines()) == 3, 'Expected 1 root and 2 leaves - 3 lines.'
|
||||
|
||||
# load the model again using Path
|
||||
bst2 = xgb.Booster(model_file=save_path)
|
||||
dump2 = bst2.get_dump()
|
||||
dump_assertions(dump2)
|
||||
|
||||
# load again using load_model
|
||||
bst3 = xgb.Booster()
|
||||
bst3.load_model(save_path)
|
||||
dump3= bst3.get_dump()
|
||||
dump_assertions(dump3)
|
||||
|
||||
# remove file
|
||||
Path.unlink(save_path)
|
||||
|
||||
def test_os_fspath(self):
|
||||
"""Core properties of the os_fspath function."""
|
||||
# strings are returned unmodified
|
||||
assert '' == xgb.compat.os_fspath('')
|
||||
assert '/this/path' == xgb.compat.os_fspath('/this/path')
|
||||
|
||||
# bytes are returned unmodified
|
||||
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
|
||||
|
||||
# path objects are returned as string representation
|
||||
path_test = Path('this') / 'path'
|
||||
assert str(path_test) == xgb.compat.os_fspath(path_test)
|
||||
|
||||
# invalid values raise Type error
|
||||
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user