stratified cv for python wrapper

finalize docstring
This commit is contained in:
Faron 2016-02-14 11:00:41 +01:00
parent 9b2b81e6a4
commit 4b3a053913
3 changed files with 70 additions and 10 deletions

View File

@ -33,8 +33,11 @@ try:
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
from sklearn.cross_validation import KFold, StratifiedKFold
SKLEARN_INSTALLED = True SKLEARN_INSTALLED = True
XGBKFold = KFold
XGBStratifiedKFold = StratifiedKFold
XGBModelBase = BaseEstimator XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin XGBClassifierBase = ClassifierMixin

View File

@ -8,6 +8,7 @@ import sys
import re import re
import numpy as np import numpy as np
from .core import Booster, STRING_TYPES from .core import Booster, STRING_TYPES
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None, maximize=False, early_stopping_rounds=None, evals_result=None,
@ -261,15 +262,26 @@ class CVPack(object):
return self.bst.eval_set(self.watchlist, iteration, feval) return self.bst.eval_set(self.watchlist, iteration, feval)
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None): def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, folds=None):
""" """
Make an n-fold list of CVPack from random indices. Make an n-fold list of CVPack from random indices.
""" """
evals = list(evals) evals = list(evals)
np.random.seed(seed) np.random.seed(seed)
randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold if stratified is False and folds is None:
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)] randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None:
idset = [x[1] for x in folds]
nfold = len(idset)
else:
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
n_folds=nfold,
shuffle=True,
random_state=seed)]
ret = [] ret = []
for k in range(nfold): for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i])) dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
@ -345,8 +357,8 @@ def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True, trial=0):
return results return results
def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(), def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
obj=None, feval=None, maximize=False, early_stopping_rounds=None, metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, show_progress=None, show_stdv=True, seed=0): fpreproc=None, as_pandas=True, show_progress=None, show_stdv=True, seed=0):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given paramaters. """Cross-validation with given paramaters.
@ -361,6 +373,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
Number of boosting iterations. Number of boosting iterations.
nfold : int nfold : int
Number of folds in CV. Number of folds in CV.
stratified : bool
Perform stratified sampling.
folds : KFold or StratifiedKFold
Sklearn KFolds or StratifiedKFolds.
metrics : string or list of strings metrics : string or list of strings
Evaluation metrics to be watched in CV. Evaluation metrics to be watched in CV.
obj : function obj : function
@ -381,9 +397,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
If False or pandas is not installed, return np.ndarray If False or pandas is not installed, return np.ndarray
show_progress : bool, int, or None, default None show_progress : bool, int, or None, default None
Whether to display the progress. If None, progress will be displayed Whether to display the progress. If None, progress will be displayed
when np.ndarray is returned. If True, progress will be displayed at when np.ndarray is returned. If True, progress will be displayed at
boosting stage. If an integer is given, progress will be displayed boosting stage. If an integer is given, progress will be displayed
at every given `show_progress` boosting stage. at every given `show_progress` boosting stage.
show_stdv : bool, default True show_stdv : bool, default True
Whether to display the standard deviation in progress. Whether to display the standard deviation in progress.
Results are not affected, and always contains std. Results are not affected, and always contains std.
@ -394,6 +410,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
------- -------
evaluation history : list(string) evaluation history : list(string)
""" """
if stratified == True and not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use stratified cv')
if isinstance(metrics, str): if isinstance(metrics, str):
metrics = [metrics] metrics = [metrics]
@ -436,7 +455,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
best_score_i = 0 best_score_i = 0
results = [] results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc) cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
for i in range(num_boost_round): for i in range(num_boost_round):
for fold in cvfolds: for fold in cvfolds:
fold.update(i, obj) fold.update(i, obj)
@ -466,3 +485,4 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
results = np.array(results) results = np.array(results)
return results return results

37
tests/python/test_cv.py Normal file
View File

@ -0,0 +1,37 @@
import xgboost as xgb
import numpy as np
from sklearn.datasets import load_digits
from sklearn.cross_validation import KFold, StratifiedKFold, train_test_split
from sklearn.metrics import mean_squared_error
import unittest
rng = np.random.RandomState(1994)
class TestCrossValidation(unittest.TestCase):
def test_cv(self):
digits = load_digits(3)
X = digits['data']
y = digits['target']
dm = xgb.DMatrix(X, label=y)
params = {
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective':
'multi:softprob',
'num_class': 3
}
seed = 2016
nfolds = 5
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
import pandas as pd
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1,0] == cv3.iloc[-1,0]