Update JSON schema. (#5982)

* Update JSON schema for pseudo huber.
* Update JSON model schema.
This commit is contained in:
Jiaming Yuan 2020-08-05 15:21:11 +08:00 committed by GitHub
parent 9c93531709
commit 8599f87597
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 0 deletions

View File

@ -177,6 +177,17 @@
} }
} }
}, },
"aft_loss_param": {
"type": "object",
"properties": {
"aft_loss_distribution": {
"type": "string"
},
"aft_loss_distribution_scale": {
"type": "string"
}
}
},
"softmax_multiclass_param": { "softmax_multiclass_param": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -273,6 +284,17 @@
"reg_loss_param" "reg_loss_param"
] ]
}, },
{
"type": "object",
"properties": {
"name": { "const": "reg:pseudohubererror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -284,6 +306,17 @@
"reg_loss_param" "reg_loss_param"
] ]
}, },
{
"type": "object",
"properties": {
"name": { "const": "reg:linear" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -420,6 +453,19 @@
"name", "name",
"lambda_rank_param" "lambda_rank_param"
] ]
},
{
"type": "object",
"properties": {
"name": {"const": "survival:aft"},
"aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}
}
},
{
"type": "object",
"properties": {
"name": {"const": "binary:hinge"}
}
} }
] ]
}, },

View File

@ -346,6 +346,25 @@ class TestModels(unittest.TestCase):
schema=schema) schema=schema)
os.remove(model_path) os.remove(model_path)
try:
xgb.train({'objective': 'foo'}, dtrain, num_boost_round=1)
except ValueError as e:
e_str = str(e)
beg = e_str.find('Objective candidate')
end = e_str.find('Stack trace')
e_str = e_str[beg: end]
e_str = e_str.strip()
splited = e_str.splitlines()
objectives = [s.split(': ')[1] for s in splited]
j_objectives = schema['properties']['learner']['properties'][
'objective']['oneOf']
objectives_from_schema = set()
for j_obj in j_objectives:
objectives_from_schema.add(
j_obj['properties']['name']['const'])
objectives = set(objectives)
assert objectives == objectives_from_schema
@pytest.mark.skipif(**tm.no_json_schema()) @pytest.mark.skipif(**tm.no_json_schema())
def test_json_dump_schema(self): def test_json_dump_schema(self):
import jsonschema import jsonschema