diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index ea87a0a50..1ee6d9fd7 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -59,7 +59,7 @@ def plot_importance(booster, ax=None, height=0.2, raise ImportError('You must install matplotlib to plot importance') if isinstance(booster, XGBModel): - importance = booster.booster().get_score(importance_type=importance_type) + importance = booster.get_booster().get_score(importance_type=importance_type) elif isinstance(booster, Booster): importance = booster.get_score(importance_type=importance_type) elif isinstance(booster, dict): @@ -196,7 +196,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', raise ValueError('booster must be Booster or XGBModel instance') if isinstance(booster, XGBModel): - booster = booster.booster() + booster = booster.get_booster() tree = booster.get_dump(fmap=fmap)[num_trees] tree = tree.split() diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 590bec545..919d929d4 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -65,6 +65,8 @@ class XGBModel(XGBModelBase): objective : string or callable Specify the learning task and the corresponding learning objective or a custom objective function to be used (see note below). + booster: string + Specify which booster to use: gbtree, gblinear or dart. nthread : int Number of parallel threads used to run xgboost. gamma : float @@ -112,7 +114,7 @@ class XGBModel(XGBModelBase): """ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, - silent=True, objective="reg:linear", + silent=True, objective="reg:linear", booster='gbtree', nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, @@ -124,6 +126,7 @@ class XGBModel(XGBModelBase): self.n_estimators = n_estimators self.silent = silent self.objective = objective + self.booster = booster self.nthread = nthread self.gamma = gamma @@ -150,7 +153,7 @@ class XGBModel(XGBModelBase): state["_Booster"] = Booster(model_file=bst) self.__dict__.update(state) - def booster(self): + def get_booster(self): """Get the underlying xgboost Booster of this model. This will raise an exception when fit was not called @@ -270,9 +273,9 @@ class XGBModel(XGBModelBase): def predict(self, data, output_margin=False, ntree_limit=0): # pylint: disable=missing-docstring,invalid-name test_dmatrix = DMatrix(data, missing=self.missing) - return self.booster().predict(test_dmatrix, - output_margin=output_margin, - ntree_limit=ntree_limit) + return self.get_booster().predict(test_dmatrix, + output_margin=output_margin, + ntree_limit=ntree_limit) def apply(self, X, ntree_limit=0): """Return the predicted leaf every tree for each sample. @@ -293,9 +296,9 @@ class XGBModel(XGBModelBase): ``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering. """ test_dmatrix = DMatrix(X, missing=self.missing) - return self.booster().predict(test_dmatrix, - pred_leaf=True, - ntree_limit=ntree_limit) + return self.get_booster().predict(test_dmatrix, + pred_leaf=True, + ntree_limit=ntree_limit) def evals_result(self): """Return the evaluation results. @@ -341,7 +344,7 @@ class XGBModel(XGBModelBase): feature_importances_ : array of shape = [n_features] """ - b = self.booster() + b = self.get_booster() fs = b.get_fscore() all_features = [fs.get(f, 0.) for f in b.feature_names] all_features = np.array(all_features, dtype=np.float32) @@ -356,13 +359,13 @@ class XGBClassifier(XGBModel, XGBClassifierBase): def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, - objective="binary:logistic", + objective="binary:logistic", booster='gbtree', nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, reg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5, seed=0, missing=None): super(XGBClassifier, self).__init__(max_depth, learning_rate, - n_estimators, silent, objective, + n_estimators, silent, objective, booster, nthread, gamma, min_child_weight, max_delta_step, subsample, colsample_bytree, colsample_bylevel, @@ -479,9 +482,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): def predict(self, data, output_margin=False, ntree_limit=0): test_dmatrix = DMatrix(data, missing=self.missing) - class_probs = self.booster().predict(test_dmatrix, - output_margin=output_margin, - ntree_limit=ntree_limit) + class_probs = self.get_booster().predict(test_dmatrix, + output_margin=output_margin, + ntree_limit=ntree_limit) if len(class_probs.shape) > 1: column_indexes = np.argmax(class_probs, axis=1) else: @@ -491,9 +494,9 @@ class XGBClassifier(XGBModel, XGBClassifierBase): def predict_proba(self, data, output_margin=False, ntree_limit=0): test_dmatrix = DMatrix(data, missing=self.missing) - class_probs = self.booster().predict(test_dmatrix, - output_margin=output_margin, - ntree_limit=ntree_limit) + class_probs = self.get_booster().predict(test_dmatrix, + output_margin=output_margin, + ntree_limit=ntree_limit) if self.objective == "multi:softprob": return class_probs else: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 12726b002..b5fc29411 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -221,12 +221,29 @@ def test_sklearn_api(): iris = load_iris() tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120) - classifier = xgb.XGBClassifier() + classifier = xgb.XGBClassifier(booster='gbtree', n_estimators=10) classifier.fit(tr_d, tr_l) preds = classifier.predict(te_d) labels = te_l - err = sum([1 for p, l in zip(preds, labels) if p != l]) / len(te_l) + err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l) + assert err < 0.2 + + +def test_sklearn_api_gblinear(): + tm._skip_if_no_sklearn() + from sklearn.datasets import load_iris + from sklearn.cross_validation import train_test_split + + iris = load_iris() + tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120) + + classifier = xgb.XGBClassifier(booster='gblinear', n_estimators=100) + classifier.fit(tr_d, tr_l) + + preds = classifier.predict(te_d) + labels = te_l + err = sum([1 for p, l in zip(preds, labels) if p != l]) * 1.0 / len(te_l) assert err < 0.2