Enable loading model from <1.0.0 trained with objective='binary:logitraw' (#6517)

* Enable loading model from <1.0.0 trained with objective='binary:logitraw'

* Add binary:logitraw in model compatibility testing suite

* Feedback from @trivialfis: Override ProbToMargin() for LogisticRaw

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Philip Hyunsu Cho
2020-12-16 16:53:46 -08:00
committed by GitHub
parent bf6cfe3b99
commit ad1a527709
5 changed files with 45 additions and 21 deletions

View File

@@ -24,6 +24,10 @@ def run_booster_check(booster, name):
config['learner']['learner_model_param']['base_score']) == 0.5
assert config['learner']['learner_train_param'][
'objective'] == 'multi:softmax'
elif name.find('logitraw') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
assert config['learner']['learner_train_param']['objective'] == 'binary:logitraw'
elif name.find('logit') != -1:
assert len(booster.get_dump()) == gm.kForests * gm.kRounds
assert config['learner']['learner_model_param']['num_class'] == str(0)
@@ -77,6 +81,13 @@ def run_scikit_model_check(name, path):
assert config['learner']['learner_train_param'][
'objective'] == 'rank:ndcg'
run_model_param_check(config)
elif name.find('logitraw') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)
assert (len(logit.get_booster().get_dump()) ==
gm.kRounds * gm.kForests)
config = json.loads(logit.get_booster().save_config())
assert config['learner']['learner_train_param']['objective'] == 'binary:logitraw'
elif name.find('logit') != -1:
logit = xgboost.XGBClassifier()
logit.load_model(path)