From 136e902fb2b966d938ca10a9a787807ee4eef503 Mon Sep 17 00:00:00 2001 From: Jamie Hall Date: Wed, 1 Apr 2015 23:29:05 -0700 Subject: [PATCH 1/2] Initial commit --- demo/guide-python/sklearn_examples.py | 42 +++++++++ wrapper/xgboost.py | 127 ++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 demo/guide-python/sklearn_examples.py diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py new file mode 100644 index 000000000..302361876 --- /dev/null +++ b/demo/guide-python/sklearn_examples.py @@ -0,0 +1,42 @@ +''' +Created on 1 Apr 2015 + +@author: Jamie Hall +''' + +import sys +sys.path.append('../../wrapper') +import xgboost as xgb + +import numpy as np +from sklearn.cross_validation import KFold +from sklearn.grid_search import GridSearchCV +from sklearn.metrics import confusion_matrix +from sklearn.datasets import load_iris, load_digits, load_boston + +rng = np.random.RandomState(31337) + + +print("Zeros and Ones from the Digits dataset: binary classification") +digits = load_digits(2) +y = digits['target'] +X = digits['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(confusion_matrix(actuals, predictions)) + +print("Iris: multiclass classification") +iris = load_iris() +y = iris['target'] +X = iris['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(confusion_matrix(actuals, predictions)) + + diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index affda3ca7..96b027bec 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -14,6 +14,15 @@ import collections import numpy as np import scipy.sparse +try: + from sklearn.base import BaseEstimator + from sklearn.preprocessing import LabelEncoder + SKLEARN_INSTALLED = True +except ImportError: + SKLEARN_INSTALLED = False + + + __all__ = ['DMatrix', 'CVPack', 'Booster', 'aggcv', 'cv', 'mknfold', 'train'] if sys.version_info[0] == 3: @@ -660,3 +669,121 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), sys.stderr.write(res + '\n') results.append(res) return results + + +XGBModelBase = object +if SKLEARN_INSTALLED: + XGBModelBase = BaseEstimator + + +class XGBModel(BaseEstimator): + """ + Implementation of the Scikit-Learn API for XGBoost. + + Parameters + ---------- + max_depth : int + Maximum tree depth for base learners. + learning_rate : float + Boosting learning rate (xgb's "eta") + n_estimators : int + Number of boosted trees to fit. + silent : boolean + Whether to print messages while running boosting. + """ + def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="reg:linear"): + if not SKLEARN_INSTALLED: + raise Exception('sklearn needs to be installed in order to use this module') + self.max_depth = max_depth + self.eta = learning_rate + self.silent = 1 if silent else 0 + self.n_rounds = n_estimators + self.objective = objective + self._Booster = Booster() + + def get_params(self, deep=True): + return {'max_depth': self.max_depth, + 'learning_rate': self.eta, + 'n_estimators': self.n_rounds, + 'silent': True if self.silent == 1 else False, + 'objective': self.objective + } + def get_xgb_params(self): + return {'eta': self.eta, 'max_depth': self.max_depth, 'silent': self.silent, 'objective': self.objective} + + def fit(self, X, y): + trainDmatrix = DMatrix(X, label=y) + self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds) + return self + +class XGBClassifier(XGBModel): + def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True): + super().__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic") + + def fit(self, X, y, sample_weight=None): + y_values = list(np.unique(y)) + if len(y_values) == 2: + # Map the two classes in the y vector into {0,1}, and record the mapping so that + # the predict() method can return results in the original range + if not (-1 in y_values and 1 in y_values) or (0 in y_values and 1 in y_values) or (True in y_values and False in y_values): + raise ValueError("For a binary classifier, y must be in (0,1), or (-1,1), or (True,False).") + if -1 in y_values: + self._yspace = "svm_like" + training_labels = y.copy() + training_labels[training_labels == -1] = 0 + elif False in y_values: + self._yspace = "boolean" + training_labels = np.array(y, dtype=int) + else: + self._yspace = "zero_one" + training_labels = y + xgb_options = self.get_xgb_params() + else: + # Switch to using a multiclass objective in the underlying XGB instance + self._yspace = "multiclass" + self.objective = "multi:softprob" + self._le = LabelEncoder().fit(y) + training_labels = self._le.transform(y) + xgb_options = self.get_xgb_params() + xgb_options['num_class'] = len(y_values) + if sample_weight is not None: + trainDmatrix = DMatrix(X, label=training_labels, weight=sample_weight) + else: + trainDmatrix = DMatrix(X, label=training_labels) + self._Booster = train(xgb_options, trainDmatrix, self.n_rounds) + + return self + + def predict(self, X): + testDmatrix = DMatrix(X) + class_probs = self._Booster.predict(testDmatrix) + if self._yspace == "multiclass": + column_indexes = np.argmax(class_probs, axis=1) + fitted_values = self._le.inverse_transform(column_indexes) + else: + if self._yspace == "svm_like": + base_value = -1 + one_value = 1 + elif self._yspace == "boolean": + base_value = False + one_value = True + else: + base_value = 0 + one_value = 1 + fitted_values = np.repeat(base_value, X.shape[0]) + fitted_values[class_probs > 0.5] = one_value + return fitted_values + + def predict_proba(self, X): + testDmatrix = DMatrix(X) + class_probs = self._Booster.predict(testDmatrix) + if self._yspace == "multiclass": + return class_probs + else: + classone_probs = class_probs + classzero_probs = 1.0 - classone_probs + return np.vstack((classzero_probs,classone_probs)).transpose() + + + + From a1a427af37807e1be00f942ac77d50ecd81963ac Mon Sep 17 00:00:00 2001 From: Jamie Hall Date: Thu, 2 Apr 2015 00:05:14 -0700 Subject: [PATCH 2/2] Fix some stuff --- demo/guide-python/sklearn_examples.py | 24 +++++++++++- wrapper/xgboost.py | 56 ++++++++++----------------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py index 302361876..b30d785fa 100644 --- a/demo/guide-python/sklearn_examples.py +++ b/demo/guide-python/sklearn_examples.py @@ -11,7 +11,7 @@ import xgboost as xgb import numpy as np from sklearn.cross_validation import KFold from sklearn.grid_search import GridSearchCV -from sklearn.metrics import confusion_matrix +from sklearn.metrics import confusion_matrix, mean_squared_error from sklearn.datasets import load_iris, load_digits, load_boston rng = np.random.RandomState(31337) @@ -39,4 +39,26 @@ for train_index, test_index in kf: actuals = y[test_index] print(confusion_matrix(actuals, predictions)) +print("Boston Housing: regression") +boston = load_boston() +y = boston['target'] +X = boston['data'] +kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) +for train_index, test_index in kf: + xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index]) + predictions = xgb_model.predict(X[test_index]) + actuals = y[test_index] + print(mean_squared_error(actuals, predictions)) + +print("Parameter optimization") +y = boston['target'] +X = boston['data'] +xgb_model = xgb.XGBRegressor() +clf = GridSearchCV(xgb_model, + {'max_depth': [2,4,6], + 'n_estimators': [50,100,200]}, verbose=1) +clf.fit(X,y) +print(clf.best_score_) +print(clf.best_params_) + diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 96b027bec..ef841da14 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -16,6 +16,7 @@ import scipy.sparse try: from sklearn.base import BaseEstimator + from sklearn.base import RegressorMixin, ClassifierMixin from sklearn.preprocessing import LabelEncoder SKLEARN_INSTALLED = True except ImportError: @@ -715,41 +716,33 @@ class XGBModel(BaseEstimator): trainDmatrix = DMatrix(X, label=y) self._Booster = train(self.get_xgb_params(), trainDmatrix, self.n_rounds) return self + + def predict(self, X): + testDmatrix = DMatrix(X) + return self._Booster.predict(testDmatrix) -class XGBClassifier(XGBModel): +class XGBClassifier(XGBModel, ClassifierMixin): def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True): super().__init__(max_depth, learning_rate, n_estimators, silent, objective="binary:logistic") def fit(self, X, y, sample_weight=None): y_values = list(np.unique(y)) - if len(y_values) == 2: - # Map the two classes in the y vector into {0,1}, and record the mapping so that - # the predict() method can return results in the original range - if not (-1 in y_values and 1 in y_values) or (0 in y_values and 1 in y_values) or (True in y_values and False in y_values): - raise ValueError("For a binary classifier, y must be in (0,1), or (-1,1), or (True,False).") - if -1 in y_values: - self._yspace = "svm_like" - training_labels = y.copy() - training_labels[training_labels == -1] = 0 - elif False in y_values: - self._yspace = "boolean" - training_labels = np.array(y, dtype=int) - else: - self._yspace = "zero_one" - training_labels = y - xgb_options = self.get_xgb_params() - else: + if len(y_values) > 2: # Switch to using a multiclass objective in the underlying XGB instance - self._yspace = "multiclass" self.objective = "multi:softprob" - self._le = LabelEncoder().fit(y) - training_labels = self._le.transform(y) xgb_options = self.get_xgb_params() xgb_options['num_class'] = len(y_values) + else: + xgb_options = self.get_xgb_params() + + self._le = LabelEncoder().fit(y) + training_labels = self._le.transform(y) + if sample_weight is not None: trainDmatrix = DMatrix(X, label=training_labels, weight=sample_weight) else: trainDmatrix = DMatrix(X, label=training_labels) + self._Booster = train(xgb_options, trainDmatrix, self.n_rounds) return self @@ -757,22 +750,12 @@ class XGBClassifier(XGBModel): def predict(self, X): testDmatrix = DMatrix(X) class_probs = self._Booster.predict(testDmatrix) - if self._yspace == "multiclass": + if len(class_probs.shape) > 1: column_indexes = np.argmax(class_probs, axis=1) - fitted_values = self._le.inverse_transform(column_indexes) else: - if self._yspace == "svm_like": - base_value = -1 - one_value = 1 - elif self._yspace == "boolean": - base_value = False - one_value = True - else: - base_value = 0 - one_value = 1 - fitted_values = np.repeat(base_value, X.shape[0]) - fitted_values[class_probs > 0.5] = one_value - return fitted_values + column_indexes = np.repeat(0, X.shape[0]) + column_indexes[class_probs > 0.5] = 1 + return self._le.inverse_transform(column_indexes) def predict_proba(self, X): testDmatrix = DMatrix(X) @@ -784,6 +767,7 @@ class XGBClassifier(XGBModel): classzero_probs = 1.0 - classone_probs return np.vstack((classzero_probs,classone_probs)).transpose() - +class XGBRegressor(XGBModel, RegressorMixin): + pass