Add dart to JSON schema. (#5218)

* Add dart to JSON schema.

* Use spaces instead of tab.
This commit is contained in:
Jiaming Yuan 2020-01-28 13:29:09 +08:00 committed by GitHub
parent 0c7455276d
commit ef19480eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 219 additions and 186 deletions

View File

@ -1,6 +1,140 @@
{ {
"$schema": "http://json-schema.org/draft-07/schema#", "$schema": "http://json-schema.org/draft-07/schema#",
"definitions": { "definitions": {
"gbtree": {
"type": "object",
"properties": {
"name": {
"const": "gbtree"
},
"model": {
"type": "object",
"properties": {
"gbtree_model_param": {
"$ref": "#/definitions/gbtree_model_param"
},
"trees": {
"type": "array",
"items": {
"type": "object",
"properties": {
"tree_param": {
"type": "object",
"properties": {
"num_nodes": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
},
"num_feature": {
"type": "string"
}
},
"required": [
"num_nodes",
"num_feature",
"size_leaf_vector"
]
},
"id": {
"type": "integer"
},
"loss_changes": {
"type": "array",
"items": {
"type": "number"
}
},
"sum_hessian": {
"type": "array",
"items": {
"type": "number"
}
},
"base_weights": {
"type": "array",
"items": {
"type": "number"
}
},
"leaf_child_counts": {
"type": "array",
"items": {
"type": "integer"
}
},
"left_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"right_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"parents": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_indices": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_conditions": {
"type": "array",
"items": {
"type": "number"
}
},
"default_left": {
"type": "array",
"items": {
"type": "boolean"
}
}
},
"required": [
"tree_param",
"loss_changes",
"sum_hessian",
"base_weights",
"leaf_child_counts",
"left_children",
"right_children",
"parents",
"split_indices",
"split_conditions",
"default_left"
]
}
},
"tree_info": {
"type": "array",
"items": {
"type": "integer"
}
}
},
"required": [
"gbtree_model_param",
"trees",
"tree_info"
]
}
},
"required": [
"name",
"model"
]
},
"gbtree_model_param": { "gbtree_model_param": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -35,7 +169,6 @@
"size_leaf_vector" "size_leaf_vector"
] ]
}, },
"reg_loss_param": { "reg_loss_param": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -47,14 +180,14 @@
"softmax_multiclass_param": { "softmax_multiclass_param": {
"type": "object", "type": "object",
"properties": { "properties": {
"num_class": { "type": "string" } "num_class": { "type": "string" }
} }
}, },
"lambda_rank_param": { "lambda_rank_param": {
"type": "object", "type": "object",
"properties": { "properties": {
"num_pairsample": { "type": "string" }, "num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" } "fix_list_weight": { "type": "string" }
} }
} }
}, },
@ -74,156 +207,46 @@
"properties": { "properties": {
"gradient_booster": { "gradient_booster": {
"oneOf": [ "oneOf": [
{
"$ref": "#/definitions/gbtree"
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "name": { "const": "gblinear" },
"const": "gbtree"
},
"model": { "model": {
"type": "object", "type": "object",
"properties": { "properties": {
"gbtree_model_param": { "weights": {
"$ref": "#/definitions/gbtree_model_param"
},
"trees": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "type": "number"
"properties": {
"tree_param": {
"type": "object",
"properties": {
"num_nodes": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
},
"num_feature": {
"type": "string"
}
},
"required": [
"num_nodes",
"num_feature",
"size_leaf_vector"
]
},
"id": {
"type": "integer"
},
"loss_changes": {
"type": "array",
"items": {
"type": "number"
}
},
"sum_hessian": {
"type": "array",
"items": {
"type": "number"
}
},
"base_weights": {
"type": "array",
"items": {
"type": "number"
}
},
"leaf_child_counts": {
"type": "array",
"items": {
"type": "integer"
}
},
"left_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"right_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"parents": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_indices": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_conditions": {
"type": "array",
"items": {
"type": "number"
}
},
"default_left": {
"type": "array",
"items": {
"type": "boolean"
}
}
},
"required": [
"tree_param",
"loss_changes",
"sum_hessian",
"base_weights",
"leaf_child_counts",
"left_children",
"right_children",
"parents",
"split_indices",
"split_conditions",
"default_left"
]
}
},
"tree_info": {
"type": "array",
"items": {
"type": "integer"
} }
} }
}, }
"required": [ }
"gbtree_model_param", }
"trees" },
] {
"type": "object",
"properties": {
"name": { "const": "dart" },
"gbtree": {
"$ref": "#/definitions/gbtree"
},
"weight_drop": {
"type": "array",
"items": {
"type": "number"
}
} }
}, },
"required": [ "required": [
"name", "name",
"model" "gbtree",
"weight_drop"
] ]
}, }
{
"type": "object",
"properties": {
"name": { "const": "gblinear" },
"model": {
"type": "object",
"properties": {
"weights": {
"type": "array",
"items": {
"type": "number"
}
}
}
}
}
}
] ]
}, },
@ -233,51 +256,51 @@
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "reg:squarederror" }, "name": { "const": "reg:squarederror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
}, },
"required": [ "required": [
"name", "name",
"reg_loss_param" "reg_loss_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "reg:squaredlogerror" }, "name": { "const": "reg:squaredlogerror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
}, },
"required": [ "required": [
"name", "name",
"reg_loss_param" "reg_loss_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "reg:logistic" }, "name": { "const": "reg:logistic" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
}, },
"required": [ "required": [
"name", "name",
"reg_loss_param" "reg_loss_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "binary:logistic" }, "name": { "const": "binary:logistic" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
}, },
"required": [ "required": [
"name", "name",
"reg_loss_param" "reg_loss_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "binary:logitraw" }, "name": { "const": "binary:logitraw" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"} "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
}, },
"required": [ "required": [
"name", "name",
@ -285,46 +308,46 @@
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "count:poisson" }, "name": { "const": "count:poisson" },
"poisson_regression_param": { "poisson_regression_param": {
"type": "object", "type": "object",
"properties": { "properties": {
"max_delta_step": { "type": "string" } "max_delta_step": { "type": "string" }
} }
} }
}, },
"required": [ "required": [
"name", "name",
"poisson_regression_param" "poisson_regression_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "reg:tweedie" }, "name": { "const": "reg:tweedie" },
"tweedie_regression_param": { "tweedie_regression_param": {
"type": "object", "type": "object",
"properties": { "properties": {
"tweedie_variance_power": { "type": "string" } "tweedie_variance_power": { "type": "string" }
} }
} }
}, },
"required": [ "required": [
"name", "name",
"tweedie_regression_param" "tweedie_regression_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "survival:cox" } "name": { "const": "survival:cox" }
}, },
"required": [ "name" ] "required": [ "name" ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "reg:gamma" } "name": { "const": "reg:gamma" }
@ -332,22 +355,22 @@
"required": [ "name" ] "required": [ "name" ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "multi:softprob" }, "name": { "const": "multi:softprob" },
"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"} "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
}, },
"required": [ "required": [
"name", "name",
"softmax_multiclass_param" "softmax_multiclass_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "multi:softmax" }, "name": { "const": "multi:softmax" },
"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"} "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
}, },
"required": [ "required": [
"name", "name",
@ -355,33 +378,33 @@
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "rank:pairwise" }, "name": { "const": "rank:pairwise" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
}, },
"required": [ "required": [
"name", "name",
"lambda_rank_param" "lambda_rank_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "rank:ndcg" }, "name": { "const": "rank:ndcg" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
}, },
"required": [ "required": [
"name", "name",
"lambda_rank_param" "lambda_rank_param"
] ]
}, },
{ {
"type": "object", "type": "object",
"properties": { "properties": {
"name": { "const": "rank:map" }, "name": { "const": "rank:map" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
}, },
"required": [ "required": [
"name", "name",
@ -391,14 +414,14 @@
] ]
}, },
"learner_model_param": { "learner_model_param": {
"type": "object", "type": "object",
"properties": { "properties": {
"base_score": { "type": "string" }, "base_score": { "type": "string" },
"num_class": { "type": "string" }, "num_class": { "type": "string" },
"num_feature": { "type": "string" } "num_feature": { "type": "string" }
} }
} }
}, },
"required": [ "required": [
"gradient_booster", "gradient_booster",

View File

@ -201,6 +201,8 @@ Another important feature of JSON format is a documented `Schema
XGBoost. Here is the initial draft of JSON schema for the output model (not XGBoost. Here is the initial draft of JSON schema for the output model (not
serialization, which will not be stable as noted above). It's subject to change due to serialization, which will not be stable as noted above). It's subject to change due to
the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``. the beta status. For an example of parsing XGBoost tree model, see ``/demo/json-model``.
Please notice the "weight_drop" field used in "dart" booster. XGBoost does not scale tree
leaf directly, instead it saves the weights as a separated array.
.. include:: ../model.schema .. include:: ../model.schema
:code: json :code: json

View File

@ -13,13 +13,13 @@ dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
def json_model(model_path): def json_model(model_path, parameters):
X = np.random.random((10, 3)) X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,)) y = np.random.randint(2, size=(10,))
dm1 = xgb.DMatrix(X, y) dm1 = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'hist'}, dm1) bst = xgb.train(parameters, dm1)
bst.save_model(model_path) bst.save_model(model_path)
with open(model_path, 'r') as fd: with open(model_path, 'r') as fd:
@ -285,7 +285,8 @@ class TestModels(unittest.TestCase):
def test_model_json_io(self): def test_model_json_io(self):
model_path = './model.json' model_path = './model.json'
j_model = json_model(model_path) parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
j_model = json_model(model_path, parameters)
assert isinstance(j_model['learner'], dict) assert isinstance(j_model['learner'], dict)
bst = xgb.Booster(model_file='./model.json') bst = xgb.Booster(model_file='./model.json')
@ -306,5 +307,12 @@ class TestModels(unittest.TestCase):
doc = os.path.join(path, 'doc', 'model.schema') doc = os.path.join(path, 'doc', 'model.schema')
with open(doc, 'r') as fd: with open(doc, 'r') as fd:
schema = json.load(fd) schema = json.load(fd)
jsonschema.validate(instance=json_model(model_path), schema=schema) parameters = {'tree_method': 'hist', 'booster': 'gbtree'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path)
parameters = {'tree_method': 'hist', 'booster': 'dart'}
jsonschema.validate(instance=json_model(model_path, parameters),
schema=schema)
os.remove(model_path) os.remove(model_path)