From 301540f1d931fdf5f5c4e859700ba0bb5757cf2d Mon Sep 17 00:00:00 2001 From: Icyblade Dai Date: Fri, 17 Mar 2017 21:38:22 +0800 Subject: [PATCH] fix DeprecationWarning on sklearn.cross_validation (#2075) * fix DeprecationWarning on sklearn.cross_validation * fix syntax * fix kfold n_split issue * fix mistype * fix n_splits multiple value issue * split should pass a iterable * use np.arange instead of xrange, py3 compatibility --- demo/guide-python/sklearn_examples.py | 5 ++++- tests/python/test_early_stopping.py | 5 ++++- tests/python/test_eval_metrics.py | 5 ++++- tests/python/test_fast_hist.py | 5 ++++- tests/python/test_with_sklearn.py | 17 ++++++++++++++--- 5 files changed, 30 insertions(+), 7 deletions(-) diff --git a/demo/guide-python/sklearn_examples.py b/demo/guide-python/sklearn_examples.py index 7ce95b491..a5a91ffa1 100755 --- a/demo/guide-python/sklearn_examples.py +++ b/demo/guide-python/sklearn_examples.py @@ -8,7 +8,10 @@ import pickle import xgboost as xgb import numpy as np -from sklearn.cross_validation import KFold, train_test_split +try: + from sklearn.model_selection import KFold, train_test_split +except: + from sklearn.cross_validation import KFold, train_test_split from sklearn.metrics import confusion_matrix, mean_squared_error from sklearn.grid_search import GridSearchCV from sklearn.datasets import load_iris, load_digits, load_boston diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py index 67e725b74..7553aed66 100644 --- a/tests/python/test_early_stopping.py +++ b/tests/python/test_early_stopping.py @@ -11,7 +11,10 @@ class TestEarlyStopping(unittest.TestCase): def test_early_stopping_nonparallel(self): tm._skip_if_no_sklearn() from sklearn.datasets import load_digits - from sklearn.cross_validation import train_test_split + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split digits = load_digits(2) X = digits['data'] diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 529ef698c..611b7e8fe 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -57,7 +57,10 @@ class TestEvalMetrics(unittest.TestCase): def test_eval_metrics(self): tm._skip_if_no_sklearn() - from sklearn.cross_validation import train_test_split + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split from sklearn.datasets import load_digits digits = load_digits(2) diff --git a/tests/python/test_fast_hist.py b/tests/python/test_fast_hist.py index a79f402a6..791c32226 100644 --- a/tests/python/test_fast_hist.py +++ b/tests/python/test_fast_hist.py @@ -10,7 +10,10 @@ class TestFastHist(unittest.TestCase): def test_fast_hist(self): tm._skip_if_no_sklearn() from sklearn.datasets import load_digits - from sklearn.cross_validation import train_test_split + try: + from sklearn.model_selection import train_test_split + except: + from sklearn.cross_validation import train_test_split digits = load_digits(2) X = digits['data'] diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 19de5abb9..12726b002 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -9,12 +9,20 @@ rng = np.random.RandomState(1994) def test_binary_classification(): tm._skip_if_no_sklearn() from sklearn.datasets import load_digits - from sklearn.cross_validation import KFold + try: + from sklearn.model_selection import KFold + except: + from sklearn.cross_validation import KFold digits = load_digits(2) y = digits['target'] X = digits['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + try: + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + except TypeError: # sklearn.model_selection.KFold uses n_split + kf = KFold( + n_splits=2, shuffle=True, random_state=rng + ).split(np.arange(y.shape[0])) for train_index, test_index in kf: xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) preds = xgb_model.predict(X[test_index]) @@ -27,7 +35,10 @@ def test_binary_classification(): def test_multiclass_classification(): tm._skip_if_no_sklearn() from sklearn.datasets import load_iris - from sklearn.cross_validation import KFold + try: + from sklearn.cross_validation import KFold + except: + from sklearn.model_selection import KFold def check_pred(preds, labels): err = sum(1 for i in range(len(preds))