Fix mknfold using new StratifiedKFold API (#1660)
This commit is contained in:
parent
b56c6097d9
commit
63829d656c
@ -46,7 +46,7 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
from sklearn.base import RegressorMixin, ClassifierMixin
|
from sklearn.base import RegressorMixin, ClassifierMixin
|
||||||
from sklearn.preprocessing import LabelEncoder # noqa
|
from sklearn.preprocessing import LabelEncoder
|
||||||
try:
|
try:
|
||||||
from sklearn.model_selection import KFold, StratifiedKFold
|
from sklearn.model_selection import KFold, StratifiedKFold
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@ -232,14 +232,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
|||||||
randidx = np.random.permutation(dall.num_row())
|
randidx = np.random.permutation(dall.num_row())
|
||||||
kstep = int(len(randidx) / nfold)
|
kstep = int(len(randidx) / nfold)
|
||||||
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
|
||||||
elif folds is not None:
|
elif folds is not None and isinstance(folds, list):
|
||||||
idset = [x[1] for x in folds]
|
idset = [x[1] for x in folds]
|
||||||
nfold = len(idset)
|
nfold = len(idset)
|
||||||
else:
|
else:
|
||||||
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
|
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
|
||||||
n_folds=nfold,
|
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]
|
||||||
shuffle=True,
|
|
||||||
random_state=seed)]
|
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for k in range(nfold):
|
for k in range(nfold):
|
||||||
|
|||||||
@ -251,7 +251,7 @@ def test_sklearn_plotting():
|
|||||||
def test_sklearn_nfolds_cv():
|
def test_sklearn_nfolds_cv():
|
||||||
tm._skip_if_no_sklearn()
|
tm._skip_if_no_sklearn()
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.cross_validation import StratifiedKFold
|
from sklearn.model_selection import StratifiedKFold
|
||||||
|
|
||||||
digits = load_digits(3)
|
digits = load_digits(3)
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
@ -269,10 +269,10 @@ def test_sklearn_nfolds_cv():
|
|||||||
|
|
||||||
seed = 2016
|
seed = 2016
|
||||||
nfolds = 5
|
nfolds = 5
|
||||||
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
|
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
|
||||||
|
|
||||||
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
|
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
|
||||||
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
|
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
|
||||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
|
||||||
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
|
||||||
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user