Add option to choose booster in scikit intreface (gbtree by default) (#2303)
* Add option to choose booster in scikit intreface (gbtree by default) * Add option to choose booster in scikit intreface: complete docstring. * Fix XGBClassifier to work with booster option * Added test case for gblinear booster
This commit is contained in:
parent
96f9776ab0
commit
29289d2302
@ -59,7 +59,7 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
raise ImportError('You must install matplotlib to plot importance')
|
raise ImportError('You must install matplotlib to plot importance')
|
||||||
|
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
importance = booster.booster().get_score(importance_type=importance_type)
|
importance = booster.get_booster().get_score(importance_type=importance_type)
|
||||||
elif isinstance(booster, Booster):
|
elif isinstance(booster, Booster):
|
||||||
importance = booster.get_score(importance_type=importance_type)
|
importance = booster.get_score(importance_type=importance_type)
|
||||||
elif isinstance(booster, dict):
|
elif isinstance(booster, dict):
|
||||||
@ -196,7 +196,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|||||||
raise ValueError('booster must be Booster or XGBModel instance')
|
raise ValueError('booster must be Booster or XGBModel instance')
|
||||||
|
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
booster = booster.booster()
|
booster = booster.get_booster()
|
||||||
|
|
||||||
tree = booster.get_dump(fmap=fmap)[num_trees]
|
tree = booster.get_dump(fmap=fmap)[num_trees]
|
||||||
tree = tree.split()
|
tree = tree.split()
|
||||||
|
|||||||
@ -65,6 +65,8 @@ class XGBModel(XGBModelBase):
|
|||||||
objective : string or callable
|
objective : string or callable
|
||||||
Specify the learning task and the corresponding learning objective or
|
Specify the learning task and the corresponding learning objective or
|
||||||
a custom objective function to be used (see note below).
|
a custom objective function to be used (see note below).
|
||||||
|
booster: string
|
||||||
|
Specify which booster to use: gbtree, gblinear or dart.
|
||||||
nthread : int
|
nthread : int
|
||||||
Number of parallel threads used to run xgboost.
|
Number of parallel threads used to run xgboost.
|
||||||
gamma : float
|
gamma : float
|
||||||
@ -112,7 +114,7 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
|
||||||
silent=True, objective="reg:linear",
|
silent=True, objective="reg:linear", booster='gbtree',
|
||||||
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
|
||||||
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
@ -124,6 +126,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.n_estimators = n_estimators
|
self.n_estimators = n_estimators
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.objective = objective
|
self.objective = objective
|
||||||
|
self.booster = booster
|
||||||
|
|
||||||
self.nthread = nthread
|
self.nthread = nthread
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@ -150,7 +153,7 @@ class XGBModel(XGBModelBase):
|
|||||||
state["_Booster"] = Booster(model_file=bst)
|
state["_Booster"] = Booster(model_file=bst)
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
|
|
||||||
def booster(self):
|
def get_booster(self):
|
||||||
"""Get the underlying xgboost Booster of this model.
|
"""Get the underlying xgboost Booster of this model.
|
||||||
|
|
||||||
This will raise an exception when fit was not called
|
This will raise an exception when fit was not called
|
||||||
@ -270,7 +273,7 @@ class XGBModel(XGBModelBase):
|
|||||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
# pylint: disable=missing-docstring,invalid-name
|
# pylint: disable=missing-docstring,invalid-name
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
return self.booster().predict(test_dmatrix,
|
return self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
|
|
||||||
@ -293,7 +296,7 @@ class XGBModel(XGBModelBase):
|
|||||||
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
|
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
|
||||||
"""
|
"""
|
||||||
test_dmatrix = DMatrix(X, missing=self.missing)
|
test_dmatrix = DMatrix(X, missing=self.missing)
|
||||||
return self.booster().predict(test_dmatrix,
|
return self.get_booster().predict(test_dmatrix,
|
||||||
pred_leaf=True,
|
pred_leaf=True,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
|
|
||||||
@ -341,7 +344,7 @@ class XGBModel(XGBModelBase):
|
|||||||
feature_importances_ : array of shape = [n_features]
|
feature_importances_ : array of shape = [n_features]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
b = self.booster()
|
b = self.get_booster()
|
||||||
fs = b.get_fscore()
|
fs = b.get_fscore()
|
||||||
all_features = [fs.get(f, 0.) for f in b.feature_names]
|
all_features = [fs.get(f, 0.) for f in b.feature_names]
|
||||||
all_features = np.array(all_features, dtype=np.float32)
|
all_features = np.array(all_features, dtype=np.float32)
|
||||||
@ -356,13 +359,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def __init__(self, max_depth=3, learning_rate=0.1,
|
def __init__(self, max_depth=3, learning_rate=0.1,
|
||||||
n_estimators=100, silent=True,
|
n_estimators=100, silent=True,
|
||||||
objective="binary:logistic",
|
objective="binary:logistic", booster='gbtree',
|
||||||
nthread=-1, gamma=0, min_child_weight=1,
|
nthread=-1, gamma=0, min_child_weight=1,
|
||||||
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
|
||||||
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
|
||||||
base_score=0.5, seed=0, missing=None):
|
base_score=0.5, seed=0, missing=None):
|
||||||
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
super(XGBClassifier, self).__init__(max_depth, learning_rate,
|
||||||
n_estimators, silent, objective,
|
n_estimators, silent, objective, booster,
|
||||||
nthread, gamma, min_child_weight,
|
nthread, gamma, min_child_weight,
|
||||||
max_delta_step, subsample,
|
max_delta_step, subsample,
|
||||||
colsample_bytree, colsample_bylevel,
|
colsample_bytree, colsample_bylevel,
|
||||||
@ -479,7 +482,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def predict(self, data, output_margin=False, ntree_limit=0):
|
def predict(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
class_probs = self.booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
if len(class_probs.shape) > 1:
|
if len(class_probs.shape) > 1:
|
||||||
@ -491,7 +494,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
def predict_proba(self, data, output_margin=False, ntree_limit=0):
|
||||||
test_dmatrix = DMatrix(data, missing=self.missing)
|
test_dmatrix = DMatrix(data, missing=self.missing)
|
||||||
class_probs = self.booster().predict(test_dmatrix,
|
class_probs = self.get_booster().predict(test_dmatrix,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
ntree_limit=ntree_limit)
|
ntree_limit=ntree_limit)
|
||||||
if self.objective == "multi:softprob":
|
if self.objective == "multi:softprob":
|
||||||
|
|||||||
@ -221,12 +221,29 @@ def test_sklearn_api():
|
|||||||
iris = load_iris()
|
iris = load_iris()
|
||||||
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
|
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
|
||||||
|
|
||||||
classifier = xgb.XGBClassifier()
|
classifier = xgb.XGBClassifier(booster='gbtree', n_estimators=10)
|
||||||
classifier.fit(tr_d, tr_l)
|
classifier.fit(tr_d, tr_l)
|
||||||
|
|
||||||
preds = classifier.predict(te_d)
|
preds = classifier.predict(te_d)
|
||||||
labels = te_l
|
labels = te_l
|
||||||
err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l)
|
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
|
||||||
|
assert err < 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sklearn_api_gblinear():
|
||||||
|
tm._skip_if_no_sklearn()
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.cross_validation import train_test_split
|
||||||
|
|
||||||
|
iris = load_iris()
|
||||||
|
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
|
||||||
|
|
||||||
|
classifier = xgb.XGBClassifier(booster='gblinear', n_estimators=100)
|
||||||
|
classifier.fit(tr_d, tr_l)
|
||||||
|
|
||||||
|
preds = classifier.predict(te_d)
|
||||||
|
labels = te_l
|
||||||
|
err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l)
|
||||||
assert err < 0.2
|
assert err < 0.2
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user