Fix #3648: XGBClassifier.predict() should return margin scores when output_margin=True (#3651)

* Fix #3648: XGBClassifier.predict() should return margin scores when output_margin=True

* Fix tests to reflect correct implementation of XGBClassifier.predict(output_margin=True)

* Fix flaky test test_with_sklearn.test_sklearn_api_gblinear
This commit is contained in:
Philip Hyunsu Cho 2018-08-30 21:05:05 -07:00 committed by GitHub
parent 5b662cbe1c
commit 86d88c0758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 8 deletions

View File

@ -652,6 +652,10 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features)
if output_margin:
# If output_margin is active, simply return the scores
return class_probs
if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1)
else:

View File

@ -53,9 +53,13 @@ def test_multiclass_classification():
except:
from sklearn.model_selection import KFold
def check_pred(preds, labels):
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
def check_pred(preds, labels, output_margin):
if output_margin:
err = sum(1 for i in range(len(preds))
if preds[i].argmax() != labels[i]) / float(len(preds))
else:
err = sum(1 for i in range(len(preds))
if preds[i] != labels[i]) / float(len(preds))
assert err < 0.4
iris = load_iris()
@ -71,10 +75,10 @@ def test_multiclass_classification():
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
labels = y[test_index]
check_pred(preds, labels)
check_pred(preds2, labels)
check_pred(preds3, labels)
check_pred(preds4, labels)
check_pred(preds, labels, output_margin=False)
check_pred(preds2, labels, output_margin=True)
check_pred(preds3, labels, output_margin=True)
check_pred(preds4, labels, output_margin=False)
def test_ranking():
@ -287,7 +291,7 @@ def test_sklearn_api_gblinear():
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
assert err < 0.2
assert err < 0.5
def test_sklearn_plotting():