Example JSON model parser and Schema. (#5137)

This commit is contained in:
Jiaming Yuan
2019-12-23 19:47:35 +08:00
committed by GitHub
parent a4b929385e
commit 1d0ca49761
8 changed files with 655 additions and 15 deletions

View File

@@ -3,6 +3,8 @@ import xgboost as xgb
import unittest
import os
import json
import testing as tm
import pytest
dpath = 'demo/data/'
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
@@ -11,6 +13,20 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
rng = np.random.RandomState(1994)
def json_model(model_path):
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'hist'}, dm1)
bst.save_model(model_path)
with open(model_path, 'r') as fd:
model = json.load(fd)
return model
class TestModels(unittest.TestCase):
def test_glm(self):
param = {'verbosity': 0, 'objective': 'binary:logistic',
@@ -42,8 +58,9 @@ class TestModels(unittest.TestCase):
# save dmatrix into binary buffer
dtest.save_binary('dtest.buffer')
model_path = 'xgb.model.dart'
# save model
bst.save_model('xgb.model.dart')
bst.save_model(model_path)
# load model and data in
bst2 = xgb.Booster(params=param, model_file='xgb.model.dart')
dtest2 = xgb.DMatrix('dtest.buffer')
@@ -69,6 +86,7 @@ class TestModels(unittest.TestCase):
for ii in range(len(preds_list)):
for jj in range(ii + 1, len(preds_list)):
assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
os.remove(model_path)
def test_eta_decay(self):
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
@@ -204,21 +222,27 @@ class TestModels(unittest.TestCase):
bst.predict(dm2) # success
def test_model_json_io(self):
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'hist'}, dm1)
bst.save_model('./model.json')
with open('./model.json', 'r') as fd:
j_model = json.load(fd)
model_path = './model.json'
j_model = json_model(model_path)
assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json')
bst.save_model(fname=model_path)
with open('./model.json', 'r') as fd:
j_model = json.load(fd)
assert isinstance(j_model['learner'], dict)
os.remove('model.json')
os.remove(model_path)
@pytest.mark.skipif(**tm.no_json_schema())
def test_json_schema(self):
import jsonschema
model_path = './model.json'
path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
doc = os.path.join(path, 'doc', 'model.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)
jsonschema.validate(instance=json_model(model_path), schema=schema)
os.remove(model_path)