Example JSON model parser and Schema. (#5137)
This commit is contained in:
@@ -22,7 +22,7 @@ ENV GOSU_VERSION 1.10
|
||||
# Install Python packages
|
||||
RUN \
|
||||
pip install pyyaml cpplint pylint astroid sphinx numpy scipy pandas matplotlib sh recommonmark guzzle_sphinx_theme mock \
|
||||
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 && \
|
||||
breathe matplotlib graphviz pytest scikit-learn wheel kubernetes urllib3 jsonschema && \
|
||||
pip install https://h2o-release.s3.amazonaws.com/datatable/stable/datatable-0.7.0/datatable-0.7.0-cp37-cp37m-linux_x86_64.whl && \
|
||||
pip install "dask[complete]"
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ import xgboost as xgb
|
||||
import unittest
|
||||
import os
|
||||
import json
|
||||
import testing as tm
|
||||
import pytest
|
||||
|
||||
dpath = 'demo/data/'
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
@@ -11,6 +13,20 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
def json_model(model_path):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
dm1 = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgb.train({'tree_method': 'hist'}, dm1)
|
||||
bst.save_model(model_path)
|
||||
|
||||
with open(model_path, 'r') as fd:
|
||||
model = json.load(fd)
|
||||
return model
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
def test_glm(self):
|
||||
param = {'verbosity': 0, 'objective': 'binary:logistic',
|
||||
@@ -42,8 +58,9 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
# save dmatrix into binary buffer
|
||||
dtest.save_binary('dtest.buffer')
|
||||
model_path = 'xgb.model.dart'
|
||||
# save model
|
||||
bst.save_model('xgb.model.dart')
|
||||
bst.save_model(model_path)
|
||||
# load model and data in
|
||||
bst2 = xgb.Booster(params=param, model_file='xgb.model.dart')
|
||||
dtest2 = xgb.DMatrix('dtest.buffer')
|
||||
@@ -69,6 +86,7 @@ class TestModels(unittest.TestCase):
|
||||
for ii in range(len(preds_list)):
|
||||
for jj in range(ii + 1, len(preds_list)):
|
||||
assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
|
||||
os.remove(model_path)
|
||||
|
||||
def test_eta_decay(self):
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
@@ -204,21 +222,27 @@ class TestModels(unittest.TestCase):
|
||||
bst.predict(dm2) # success
|
||||
|
||||
def test_model_json_io(self):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
dm1 = xgb.DMatrix(X, y)
|
||||
bst = xgb.train({'tree_method': 'hist'}, dm1)
|
||||
bst.save_model('./model.json')
|
||||
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
model_path = './model.json'
|
||||
j_model = json_model(model_path)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
|
||||
bst.save_model(fname=model_path)
|
||||
with open('./model.json', 'r') as fd:
|
||||
j_model = json.load(fd)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
os.remove('model.json')
|
||||
os.remove(model_path)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_json_schema())
|
||||
def test_json_schema(self):
|
||||
import jsonschema
|
||||
model_path = './model.json'
|
||||
path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
doc = os.path.join(path, 'doc', 'model.schema')
|
||||
with open(doc, 'r') as fd:
|
||||
schema = json.load(fd)
|
||||
jsonschema.validate(instance=json_model(model_path), schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
@@ -55,3 +55,12 @@ def no_dask_cudf():
|
||||
return {'condition': False, 'reason': reason}
|
||||
except ImportError:
|
||||
return {'condition': True, 'reason': reason}
|
||||
|
||||
|
||||
def no_json_schema():
|
||||
reason = 'jsonschema is not installed'
|
||||
try:
|
||||
import jsonschema # noqa
|
||||
return {'condition': False, 'reason': reason}
|
||||
except ImportError:
|
||||
return {'condition': True, 'reason': reason}
|
||||
|
||||
Reference in New Issue
Block a user