add os.PathLike support for file paths to DMatrix and Booster Python classes (#4757)
This commit is contained in:
committed by
Jiaming Yuan
parent
7b5cbcc846
commit
53d4272c2a
@@ -11,6 +11,7 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
@@ -341,3 +342,92 @@ class TestBasic(unittest.TestCase):
|
||||
dtrain.get_float_info('weight')
|
||||
dtrain.get_float_info('base_margin')
|
||||
dtrain.get_uint_info('root_index')
|
||||
|
||||
|
||||
class TestBasicPathLike(unittest.TestCase):
|
||||
"""Unit tests using the os_fspath and pathlib.Path for file interaction."""
|
||||
|
||||
def test_DMatrix_init_from_path(self):
|
||||
"""Initialization from the data path."""
|
||||
dpath = Path('demo/data')
|
||||
dtrain = xgb.DMatrix(dpath / 'agaricus.txt.train')
|
||||
assert dtrain.num_row() == 6513
|
||||
assert dtrain.num_col() == 127
|
||||
|
||||
|
||||
def test_DMatrix_save_to_path(self):
|
||||
"""Saving to a binary file using pathlib from a DMatrix."""
|
||||
data = np.random.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ['Feature1', 'Feature2']
|
||||
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
|
||||
# save, assert exists, remove file
|
||||
binary_path = Path("dtrain.bin")
|
||||
dm.save_binary(binary_path)
|
||||
assert binary_path.exists()
|
||||
Path.unlink(binary_path)
|
||||
|
||||
|
||||
def test_Booster_init_invalid_path(self):
|
||||
"""An invalid model_file path should raise XGBoostError."""
|
||||
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
|
||||
model_file=Path("invalidpath"))
|
||||
|
||||
|
||||
def test_Booster_save_and_load(self):
|
||||
"""Saving and loading model files from paths."""
|
||||
save_path = Path("saveload.model")
|
||||
|
||||
data = np.random.randn(100, 2)
|
||||
target = np.array([0, 1] * 50)
|
||||
features = ['Feature1', 'Feature2']
|
||||
|
||||
dm = xgb.DMatrix(data, label=target, feature_names=features)
|
||||
params = {'objective': 'binary:logistic',
|
||||
'eval_metric': 'logloss',
|
||||
'eta': 0.3,
|
||||
'max_depth': 1}
|
||||
|
||||
bst = xgb.train(params, dm, num_boost_round=1)
|
||||
|
||||
# save, assert exists
|
||||
bst.save_model(save_path)
|
||||
assert save_path.exists()
|
||||
|
||||
def dump_assertions(dump):
|
||||
"""Assertions for the expected dump from Booster"""
|
||||
assert len(dump) == 1, 'Exepcted only 1 tree to be dumped.'
|
||||
assert len(dump[0].splitlines()) == 3, 'Expected 1 root and 2 leaves - 3 lines.'
|
||||
|
||||
# load the model again using Path
|
||||
bst2 = xgb.Booster(model_file=save_path)
|
||||
dump2 = bst2.get_dump()
|
||||
dump_assertions(dump2)
|
||||
|
||||
# load again using load_model
|
||||
bst3 = xgb.Booster()
|
||||
bst3.load_model(save_path)
|
||||
dump3= bst3.get_dump()
|
||||
dump_assertions(dump3)
|
||||
|
||||
# remove file
|
||||
Path.unlink(save_path)
|
||||
|
||||
def test_os_fspath(self):
|
||||
"""Core properties of the os_fspath function."""
|
||||
# strings are returned unmodified
|
||||
assert '' == xgb.compat.os_fspath('')
|
||||
assert '/this/path' == xgb.compat.os_fspath('/this/path')
|
||||
|
||||
# bytes are returned unmodified
|
||||
assert b'/this/path' == xgb.compat.os_fspath(b'/this/path')
|
||||
|
||||
# path objects are returned as string representation
|
||||
path_test = Path('this') / 'path'
|
||||
assert str(path_test) == xgb.compat.os_fspath(path_test)
|
||||
|
||||
# invalid values raise Type error
|
||||
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user