Add dart to JSON schema. (#5218)
* Add dart to JSON schema. * Use spaces instead of tab.
This commit is contained in:
parent
0c7455276d
commit
ef19480eda
173
doc/model.schema
173
doc/model.schema
@ -1,80 +1,7 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"definitions": {
|
||||
"gbtree_model_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_trees": {
|
||||
"type": "string"
|
||||
},
|
||||
"size_leaf_vector": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"num_trees",
|
||||
"size_leaf_vector"
|
||||
]
|
||||
},
|
||||
"tree_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_nodes": {
|
||||
"type": "string"
|
||||
},
|
||||
"size_leaf_vector": {
|
||||
"type": "string"
|
||||
},
|
||||
"num_feature": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"num_nodes",
|
||||
"num_feature",
|
||||
"size_leaf_vector"
|
||||
]
|
||||
},
|
||||
|
||||
"reg_loss_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scale_pos_weight": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"softmax_multiclass_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_class": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"lambda_rank_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_pairsample": { "type": "string" },
|
||||
"fix_list_weight": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"version": {
|
||||
"type": "array",
|
||||
"const": [
|
||||
1,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"additionalItems": false
|
||||
},
|
||||
"learner": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"gradient_booster": {
|
||||
"oneOf": [
|
||||
{
|
||||
"gbtree": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
@ -198,7 +125,8 @@
|
||||
},
|
||||
"required": [
|
||||
"gbtree_model_param",
|
||||
"trees"
|
||||
"trees",
|
||||
"tree_info"
|
||||
]
|
||||
}
|
||||
},
|
||||
@ -207,6 +135,81 @@
|
||||
"model"
|
||||
]
|
||||
},
|
||||
"gbtree_model_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_trees": {
|
||||
"type": "string"
|
||||
},
|
||||
"size_leaf_vector": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"num_trees",
|
||||
"size_leaf_vector"
|
||||
]
|
||||
},
|
||||
"tree_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_nodes": {
|
||||
"type": "string"
|
||||
},
|
||||
"size_leaf_vector": {
|
||||
"type": "string"
|
||||
},
|
||||
"num_feature": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"num_nodes",
|
||||
"num_feature",
|
||||
"size_leaf_vector"
|
||||
]
|
||||
},
|
||||
"reg_loss_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"scale_pos_weight": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"softmax_multiclass_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_class": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"lambda_rank_param": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_pairsample": { "type": "string" },
|
||||
"fix_list_weight": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"version": {
|
||||
"type": "array",
|
||||
"const": [
|
||||
1,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"additionalItems": false
|
||||
},
|
||||
"learner": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"gradient_booster": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/definitions/gbtree"
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -223,6 +226,26 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "const": "dart" },
|
||||
"gbtree": {
|
||||
"$ref": "#/definitions/gbtree"
|
||||
},
|
||||
"weight_drop": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"gbtree",
|
||||
"weight_drop"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -201,6 +201,8 @@ Another important feature of JSON format is a documented `Schema
|
||||
XGBoost. Here is the initial draft of JSON schema for the output model (not
|
||||
serialization, which will not be stable as noted above). It's subject to change due to
|
||||
the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``.
|
||||
Please notice the "weight_drop" field used in "dart" booster. XGBoost does not scale tree
|
||||
leaf directly, instead it saves the weights as a separated array.
|
||||
|
||||
.. include:: ../model.schema
|
||||
:code: json
|
||||
|
||||
@ -13,13 +13,13 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
def json_model(model_path):
|
||||
def json_model(model_path, parameters):
|
||||
X = np.random.random((10, 3))
|
||||
y = np.random.randint(2, size=(10,))
|
||||
|
||||
dm1 = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgb.train({'tree_method': 'hist'}, dm1)
|
||||
bst = xgb.train(parameters, dm1)
|
||||
bst.save_model(model_path)
|
||||
|
||||
with open(model_path, 'r') as fd:
|
||||
@ -285,7 +285,8 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_model_json_io(self):
|
||||
model_path = './model.json'
|
||||
j_model = json_model(model_path)
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||
j_model = json_model(model_path, parameters)
|
||||
assert isinstance(j_model['learner'], dict)
|
||||
|
||||
bst = xgb.Booster(model_file='./model.json')
|
||||
@ -306,5 +307,12 @@ class TestModels(unittest.TestCase):
|
||||
doc = os.path.join(path, 'doc', 'model.schema')
|
||||
with open(doc, 'r') as fd:
|
||||
schema = json.load(fd)
|
||||
jsonschema.validate(instance=json_model(model_path), schema=schema)
|
||||
parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
|
||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||
schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
parameters = {'tree_method': 'hist', 'booster': 'dart'}
|
||||
jsonschema.validate(instance=json_model(model_path, parameters),
|
||||
schema=schema)
|
||||
os.remove(model_path)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user