Add option to choose booster in scikit intreface (gbtree by default) (#2303)

* Add option to choose booster in scikit intreface (gbtree by default)

* Add option to choose booster in scikit intreface: complete docstring.

* Fix XGBClassifier to work with booster option

* Added test case for gblinear booster
This commit is contained in:
jayzed82 2017-05-19 05:12:27 +02:00 committed by Yuan (Terry) Tang
parent 96f9776ab0
commit 29289d2302
3 changed files with 41 additions and 21 deletions

View File

@ -59,7 +59,7 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel):
importance = booster.booster().get_score(importance_type=importance_type)
importance = booster.get_booster().get_score(importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict):
@ -196,7 +196,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
raise ValueError('booster must be Booster or XGBModel instance')
if isinstance(booster, XGBModel):
booster = booster.booster()
booster = booster.get_booster()
tree = booster.get_dump(fmap=fmap)[num_trees]
tree = tree.split()

View File

@ -65,6 +65,8 @@ class XGBModel(XGBModelBase):
objective : string or callable
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
booster: string
Specify which booster to use: gbtree, gblinear or dart.
nthread : int
Number of parallel threads used to run xgboost.
gamma : float
@ -112,7 +114,7 @@ class XGBModel(XGBModelBase):
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear",
silent=True, objective="reg:linear", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
@ -124,6 +126,7 @@ class XGBModel(XGBModelBase):
self.n_estimators = n_estimators
self.silent = silent
self.objective = objective
self.booster = booster
self.nthread = nthread
self.gamma = gamma
@ -150,7 +153,7 @@ class XGBModel(XGBModelBase):
state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state)
def booster(self):
def get_booster(self):
"""Get the underlying xgboost Booster of this model.
This will raise an exception when fit was not called
@ -270,9 +273,9 @@ class XGBModel(XGBModelBase):
def predict(self, data, output_margin=False, ntree_limit=0):
# pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing)
return self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
return self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
def apply(self, X, ntree_limit=0):
"""Return the predicted leaf every tree for each sample.
@ -293,9 +296,9 @@ class XGBModel(XGBModelBase):
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
"""
test_dmatrix = DMatrix(X, missing=self.missing)
return self.booster().predict(test_dmatrix,
pred_leaf=True,
ntree_limit=ntree_limit)
return self.get_booster().predict(test_dmatrix,
pred_leaf=True,
ntree_limit=ntree_limit)
def evals_result(self):
"""Return the evaluation results.
@ -341,7 +344,7 @@ class XGBModel(XGBModelBase):
feature_importances_ : array of shape = [n_features]
"""
b = self.booster()
b = self.get_booster()
fs = b.get_fscore()
all_features = [fs.get(f, 0.) for f in b.feature_names]
all_features = np.array(all_features, dtype=np.float32)
@ -356,13 +359,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
objective="binary:logistic",
objective="binary:logistic", booster='gbtree',
nthread=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
base_score=0.5, seed=0, missing=None):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective,
n_estimators, silent, objective, booster,
nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree, colsample_bylevel,
@ -479,9 +482,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def predict(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1)
else:
@ -491,9 +494,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def predict_proba(self, data, output_margin=False, ntree_limit=0):
test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
class_probs = self.get_booster().predict(test_dmatrix,
output_margin=output_margin,
ntree_limit=ntree_limit)
if self.objective == "multi:softprob":
return class_probs
else:

View File

@ -221,12 +221,29 @@ def test_sklearn_api():
iris = load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
classifier = xgb.XGBClassifier()
classifier = xgb.XGBClassifier(booster='gbtree', n_estimators=10)
classifier.fit(tr_d, tr_l)
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
assert err < 0.2
def test_sklearn_api_gblinear():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
iris = load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
classifier = xgb.XGBClassifier(booster='gblinear', n_estimators=100)
classifier.fit(tr_d, tr_l)
preds = classifier.predict(te_d)
labels = te_l
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
assert err < 0.2