Simplify the data backends. (#5893)
This commit is contained in:
@@ -1,17 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
try:
|
||||
# python 2
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
# python 3
|
||||
from io import StringIO
|
||||
from io import StringIO
|
||||
import numpy as np
|
||||
import os
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import json
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
@@ -66,16 +63,19 @@ class TestBasic(unittest.TestCase):
|
||||
# error must be smaller than 10%
|
||||
assert err < 0.1
|
||||
|
||||
# save dmatrix into binary buffer
|
||||
dtest.save_binary('dtest.buffer')
|
||||
# save model
|
||||
bst.save_model('xgb.model')
|
||||
# load model and data in
|
||||
bst2 = xgb.Booster(model_file='xgb.model')
|
||||
dtest2 = xgb.DMatrix('dtest.buffer')
|
||||
preds2 = bst2.predict(dtest2)
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dtest_path = os.path.join(tmpdir, 'dtest.dmatrix')
|
||||
# save dmatrix into binary buffer
|
||||
dtest.save_binary(dtest_path)
|
||||
# save model
|
||||
model_path = os.path.join(tmpdir, 'model.booster')
|
||||
bst.save_model(model_path)
|
||||
# load model and data in
|
||||
bst2 = xgb.Booster(model_file=model_path)
|
||||
dtest2 = xgb.DMatrix(dtest_path)
|
||||
preds2 = bst2.predict(dtest2)
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
|
||||
Reference in New Issue
Block a user