Add dart to JSON schema. (#5218)

* Add dart to JSON schema.

* Use spaces instead of tab.
This commit is contained in:
Jiaming Yuan 2020-01-28 13:29:09 +08:00 committed by GitHub
parent 0c7455276d
commit ef19480eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 219 additions and 186 deletions

View File

@ -1,80 +1,7 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"definitions": { "definitions": {
"gbtree_model_param": { "gbtree": {
"type": "object",
"properties": {
"num_trees": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
}
},
"required": [
"num_trees",
"size_leaf_vector"
]
},
"tree_param": {
"type": "object",
"properties": {
"num_nodes": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
},
"num_feature": {
"type": "string"
}
},
"required": [
"num_nodes",
"num_feature",
"size_leaf_vector"
]
},
"reg_loss_param": {
"type": "object",
"properties": {
"scale_pos_weight": {
"type": "string"
}
}
},
"softmax_multiclass_param": {
"type": "object",
"properties": {
"num_class": { "type": "string" }
}
},
"lambda_rank_param": {
"type": "object",
"properties": {
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
}
},
"type": "object",
"properties": {
"version": {
"type": "array",
"const": [
1,
0,
0
],
"additionalItems": false
},
"learner": {
"type": "object",
"properties": {
"gradient_booster": {
"oneOf": [
{
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": {
@ -198,7 +125,8 @@
}, },
"required": [ "required": [
"gbtree_model_param", "gbtree_model_param",
"trees" "trees",
"tree_info"
] ]
} }
}, },
@ -207,6 +135,81 @@
"model" "model"
] ]
}, },
"gbtree_model_param": {
"type": "object",
"properties": {
"num_trees": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
}
},
"required": [
"num_trees",
"size_leaf_vector"
]
},
"tree_param": {
"type": "object",
"properties": {
"num_nodes": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
},
"num_feature": {
"type": "string"
}
},
"required": [
"num_nodes",
"num_feature",
"size_leaf_vector"
]
},
"reg_loss_param": {
"type": "object",
"properties": {
"scale_pos_weight": {
"type": "string"
}
}
},
"softmax_multiclass_param": {
"type": "object",
"properties": {
"num_class": { "type": "string" }
}
},
"lambda_rank_param": {
"type": "object",
"properties": {
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
}
},
"type": "object",
"properties": {
"version": {
"type": "array",
"const": [
1,
0,
0
],
"additionalItems": false
},
"learner": {
"type": "object",
"properties": {
"gradient_booster": {
"oneOf": [
{
"$ref": "#/definitions/gbtree"
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -223,6 +226,26 @@
} }
} }
} }
},
{
"type": "object",
"properties": {
"name": { "const": "dart" },
"gbtree": {
"$ref": "#/definitions/gbtree"
},
"weight_drop": {
"type": "array",
"items": {
"type": "number"
}
}
},
"required": [
"name",
"gbtree",
"weight_drop"
]
} }
] ]
}, },

View File

@ -201,6 +201,8 @@ Another important feature of JSON format is a documented `Schema
XGBoost. Here is the initial draft of JSON schema for the output model (not XGBoost. Here is the initial draft of JSON schema for the output model (not
serialization, which will not be stable as noted above). It's subject to change due to serialization, which will not be stable as noted above). It's subject to change due to
the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``. the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``.
Please notice the "weight_drop" field used in "dart" booster. XGBoost does not scale tree
leaf directly, instead it saves the weights as a separated array.
.. include:: ../model.schema .. include:: ../model.schema
:code: json :code: json

View File

@ -13,13 +13,13 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
def json_model(model_path): def json_model(model_path, parameters):
X = np.random.random((10, 3)) X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,)) y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y) dm1 = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'hist'}, dm1) bst = xgb.train(parameters, dm1)
bst.save_model(model_path) bst.save_model(model_path)
with open(model_path, 'r') as fd: with open(model_path, 'r') as fd:
@ -285,7 +285,8 @@ class TestModels(unittest.TestCase):
def test_model_json_io(self): def test_model_json_io(self):
model_path = './model.json' model_path = './model.json'
j_model = json_model(model_path) parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json') bst = xgb.Booster(model_file='./model.json')
@ -306,5 +307,12 @@ class TestModels(unittest.TestCase):
doc = os.path.join(path, 'doc', 'model.schema') doc = os.path.join(path, 'doc', 'model.schema')
with open(doc, 'r') as fd: with open(doc, 'r') as fd:
schema = json.load(fd) schema = json.load(fd)
jsonschema.validate(instance=json_model(model_path), schema=schema) parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)
parameters = {'tree_method': 'hist', 'booster': 'dart'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path) os.remove(model_path)