checkin all python

This commit is contained in:
tqchen 2015-07-30 22:08:48 -07:00
parent c2fec29bfa
commit 60217a2c02
4 changed files with 608 additions and 4 deletions

7
.gitignore vendored
View File

@ -48,10 +48,9 @@ Debug
*.cpage.col *.cpage.col
*.cpage *.cpage
*.Rproj *.Rproj
xgboost ./xgboost
xgboost.mpi ./xgboost.mpi
xgboost.mock ./xgboost.mock
train*
rabit rabit
#.Rbuildignore #.Rbuildignore
R-package.Rproj R-package.Rproj

View File

@ -0,0 +1,12 @@
# coding: utf-8
"""XGBoost: eXtreme Gradient Boosting library.
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
"""
from __future__ import absolute_import
from .core import DMatrix, Booster
from .training import train, cv
from .sklearn import XGBModel, XGBClassifier, XGBRegressor
__version__ = '0.4'

View File

@ -0,0 +1,341 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
"""Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import
import numpy as np
from .core import Booster, DMatrix, XGBoostError
from .training import train
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
SKLEARN_INSTALLED = True
except ImportError:
SKLEARN_INSTALLED = False
# used for compatiblity without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
if SKLEARN_INSTALLED:
XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
"""Implementation of the Scikit-Learn API for XGBoost.
Parameters
----------
max_depth : int
Maximum tree depth for base learners.
learning_rate : float
Boosting learning rate (xgb's "eta")
n_estimators : int
Number of boosted trees to fit.
silent : boolean
Whether to print messages while running boosting.
objective : string
Specify the learning task and the corresponding learning objective.
nthread : int
Number of parallel threads used to run xgboost.
gamma : float
Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : int
Maximum delta step we allow each tree's weight estimation to be.
subsample : float
Subsample ratio of the training instance.
colsample_bytree : float
Subsample ratio of columns when constructing each tree.
base_score:
The initial prediction score of all instances, global bias.
seed : int
Random number seed.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear",
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,
subsample=1, colsample_bytree=1,
base_score=0.5, seed=0, missing=None):
if not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use this module')
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.silent = silent
self.objective = objective
self.nthread = nthread
self.gamma = gamma
self.min_child_weight = min_child_weight
self.max_delta_step = max_delta_step
self.subsample = subsample
self.colsample_bytree = colsample_bytree
self.base_score = base_score
self.seed = seed
self.missing = missing if missing is not None else np.nan
self._Booster = None
def __setstate__(self, state):
# backward compatiblity code
# load booster from raw if it is raw
# the booster now support pickle
bst = state["_Booster"]
if bst is not None and not isinstance(bst, Booster):
state["_Booster"] = Booster(model_file=bst)
self.__dict__.update(state)
def booster(self):
"""Get the underlying xgboost Booster of this model.
This will raise an exception when fit was not called
Returns
-------
booster : a xgboost booster of underlying model
"""
if self._Booster is None:
raise XGBoostError('need to call fit beforehand')
return self._Booster
def get_params(self, deep=False):
"""Get parameter.s"""
params = super(XGBModel, self).get_params(deep=deep)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
return params
def get_xgb_params(self):
"""Get xgboost type parameters."""
xgb_params = self.get_params()
xgb_params['silent'] = 1 if self.silent else 0
if self.nthread <= 0:
xgb_params.pop('nthread', None)
return xgb_params
def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
"""
Fit the gradient boosting model
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
eval_set : list, optional
A list of (X, y) tuple pairs to use as a validation set for
early-stopping
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have two additional fields: bst.best_score and bst.best_iteration.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
"""
trainDmatrix = DMatrix(X, label=y, missing=self.missing)
eval_results = {}
if eval_set is not None:
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
evals = list(zip(evals, ["validation_{}".format(i) for i in
range(len(evals))]))
else:
evals = ()
params = self.get_xgb_params()
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
params.update({'eval_metric': eval_metric})
self._Booster = train(params, trainDmatrix,
self.n_estimators, evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=eval_results, feval=feval,
verbose_eval=verbose)
if eval_results:
eval_results = {k: np.array(v, dtype=float)
for k, v in eval_results.items()}
eval_results = {k: np.array(v) for k, v in eval_results.items()}
self.eval_results = eval_results
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
return self
def predict(self, data):
# pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(data, missing=self.missing)
return self.booster().predict(test_dmatrix)
class XGBClassifier(XGBModel, XGBClassifierBase):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
__doc__ = """
Implementation of the scikit-learn API for XGBoost classification
""" + "\n".join(XGBModel.__doc__.split('\n')[2:])
def __init__(self, max_depth=3, learning_rate=0.1,
n_estimators=100, silent=True,
objective="binary:logistic",
nthread=-1, gamma=0, min_child_weight=1,
max_delta_step=0, subsample=1, colsample_bytree=1,
base_score=0.5, seed=0, missing=None):
super(XGBClassifier, self).__init__(max_depth, learning_rate,
n_estimators, silent, objective,
nthread, gamma, min_child_weight,
max_delta_step, subsample,
colsample_bytree,
base_score, seed, missing)
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True):
# pylint: disable = attribute-defined-outside-init,arguments-differ
"""
Fit gradient boosting classifier
Parameters
----------
X : array_like
Feature matrix
y : array_like
Labels
sample_weight : array_like
Weight for each instance
eval_set : list, optional
A list of (X, y) pairs to use as a validation set for
early-stopping
eval_metric : str, callable, optional
If a str, should be a built-in evaluation metric to use. See
doc/parameter.md. If callable, a custom evaluation metric. The call
signature is func(y_predicted, y_true) where y_true will be a
DMatrix object such that you may need to call the get_label
method. It must return a str, value pair where the str is a name
for the evaluation and value is the value of the evaluation
function. This objective is always minimized.
early_stopping_rounds : int, optional
Activates early stopping. Validation error needs to decrease at
least every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals. If there's more than one,
will use the last. Returns the model from the last iteration
(not the best one). If early stopping occurs, the model will
have two additional fields: bst.best_score and bst.best_iteration.
verbose : bool
If `verbose` and an evaluation set is used, writes the evaluation
metric measured on the validation set to stderr.
"""
eval_results = {}
self.classes_ = list(np.unique(y))
self.n_classes_ = len(self.classes_)
if self.n_classes_ > 2:
# Switch to using a multiclass objective in the underlying XGB instance
self.objective = "multi:softprob"
xgb_options = self.get_xgb_params()
xgb_options['num_class'] = self.n_classes_
else:
xgb_options = self.get_xgb_params()
feval = eval_metric if callable(eval_metric) else None
if eval_metric is not None:
if callable(eval_metric):
eval_metric = None
else:
xgb_options.update({"eval_metric": eval_metric})
if eval_set is not None:
# TODO: use sample_weight if given?
evals = list(DMatrix(x[0], label=x[1]) for x in eval_set)
nevals = len(evals)
eval_names = ["validation_{}".format(i) for i in range(nevals)]
evals = list(zip(evals, eval_names))
else:
evals = ()
self._le = LabelEncoder().fit(y)
training_labels = self._le.transform(y)
if sample_weight is not None:
train_dmatrix = DMatrix(X, label=training_labels, weight=sample_weight,
missing=self.missing)
else:
train_dmatrix = DMatrix(X, label=training_labels,
missing=self.missing)
self._Booster = train(xgb_options, train_dmatrix, self.n_estimators,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=eval_results, feval=feval,
verbose_eval=verbose)
if eval_results:
eval_results = {k: np.array(v, dtype=float)
for k, v in eval_results.items()}
self.eval_results = eval_results
if early_stopping_rounds is not None:
self.best_score = self._Booster.best_score
self.best_iteration = self._Booster.best_iteration
return self
def predict(self, data):
test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix)
if len(class_probs.shape) > 1:
column_indexes = np.argmax(class_probs, axis=1)
else:
column_indexes = np.repeat(0, data.shape[0])
column_indexes[class_probs > 0.5] = 1
return self._le.inverse_transform(column_indexes)
def predict_proba(self, data):
test_dmatrix = DMatrix(data, missing=self.missing)
class_probs = self.booster().predict(test_dmatrix)
if self.objective == "multi:softprob":
return class_probs
else:
classone_probs = class_probs
classzero_probs = 1.0 - classone_probs
return np.vstack((classzero_probs, classone_probs)).transpose()
class XGBRegressor(XGBModel, XGBRegressorBase):
# pylint: disable=missing-docstring
__doc__ = """
Implementation of the scikit-learn API for XGBoost regression
""" + "\n".join(XGBModel.__doc__.split('\n')[2:])

View File

@ -0,0 +1,252 @@
# coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
"""Training Library containing training routines."""
from __future__ import absolute_import
import sys
import re
import numpy as np
from .core import Booster, STRING_TYPES
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
early_stopping_rounds=None, evals_result=None, verbose_eval=True):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters.
Parameters
----------
params : dict
Booster params.
dtrain : DMatrix
Data to be trained.
num_boost_round: int
Number of boosting iterations.
watchlist (evals): list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
obj : function
Customized objective function.
feval : function
Customized evaluation function.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration.
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist
verbose_eval : bool
If `verbose_eval` then the evaluation metric on the validation set, if
given, is printed at each boosting stage.
Returns
-------
booster : a trained booster model
"""
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
if evals_result is not None:
if not isinstance(evals_result, dict):
raise TypeError('evals_result has to be a dictionary')
else:
evals_name = [d[1] for d in evals]
evals_result.clear()
evals_result.update({key: [] for key in evals_name})
if not early_stopping_rounds:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
if verbose_eval:
sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":-?([0-9.]+).", msg)
for key, val in zip(evals_name, res):
evals_result[key].append(val)
return bst
else:
# early stopping
if len(evals) < 1:
raise ValueError('For early stopping you need at least one set in evals.')
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(\
evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics?
if isinstance(params, list):
if len(params) != len(dict(params).items()):
raise ValueError('Check your params.'\
'Early stopping works with single eval metric only.')
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False
if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg')
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
maximize_score = True
if maximize_score:
best_score = 0.0
else:
best_score = float('inf')
best_msg = ''
best_score_i = 0
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
if verbose_eval:
sys.stderr.write(msg + '\n')
if evals_result is not None:
res = re.findall(":-([0-9.]+).", msg)
for key, val in zip(evals_name, res):
evals_result[key].append(val)
score = float(msg.rsplit(':', 1)[1])
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = i
best_msg = msg
elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
bst.best_score = best_score
bst.best_iteration = best_score_i
break
bst.best_score = best_score
bst.best_iteration = best_score_i
return bst
class CVPack(object):
""""Auxiliary datastruct to hold one fold of CV."""
def __init__(self, dtrain, dtest, param):
""""Initialize the CVPack"""
self.dtrain = dtrain
self.dtest = dtest
self.watchlist = [(dtrain, 'train'), (dtest, 'test')]
self.bst = Booster(param, [dtrain, dtest])
def update(self, iteration, fobj):
""""Update the boosters for one iteration"""
self.bst.update(self.dtrain, iteration, fobj)
def eval(self, iteration, feval):
""""Evaluate the CVPack for one iteration."""
return self.bst.eval_set(self.watchlist, iteration, feval)
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
"""
Make an n-fold list of CVPack from random indices.
"""
evals = list(evals)
np.random.seed(seed)
randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
ret = []
for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
dtest = dall.slice(idset[k])
# run preprocessing on the data set if needed
if fpreproc is not None:
dtrain, dtest, tparam = fpreproc(dtrain, dtest, param.copy())
else:
tparam = param
plst = list(tparam.items()) + [('eval_metric', itm) for itm in evals]
ret.append(CVPack(dtrain, dtest, plst))
return ret
def aggcv(rlist, show_stdv=True):
# pylint: disable=invalid-name
"""
Aggregate cross-validation results.
"""
cvmap = {}
ret = rlist[0].split()[0]
for line in rlist:
arr = line.split()
assert ret == arr[0]
for it in arr[1:]:
if not isinstance(it, STRING_TYPES):
it = it.decode()
k, v = it.split(':')
if k not in cvmap:
cvmap[k] = []
cvmap[k].append(float(v))
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
v = np.array(v)
if not isinstance(ret, STRING_TYPES):
ret = ret.decode()
if show_stdv:
ret += '\tcv-%s:%f+%f' % (k, np.mean(v), np.std(v))
else:
ret += '\tcv-%s:%f' % (k, np.mean(v))
return ret
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, fpreproc=None, show_stdv=True, seed=0):
# pylint: disable = invalid-name
"""Cross-validation with given paramaters.
Parameters
----------
params : dict
Booster params.
dtrain : DMatrix
Data to be trained.
num_boost_round : int
Number of boosting iterations.
nfold : int
Number of folds in CV.
metrics : list of strings
Evaluation metrics to be watched in CV.
obj : function
Custom objective function.
feval : function
Custom evaluation function.
fpreproc : function
Preprocessing function that takes (dtrain, dtest, param) and returns
transformed versions of those.
show_stdv : bool
Whether to display the standard deviation.
seed : int
Seed used to generate the folds (passed to numpy.random.seed).
Returns
-------
evaluation history : list(string)
"""
results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
for i in range(num_boost_round):
for fold in cvfolds:
fold.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds], show_stdv)
sys.stderr.write(res + '\n')
results.append(res)
return results