Fix plotting test. (#6040)
Previously the test loads a model generated by `test_basic.py`, now we generate the model explicitly. * Cleanup saved files for basic tests.
This commit is contained in:
@@ -110,16 +110,19 @@ class TestBasic(unittest.TestCase):
|
||||
# error must be smaller than 10%
|
||||
assert err < 0.1
|
||||
|
||||
# save dmatrix into binary buffer
|
||||
dtest.save_binary('dtest.buffer')
|
||||
# save model
|
||||
bst.save_model('xgb.model')
|
||||
# load model and data in
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
dtest2 = xgb.DMatrix('dtest.buffer')
|
||||
preds2 = bst2.predict(dtest2)
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dtest_path = os.path.join(tmpdir, 'dtest.buffer')
|
||||
model_path = os.path.join(tmpdir, 'xgb.model')
|
||||
# save dmatrix into binary buffer
|
||||
dtest.save_binary(dtest_path)
|
||||
# save model
|
||||
bst.save_model(model_path)
|
||||
# load model and data in
|
||||
bst2 = xgb.Booster(model_file=model_path)
|
||||
dtest2 = xgb.DMatrix(dtest_path)
|
||||
preds2 = bst2.predict(dtest2)
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def test_dump(self):
|
||||
data = np.random.randn(100, 2)
|
||||
|
||||
Reference in New Issue
Block a user