Model IO in JSON. (#5110)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import os
|
||||
import json
|
||||
|
||||
dpath = 'demo/data/'
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
@@ -200,3 +202,23 @@ class TestModels(unittest.TestCase):
|
||||
bst.predict(dm2) # success
|
||||
self.assertRaises(ValueError, bst.predict, dm1)
|
||||
bst.predict(dm2) # success
|
||||
|
||||
def test_json_model_io(self):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
dm1 = xgb.DMatrix(X, y)
|
||||
bst = xgb.train({'tree_method': 'hist'}, dm1)
|
||||
bst.save_model('./model.json')
|
||||
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['Learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['Learner'], dict)
|
||||
|
||||
os.remove('model.json')
|
||||
|
||||
Reference in New Issue
Block a user