add os.PathLike support for file paths to DMatrix and Booster Python classes (#4757)

This commit is contained in:
Evan Kepner
2019-08-15 04:46:25 -04:00
committed by Jiaming Yuan
parent 7b5cbcc846
commit 53d4272c2a
4 changed files with 184 additions and 24 deletions

View File

@@ -11,6 +11,7 @@ import numpy as np
import xgboost as xgb
import unittest
import json
from pathlib import Path
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
@@ -341,3 +342,92 @@ class TestBasic(unittest.TestCase):
dtrain.get_float_info('weight')
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('root_index')
class TestBasicPathLike(unittest.TestCase):
"""Unit tests using the os_fspath and pathlib.Path for file interaction."""
def test_DMatrix_init_from_path(self):
"""Initialization from the data path."""
dpath = Path('demo/data')
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
assert dtrain.num_row() == 6513
assert dtrain.num_col() == 127
def test_DMatrix_save_to_path(self):
"""Saving to a binary file using pathlib from a DMatrix."""
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
# save, assert exists, remove file
binary_path = Path("dtrain.bin")
dm.save_binary(binary_path)
assert binary_path.exists()
Path.unlink(binary_path)
def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
model_file=Path("invalidpath"))
def test_Booster_save_and_load(self):
"""Saving and loading model files from paths."""
save_path = Path("saveload.model")
data = np.random.randn(100, 2)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2']
dm = xgb.DMatrix(data, label=target, feature_names=features)
params = {'objective': 'binary:logistic',
'eval_metric': 'logloss',
'eta': 0.3,
'max_depth': 1}
bst = xgb.train(params, dm, num_boost_round=1)
# save, assert exists
bst.save_model(save_path)
assert save_path.exists()
def dump_assertions(dump):
"""Assertions for the expected dump from Booster"""
assert len(dump) == 1, 'Exepcted only 1 tree to be dumped.'
assert len(dump[0].splitlines()) == 3, 'Expected 1 root and 2 leaves - 3 lines.'
# load the model again using Path
bst2 = xgb.Booster(model_file=save_path)
dump2 = bst2.get_dump()
dump_assertions(dump2)
# load again using load_model
bst3 = xgb.Booster()
bst3.load_model(save_path)
dump3= bst3.get_dump()
dump_assertions(dump3)
# remove file
Path.unlink(save_path)
def test_os_fspath(self):
"""Core properties of the os_fspath function."""
# strings are returned unmodified
assert '' == xgb.compat.os_fspath('')
assert '/this/path' == xgb.compat.os_fspath('/this/path')
# bytes are returned unmodified
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
# path objects are returned as string representation
path_test = Path('this') / 'path'
assert str(path_test) == xgb.compat.os_fspath(path_test)
# invalid values raise Type error
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)