Add option to choose booster in scikit intreface (gbtree by default) (#2303)

* Add option to choose booster in scikit intreface (gbtree by default)

* Add option to choose booster in scikit intreface: complete docstring.

* Fix XGBClassifier to work with booster option

* Added test case for gblinear booster
This commit is contained in:
jayzed82
2017-05-19 05:12:27 +02:00
committed by Yuan (Terry) Tang
parent 96f9776ab0
commit 29289d2302
3 changed files with 41 additions and 21 deletions

View File

@@ -221,12 +221,29 @@ def test_sklearn_api():
iris = load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
classifier = xgb.XGBClassifier()
classifier = xgb.XGBClassifier(booster='gbtree', n_estimators=10)
classifier.fit(tr_d, tr_l)
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
assert err < 0.2
def test_sklearn_api_gblinear():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
iris = load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
classifier = xgb.XGBClassifier(booster='gblinear', n_estimators=100)
classifier.fit(tr_d, tr_l)
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
assert err < 0.2