Add dart to JSON schema. (#5218)

* Add dart to JSON schema.

* Use spaces instead of tab.
This commit is contained in:
Jiaming Yuan
2020-01-28 13:29:09 +08:00
committed by GitHub
parent 0c7455276d
commit ef19480eda
3 changed files with 219 additions and 186 deletions

View File

@@ -13,13 +13,13 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
rng = np.random.RandomState(1994)
def json_model(model_path):
def json_model(model_path, parameters):
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'hist'}, dm1)
bst = xgb.train(parameters, dm1)
bst.save_model(model_path)
with open(model_path, 'r') as fd:
@@ -285,7 +285,8 @@ class TestModels(unittest.TestCase):
def test_model_json_io(self):
model_path = './model.json'
j_model = json_model(model_path)
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json')
@@ -306,5 +307,12 @@ class TestModels(unittest.TestCase):
doc = os.path.join(path, 'doc', 'model.schema')
with open(doc, 'r') as fd:
schema = json.load(fd)
jsonschema.validate(instance=json_model(model_path), schema=schema)
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)
parameters = {'tree_method': 'hist', 'booster': 'dart'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)