Add dart to JSON schema. (#5218)
* Add dart to JSON schema. * Use spaces instead of tab.
This commit is contained in:
@@ -13,13 +13,13 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
def json_model(model_path):
|
||||
def json_model(model_path, parameters):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
dm1 = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgb.train({'tree_method': 'hist'}, dm1)
|
||||
bst = xgb.train(parameters, dm1)
|
||||
bst.save_model(model_path)
|
||||
|
||||
with open(model_path, 'r') as fd:
|
||||
@@ -285,7 +285,8 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_model_json_io(self):
|
||||
model_path = './model.json'
|
||||
j_model = json_model(model_path)
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||
j_model = json_model(model_path, parameters)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
@@ -306,5 +307,12 @@ class TestModels(unittest.TestCase):
|
||||
doc = os.path.join(path, 'doc', 'model.schema')
|
||||
with open(doc, 'r') as fd:
|
||||
schema = json.load(fd)
|
||||
jsonschema.validate(instance=json_model(model_path), schema=schema)
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||
schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
parameters = {'tree_method': 'hist', 'booster': 'dart'}
|
||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||
schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
Reference in New Issue
Block a user