Refactor Python tests. (#3897)

* Deprecate nose tests.
* Format python tests.
This commit is contained in:
Jiaming Yuan 2018-11-15 13:56:33 +13:00 committed by GitHub
parent c76d993681
commit 2ea0f887c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 302 additions and 225 deletions

4
Jenkinsfile vendored
View File

@ -96,11 +96,11 @@ def buildPlatformCmake(buildName, conf, nodeReq, dockerTarget) {
# Test the wheel for compatibility on a barebones CPU container
${dockerRun} release ${dockerArgs} bash -c " \
pip install --user python-package/dist/xgboost-*-none-any.whl && \
python -m nose -v tests/python"
pytest -v --fulltrace -s tests/python"
# Test the wheel for compatibility on CUDA 10.0 container
${dockerRun} gpu --build-arg CUDA_VERSION=10.0 bash -c " \
pip install --user python-package/dist/xgboost-*-none-any.whl && \
python -m nose -v --eval-attr='(not slow) and (not mgpu)' tests/python-gpu"
pytest -v -s --fulltrace -m '(not mgpu) and (not slow)' tests/python-gpu"
"""
}
}

View File

@ -44,7 +44,7 @@ install:
- set DO_PYTHON=off
- if /i "%target%" == "mingw" set DO_PYTHON=on
- if /i "%target%_%ver%_%configuration%" == "msvc_2015_Release" set DO_PYTHON=on
- if /i "%DO_PYTHON%" == "on" conda install -y numpy scipy pandas matplotlib nose scikit-learn graphviz python-graphviz
- if /i "%DO_PYTHON%" == "on" conda install -y numpy scipy pandas matplotlib pytest scikit-learn graphviz python-graphviz
# R: based on https://github.com/krlmlr/r-appveyor
- ps: |
if($env:target -eq 'rmingw' -or $env:target -eq 'rmsvc') {
@ -96,7 +96,7 @@ build_script:
test_script:
- cd %APPVEYOR_BUILD_FOLDER%
- if /i "%DO_PYTHON%" == "on" python -m nose tests/python
- if /i "%DO_PYTHON%" == "on" python -m pytest tests/python
# mingw R package: run the R check (which includes unit tests), and also keep the built binary package
- if /i "%target%" == "rmingw" (
set _R_CHECK_CRAN_INCOMING_=FALSE&&

View File

@ -37,7 +37,7 @@ ENV CPP=/opt/rh/devtoolset-2/root/usr/bin/cpp
# Install Python packages
RUN \
pip install numpy nose scipy scikit-learn wheel
pip install numpy pytest scipy scikit-learn wheel
ENV GOSU_VERSION 1.10

View File

@ -15,8 +15,8 @@ ENV PATH=/opt/python/bin:$PATH
# Install Python packages
RUN \
conda install numpy scipy pandas matplotlib nose scikit-learn && \
pip install nose wheel auditwheel graphviz
conda install numpy scipy pandas matplotlib pytest scikit-learn && \
pip install pytest wheel auditwheel graphviz
ENV GOSU_VERSION 1.10

View File

@ -4,6 +4,5 @@ set -e
cd python-package
python setup.py install --user
cd ..
python -m nose -v --eval-attr='(not slow) and (not mgpu)' tests/python-gpu/
pytest -v -s --fulltrace -m "(not mgpu) and (not slow)" tests/python-gpu
./testxgboost --gtest_filter=-*.MGPU_*

View File

@ -4,5 +4,5 @@ set -e
cd python-package
python setup.py install --user
cd ..
python -m nose -v --eval-attr='(not slow) and mgpu' tests/python-gpu/
pytest -v -s --fulltrace -m "(not slow) and mgpu" tests/python-gpu
./testxgboost --gtest_filter=*.MGPU_*

View File

@ -1,18 +1,19 @@
import sys
import pytest
import unittest
sys.path.append('tests/python/')
import test_linear
import testing as tm
import unittest
class TestGPULinear(unittest.TestCase):
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
"Boston External Memory"]
@pytest.mark.skipif(**tm.no_sklearn())
def test_gpu_coordinate(self):
tm._skip_if_no_sklearn()
variable_param = {
'booster': ['gblinear'],
'updater': ['coord_descent'],

View File

@ -1,15 +1,14 @@
from __future__ import print_function
import numpy as np
import sys
import unittest
import xgboost as xgb
from nose.plugins.attrib import attr
import pytest
rng = np.random.RandomState(1994)
@attr('gpu')
@pytest.mark.gpu
class TestGPUPredict(unittest.TestCase):
def test_predict(self):
iterations = 10
@ -18,9 +17,12 @@ class TestGPUPredict(unittest.TestCase):
test_num_cols = [10, 50, 500]
for num_rows in test_num_rows:
for num_cols in test_num_cols:
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols), label=[0, 1] * int(num_rows / 2))
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
watchlist = [(dtrain, 'train'), (dval, 'validation')]
res = {}
param = {
@ -28,7 +30,8 @@ class TestGPUPredict(unittest.TestCase):
"predictor": "gpu_predictor",
'eval_metric': 'auc',
}
bst = xgb.train(param, dtrain, iterations, evals=watchlist, evals_result=res)
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
evals_result=res)
assert self.non_decreasing(res["train"]["auc"])
gpu_pred_train = bst.predict(dtrain, output_margin=True)
gpu_pred_test = bst.predict(dtest, output_margin=True)
@ -39,21 +42,26 @@ class TestGPUPredict(unittest.TestCase):
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-5)
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train,
rtol=1e-5)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val,
rtol=1e-5)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test,
rtol=1e-5)
def non_decreasing(self, L):
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
# Test case for a bug where multiple batch predictions made on a test set produce incorrect results
# Test case for a bug where multiple batch predictions made on a
# test set produce incorrect results
def test_multi_predict(self):
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
n = 1000
X, y = make_regression(n, random_state=rng)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=123)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
@ -85,8 +93,7 @@ class TestGPUPredict(unittest.TestCase):
params = {'tree_method': 'gpu_hist',
'predictor': 'cpu_predictor',
'n_jobs': -1,
'seed': 123
}
'seed': 123}
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
cpu_train_score = m.score(X_train, y_train)
cpu_test_score = m.score(X_test, y_test)

View File

@ -1,10 +1,9 @@
import numpy as np
import sys
import unittest
from nose.plugins.attrib import attr
import pytest
sys.path.append("tests/python")
import xgboost as xgb
from regression_test_utilities import run_suite, parameter_combinations, \
assert_results_non_increasing
@ -45,7 +44,7 @@ class TestGPU(unittest.TestCase):
cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results)
@attr('mgpu')
@pytest.mark.mgpu
def test_gpu_hist_mgpu(self):
variable_param = {'n_gpus': [-1], 'max_depth': [2, 10],
'max_leaves': [255, 4],
@ -56,7 +55,7 @@ class TestGPU(unittest.TestCase):
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)
@attr('mgpu')
@pytest.mark.mgpu
def test_specified_gpu_id_gpu_update(self):
variable_param = {'n_gpus': [1],
'gpu_id': [1],

View File

@ -2,12 +2,12 @@ from __future__ import print_function
import sys
import time
import pytest
sys.path.append("../../tests/python")
import xgboost as xgb
import numpy as np
import unittest
from nose.plugins.attrib import attr
def eprint(*args, **kwargs):
@ -16,9 +16,11 @@ def eprint(*args, **kwargs):
print(*args, file=sys.stdout, **kwargs)
sys.stdout.flush()
rng = np.random.RandomState(1994)
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/ , which has been processed to one-hot encode categoricalsxsy
# "realistic" size based upon http://stat-computing.org/dataexpo/2009/
# , which has been processed to one-hot encode categoricalsxsy
cols = 31
# reduced to fit onto 1 gpu but still be large
rows3 = 5000 # small
@ -28,7 +30,7 @@ rows1 = 42360032 # large
rowslist = [rows1, rows2, rows3]
@attr('slow')
@pytest.mark.slow
class TestGPU(unittest.TestCase):
def test_large(self):
for rows in rowslist:
@ -47,15 +49,8 @@ class TestGPU(unittest.TestCase):
max_depth = 6
max_bin = 1024
# regression test --- hist must be same as exact on all-categorial data
ag_param = {'max_depth': max_depth,
'tree_method': 'exact',
'nthread': 0,
'eta': 1,
'silent': 0,
'debug_verbose': 5,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
# regression test --- hist must be same as exact on
# all-categorial data
ag_paramb = {'max_depth': max_depth,
'tree_method': 'hist',
'nthread': 0,

View File

@ -1,11 +1,13 @@
from __future__ import print_function
import numpy as np
import unittest
import xgboost as xgb
from nose.plugins.attrib import attr
from sklearn.datasets import make_regression
import unittest
import pytest
import xgboost as xgb
rng = np.random.RandomState(1994)
@ -33,7 +35,7 @@ def assert_constraint(constraint, tree_method):
assert non_increasing(pred)
@attr('gpu')
@pytest.mark.gpu
class TestMonotonicConstraints(unittest.TestCase):
def test_exact(self):
assert_constraint(1, 'exact')

View File

@ -21,6 +21,8 @@ def captured_output():
"""
Reassign stdout temporarily in order to test printed statements
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
Also works for pytest.
"""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
@ -36,7 +38,8 @@ class TestBasic(unittest.TestCase):
def test_basic(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
@ -44,7 +47,8 @@ class TestBasic(unittest.TestCase):
# this is prediction
preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
@ -62,7 +66,8 @@ class TestBasic(unittest.TestCase):
def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
param = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
@ -86,7 +91,8 @@ class TestBasic(unittest.TestCase):
# this is prediction
preds = bst.predict(dtest)
labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) if preds[i] != labels[i]) / float(len(preds))
err = sum(1 for i in range(len(preds))
if preds[i] != labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
@ -248,7 +254,8 @@ class TestBasic(unittest.TestCase):
def test_cv(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
@ -257,16 +264,19 @@ class TestBasic(unittest.TestCase):
def test_cv_no_shuffle(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10,
as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)
def test_cv_explicit_fold_indices(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective':
'binary:logistic'}
folds = [
# Train Test
([1, 3], [5, 8]),
@ -274,12 +284,14 @@ class TestBasic(unittest.TestCase):
]
# return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds, as_pandas=False)
cv = xgb.cv(params, dm, num_boost_round=10, folds=folds,
as_pandas=False)
assert isinstance(cv, dict)
assert len(cv) == (4)
def test_cv_explicit_fold_indices_labels(self):
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'reg:linear'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective':
'reg:linear'}
N = 100
F = 3
dm = xgb.DMatrix(data=np.random.randn(N, F), label=np.arange(N))
@ -300,7 +312,9 @@ class TestBasic(unittest.TestCase):
as_pandas=False
)
output = out.getvalue().strip()
assert output == '[array([5., 8.], dtype=float32), array([23., 43., 11.], dtype=float32)]'
solution = ('[array([5., 8.], dtype=float32), array([23., 43., 11.],' +
' dtype=float32)]')
assert output == solution
def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import unittest
import pytest
import testing as tm
import xgboost as xgb
@ -10,14 +11,16 @@ try:
except ImportError:
pass
tm._skip_if_no_dt()
tm._skip_if_no_pandas()
pytestmark = pytest.mark.skipif(
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
class TestDataTable(unittest.TestCase):
def test_dt(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
labels = dt.Frame([1, 2])
dm = xgb.DMatrix(dtable, label=labels)
@ -34,7 +37,8 @@ class TestDataTable(unittest.TestCase):
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
columns=['a', 'b', 'c'])
dtable = dt.Frame(df)
self.assertRaises(ValueError, xgb.DMatrix, dtable)

View File

@ -2,24 +2,26 @@ import xgboost as xgb
import testing as tm
import numpy as np
import unittest
import pytest
rng = np.random.RandomState(1994)
class TestEarlyStopping(unittest.TestCase):
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping_nonparallel(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
except ImportError:
from sklearn.cross_validation import train_test_split
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
clf1 = xgb.XGBClassifier()
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
eval_set=[(X_test, y_test)])
@ -35,36 +37,41 @@ class TestEarlyStopping(unittest.TestCase):
eval_set=[(X_test, y_test)])
assert clf3.best_score == 1
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror(self, preds, dtrain):
tm._skip_if_no_sklearn()
from sklearn.metrics import mean_squared_error
labels = dtrain.get_label()
return 'rmse', mean_squared_error(labels, preds)
@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
digits = load_digits(2)
X = digits['data']
y = digits['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, early_stopping_rounds=10)
assert cv.shape[0] == 10
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, early_stopping_rounds=5)
assert cv.shape[0] == 3
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, early_stopping_rounds=1)
assert cv.shape[0] == 1
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, feval=self.evalerror,
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
early_stopping_rounds=10)
assert cv.shape[0] == 10
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, feval=self.evalerror,
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
early_stopping_rounds=5)
assert cv.shape[0] == 3
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
early_stopping_rounds=1)
assert cv.shape[0] == 1
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
feval=self.evalerror, early_stopping_rounds=10)
assert cv.shape[0] == 10
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
feval=self.evalerror, early_stopping_rounds=1)
assert cv.shape[0] == 5
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10,
feval=self.evalerror, maximize=True,
early_stopping_rounds=1)
assert cv.shape[0] == 5
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, feval=self.evalerror,
maximize=True, early_stopping_rounds=1)
assert cv.shape[0] == 1

View File

@ -2,6 +2,7 @@ import xgboost as xgb
import testing as tm
import numpy as np
import unittest
import pytest
rng = np.random.RandomState(1337)
@ -39,27 +40,27 @@ class TestEvalMetrics(unittest.TestCase):
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels))]
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror_03(self, preds, dtrain):
tm._skip_if_no_sklearn()
from sklearn.metrics import mean_squared_error
labels = dtrain.get_label()
return [('rmse', mean_squared_error(labels, preds)),
('error', float(sum(labels != (preds > 0.0))) / len(labels))]
@pytest.mark.skipif(**tm.no_sklearn())
def evalerror_04(self, preds, dtrain):
tm._skip_if_no_sklearn()
from sklearn.metrics import mean_squared_error
labels = dtrain.get_label()
return [('error', float(sum(labels != (preds > 0.0))) / len(labels)),
('rmse', mean_squared_error(labels, preds))]
@pytest.mark.skipif(**tm.no_sklearn())
def test_eval_metrics(self):
tm._skip_if_no_sklearn()
try:
from sklearn.model_selection import train_test_split
except:
except ImportError:
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits

View File

@ -3,6 +3,8 @@ from __future__ import print_function
import numpy as np
import testing as tm
import unittest
import pytest
import xgboost as xgb
try:
@ -22,13 +24,16 @@ def is_float(s):
def xgb_get_weights(bst):
return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)])
return np.array([float(s) for s in bst.get_dump()[0].split() if
is_float(s)])
def assert_regression_result(results, tol):
regression_results = [r for r in results if r["param"]["objective"] == "reg:linear"]
regression_results = [r for r in results if
r["param"]["objective"] == "reg:linear"]
for res in regression_results:
X = scale(res["dataset"].X, with_mean=isinstance(res["dataset"].X, np.ndarray))
X = scale(res["dataset"].X,
with_mean=isinstance(res["dataset"].X, np.ndarray))
y = res["dataset"].y
reg_alpha = res["param"]["alpha"]
reg_lambda = res["param"]["lambda"]
@ -38,14 +43,16 @@ def assert_regression_result(results, tol):
l1_ratio=reg_alpha / (reg_alpha + reg_lambda))
enet.fit(X, y)
enet_pred = enet.predict(X)
assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all(), (weights, enet.coef_)
assert np.isclose(weights, enet.coef_, rtol=tol,
atol=tol).all(), (weights, enet.coef_)
assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all(), (
res["dataset"].name, enet_pred[:5], pred[:5])
# TODO: More robust classification tests
def assert_classification_result(results):
classification_results = [r for r in results if r["param"]["objective"] != "reg:linear"]
classification_results = [r for r in results if
r["param"]["objective"] != "reg:linear"]
for res in classification_results:
# Check accuracy is reasonable
assert res["eval"][-1] < 0.5, (res["dataset"].name, res["eval"][-1])
@ -56,25 +63,26 @@ class TestLinear(unittest.TestCase):
datasets = ["Boston", "Digits", "Cancer", "Sparse regression",
"Boston External Memory"]
@pytest.mark.skipif(**tm.no_sklearn())
def test_coordinate(self):
tm._skip_if_no_sklearn()
variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
'top_k': [10], 'tolerance': [1e-5], 'nthread': [2],
variable_param = {'booster': ['gblinear'], 'updater':
['coord_descent'], 'eta': [0.5], 'top_k':
[10], 'tolerance': [1e-5], 'nthread': [2],
'alpha': [.005, .1], 'lambda': [.005],
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']
}
'feature_selector': ['cyclic', 'shuffle',
'greedy', 'thrifty']}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, self.datasets, scale_features=True)
assert_regression_result(results, 1e-2)
assert_classification_result(results)
@pytest.mark.skipif(**tm.no_sklearn())
def test_shotgun(self):
tm._skip_if_no_sklearn()
variable_param = {'booster': ['gblinear'], 'updater': ['shotgun'], 'eta': [0.5],
'top_k': [10], 'tolerance': [1e-5], 'nthread': [2],
variable_param = {'booster': ['gblinear'], 'updater':
['shotgun'], 'eta': [0.5], 'top_k': [10],
'tolerance': [1e-5], 'nthread': [2],
'alpha': [.005, .1], 'lambda': [.005],
'feature_selector': ['cyclic', 'shuffle']
}
'feature_selector': ['cyclic', 'shuffle']}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, self.datasets, True)
assert_regression_result(results, 1e-2)

View File

@ -2,7 +2,9 @@
import numpy as np
import xgboost as xgb
import testing as tm
import unittest
import pytest
try:
import matplotlib
@ -13,7 +15,7 @@ except ImportError:
pass
tm._skip_if_no_matplotlib()
pytestmark = pytest.mark.skipif(**tm.no_matplotlib())
dpath = 'demo/data/'

View File

@ -2,6 +2,7 @@ import xgboost as xgb
import testing as tm
import numpy as np
import unittest
import pytest
rng = np.random.RandomState(1337)
@ -27,8 +28,8 @@ class TestTrainingContinuation(unittest.TestCase):
'num_parallel_tree': num_parallel_tree
}
@pytest.mark.skipif(**tm.no_sklearn())
def test_training_continuation(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.metrics import mean_squared_error
@ -44,15 +45,19 @@ class TestTrainingContinuation(unittest.TestCase):
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10)
gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=10)
ntrees_01 = len(gbdt_01.get_dump())
assert ntrees_01 == 10
gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=0)
gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=0)
gbdt_02.save_model('xgb_tc.model')
gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02)
gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.model")
gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model=gbdt_02)
gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model="xgb_tc.model")
ntrees_02a = len(gbdt_02a.get_dump())
ntrees_02b = len(gbdt_02b.get_dump())
assert ntrees_02a == 10
@ -66,11 +71,14 @@ class TestTrainingContinuation(unittest.TestCase):
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
assert res1 == res2
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=3)
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=3)
gbdt_03.save_model('xgb_tc.model')
gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03)
gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.model")
gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_03)
gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model="xgb_tc.model")
ntrees_03a = len(gbdt_03a.get_dump())
ntrees_03b = len(gbdt_03b.get_dump())
assert ntrees_03a == 10
@ -80,25 +88,42 @@ class TestTrainingContinuation(unittest.TestCase):
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
assert res1 == res2
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=3)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class,
num_boost_round=3)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration +
1) * self.num_parallel_tree
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
res2 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
res2 = mean_squared_error(y_2class,
gbdt_04.predict(
dtrain_2class,
ntree_limit=gbdt_04.best_ntree_limit))
assert res1 == res2
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_04)
assert gbdt_04.best_ntree_limit == (
gbdt_04.best_iteration + 1) * self.num_parallel_tree
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
res2 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
res2 = mean_squared_error(y_2class,
gbdt_04.predict(
dtrain_2class,
ntree_limit=gbdt_04.best_ntree_limit))
assert res1 == res2
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=7)
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05)
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class,
num_boost_round=7)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
gbdt_05 = xgb.train(self.xgb_params_03,
dtrain_5class,
num_boost_round=3,
xgb_model=gbdt_05)
assert gbdt_05.best_ntree_limit == (
gbdt_05.best_iteration + 1) * self.num_parallel_tree
res1 = gbdt_05.predict(dtrain_5class)
res2 = gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)
res2 = gbdt_05.predict(dtrain_5class,
ntree_limit=gbdt_05.best_ntree_limit)
np.testing.assert_almost_equal(res1, res2)

View File

@ -1,5 +1,6 @@
import testing as tm
import unittest
import pytest
import xgboost as xgb
try:
@ -10,24 +11,27 @@ except ImportError:
class TestUpdaters(unittest.TestCase):
@pytest.mark.skipif(**tm.no_sklearn())
def test_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_histmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_colmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_colmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_fast_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'tree_method': ['hist'], 'max_depth': [2, 8], 'max_bin': [2, 256],
'grow_policy': ['depthwise', 'lossguide'], 'max_leaves': [64, 0],
variable_param = {'tree_method': ['hist'],
'max_depth': [2, 8],
'max_bin': [2, 256],
'grow_policy': ['depthwise', 'lossguide'],
'max_leaves': [64, 0],
'silent': [1]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
@ -46,10 +50,12 @@ class TestUpdaters(unittest.TestCase):
hist_res = {}
exact_res = {}
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param, ag_dtrain, 10,
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=hist_res)
ag_param["tree_method"] = "exact"
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param, ag_dtrain, 10,
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=exact_res)
assert hist_res['train']['auc'] == exact_res['train']['auc']
assert hist_res['test']['auc'] == exact_res['test']['auc']

View File

@ -3,6 +3,7 @@ import numpy as np
import xgboost as xgb
import testing as tm
import unittest
import pytest
try:
import pandas as pd
@ -10,7 +11,7 @@ except ImportError:
pass
tm._skip_if_no_pandas()
pytestmark = pytest.mark.skipif(**tm.no_pandas())
dpath = 'demo/data/'
@ -21,7 +22,8 @@ class TestPandas(unittest.TestCase):
def test_pandas(self):
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
columns=['a', 'b', 'c'])
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
assert dm.feature_names == ['a', 'b', 'c']
assert dm.feature_types == ['int', 'float', 'i']
@ -30,14 +32,16 @@ class TestPandas(unittest.TestCase):
# overwrite feature_names and feature_types
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
feature_names=['x', 'y', 'z'], feature_types=['q', 'q', 'q'])
feature_names=['x', 'y', 'z'],
feature_types=['q', 'q', 'q'])
assert dm.feature_names == ['x', 'y', 'z']
assert dm.feature_types == ['q', 'q', 'q']
assert dm.num_row() == 2
assert dm.num_col() == 3
# incorrect dtypes
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
columns=['a', 'b', 'c'])
self.assertRaises(ValueError, xgb.DMatrix, df)
# numeric columns
@ -107,7 +111,8 @@ class TestPandas(unittest.TestCase):
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result = xgb.core._maybe_pandas_label(df)
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float))
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float))
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
assert dm.num_row() == 3
@ -115,9 +120,9 @@ class TestPandas(unittest.TestCase):
def test_cv_as_pandas(self):
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
params = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 2, 'eta': 1, 'silent': 1,
'objective': 'binary:logistic'}
import pandas as pd
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10)
assert isinstance(cv, pd.DataFrame)
exp = pd.Index([u'test-error-mean', u'test-error-std',

View File

@ -4,10 +4,12 @@ import testing as tm
import tempfile
import os
import shutil
from nose.tools import raises
import pytest
rng = np.random.RandomState(1994)
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp()"""
@ -20,7 +22,6 @@ class TemporaryDirectory(object):
def test_binary_classification():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
@ -38,7 +39,6 @@ def test_binary_classification():
def test_multiclass_classification():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
@ -59,9 +59,12 @@ def test_multiclass_classification():
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
preds2 = xgb_model.predict(X[test_index], output_margin=True,
ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True,
ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False,
ntree_limit=3)
labels = y[test_index]
check_pred(preds, labels, output_margin=False)
@ -71,7 +74,6 @@ def test_multiclass_classification():
def test_ranking():
tm._skip_if_no_sklearn()
# generate random data
x_train = np.random.rand(1000, 10)
y_train = np.random.randint(5, size=1000)
@ -105,13 +107,13 @@ def test_ranking():
def test_feature_importances_weight():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
digits = load_digits(2)
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
@ -127,28 +129,32 @@ def test_feature_importances_weight():
import pandas as pd
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="weight").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="weight").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
def test_feature_importances_gain():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
digits = load_digits(2)
y = digits['target']
X = digits['data']
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="gain").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00326159, 0., 0., 0.,
0., 0., 0., 0., 0., 0.00297238, 0.00988034, 0., 0., 0., 0.,
0., 0., 0.03512521, 0.41123885, 0., 0., 0., 0., 0.01326332,
0.00160674, 0., 0.4206952, 0., 0., 0., 0., 0.00616747, 0.01237546,
0., 0., 0., 0., 0., 0., 0., 0.08240705, 0., 0., 0., 0.,
0., 0., 0., 0.00100649, 0., 0., 0., 0., 0.], dtype=np.float32)
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
0.00297238, 0.00988034, 0., 0., 0., 0., 0., 0.,
0.03512521, 0.41123885, 0., 0., 0., 0.,
0.01326332, 0.00160674, 0., 0.4206952, 0., 0., 0.,
0., 0.00616747, 0.01237546, 0., 0., 0., 0., 0.,
0., 0., 0.08240705, 0., 0., 0., 0., 0., 0., 0.,
0.00100649, 0., 0., 0., 0., 0.], dtype=np.float32)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
@ -156,15 +162,16 @@ def test_feature_importances_gain():
import pandas as pd
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="gain").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
xgb_model = xgb.XGBClassifier(random_state=0, importance_type="gain").fit(X, y)
xgb_model = xgb.XGBClassifier(
random_state=0, importance_type="gain").fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
def test_boston_housing_regression():
tm._skip_if_no_sklearn()
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.model_selection import KFold
@ -178,9 +185,12 @@ def test_boston_housing_regression():
preds = xgb_model.predict(X[test_index])
# test other params in XGBRegressor().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3)
preds2 = xgb_model.predict(X[test_index], output_margin=True,
ntree_limit=3)
preds3 = xgb_model.predict(X[test_index], output_margin=True,
ntree_limit=0)
preds4 = xgb_model.predict(X[test_index], output_margin=False,
ntree_limit=3)
labels = y[test_index]
assert mean_squared_error(preds, labels) < 25
@ -190,7 +200,6 @@ def test_boston_housing_regression():
def test_parameter_tuning():
tm._skip_if_no_sklearn()
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_boston
@ -207,7 +216,6 @@ def test_parameter_tuning():
def test_regression_with_custom_objective():
tm._skip_if_no_sklearn()
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.model_selection import KFold
@ -241,7 +249,6 @@ def test_regression_with_custom_objective():
def test_classification_with_custom_objective():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
@ -280,7 +287,6 @@ def test_classification_with_custom_objective():
def test_sklearn_api():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
@ -298,12 +304,12 @@ def test_sklearn_api():
def test_sklearn_api_gblinear():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target, train_size=120)
tr_d, te_d, tr_l, te_l = train_test_split(iris.data, iris.target,
train_size=120)
classifier = xgb.XGBClassifier(booster='gblinear', n_estimators=100)
classifier.fit(tr_d, tr_l)
@ -314,8 +320,8 @@ def test_sklearn_api_gblinear():
assert err < 0.5
@pytest.mark.skipif(**tm.no_matplotlib())
def test_sklearn_plotting():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_iris
iris = load_iris()
@ -344,7 +350,6 @@ def test_sklearn_plotting():
def test_sklearn_nfolds_cv():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.model_selection import StratifiedKFold
@ -367,14 +372,15 @@ def test_sklearn_nfolds_cv():
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds,
folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds,
stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
def test_split_value_histograms():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
digits_2class = load_digits(2)
@ -383,11 +389,14 @@ def test_split_value_histograms():
y = digits_2class['target']
dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1,
'objective': 'binary:logistic'}
gbdt = xgb.train(params, dm, num_boost_round=10)
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there",
as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there",
as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
@ -396,8 +405,6 @@ def test_split_value_histograms():
def test_sklearn_random_state():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(random_state=402)
assert clf.get_xgb_params()['seed'] == 402
@ -406,8 +413,6 @@ def test_sklearn_random_state():
def test_sklearn_n_jobs():
tm._skip_if_no_sklearn()
clf = xgb.XGBClassifier(n_jobs=1)
assert clf.get_xgb_params()['nthread'] == 1
@ -416,8 +421,6 @@ def test_sklearn_n_jobs():
def test_kwargs():
tm._skip_if_no_sklearn()
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_estimators=1000, **params)
assert clf.get_params()['updater'] == 'grow_gpu'
@ -426,7 +429,6 @@ def test_kwargs():
def test_kwargs_grid_search():
tm._skip_if_no_sklearn()
from sklearn.model_selection import GridSearchCV
from sklearn import datasets
@ -446,17 +448,14 @@ def test_kwargs_grid_search():
assert len(means) == len(set(means))
@raises(TypeError)
def test_kwargs_error():
tm._skip_if_no_sklearn()
params = {'updater': 'grow_gpu', 'subsample': .5, 'n_jobs': -1}
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)
with pytest.raises(TypeError):
clf = xgb.XGBClassifier(n_jobs=1000, **params)
assert isinstance(clf, xgb.XGBClassifier)
def test_sklearn_clone():
tm._skip_if_no_sklearn()
from sklearn.base import clone
clf = xgb.XGBClassifier(n_jobs=2, nthread=3)
@ -465,7 +464,6 @@ def test_sklearn_clone():
def test_validation_weights_xgbmodel():
tm._skip_if_no_sklearn()
from sklearn.datasets import make_hastie_10_2
# prepare training and test data
@ -489,7 +487,8 @@ def test_validation_weights_xgbmodel():
# evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
logloss_without_weights = evals_result_without_weights[
"validation_0"]["logloss"]
# now use weights for the test set
np.random.seed(0)
@ -503,13 +502,13 @@ def test_validation_weights_xgbmodel():
evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using weights
# than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
# check that the logloss in the test set is actually different when using
# weights than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i]
for i in [0, 1]))
def test_validation_weights_xgbclassifier():
tm._skip_if_no_sklearn()
from sklearn.datasets import make_hastie_10_2
# prepare training and test data
@ -533,7 +532,8 @@ def test_validation_weights_xgbclassifier():
# evaluate logloss metric on test set *without* using weights
evals_result_without_weights = clf.evals_result()
logloss_without_weights = evals_result_without_weights["validation_0"]["logloss"]
logloss_without_weights = evals_result_without_weights[
"validation_0"]["logloss"]
# now use weights for the test set
np.random.seed(0)
@ -547,13 +547,13 @@ def test_validation_weights_xgbclassifier():
evals_result_with_weights = clf.evals_result()
logloss_with_weights = evals_result_with_weights["validation_0"]["logloss"]
# check that the logloss in the test set is actually different when using weights
# than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1]))
# check that the logloss in the test set is actually different
# when using weights than when not using them
assert all((logloss_with_weights[i] != logloss_without_weights[i]
for i in [0, 1]))
def test_save_load_model():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold
@ -576,7 +576,6 @@ def test_save_load_model():
def test_RFECV():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_boston
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
@ -587,21 +586,25 @@ def test_RFECV():
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
n_estimators=10, n_jobs=1, objective='reg:linear',
random_state=0, silent=True)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_mean_squared_error')
rfecv = RFECV(
estimator=bst, step=1, cv=3, scoring='neg_mean_squared_error')
rfecv.fit(X, y)
# Binary classification
X, y = load_breast_cancer(return_X_y=True)
bst = xgb.XGBClassifier(booster='gblinear', learning_rate=0.1,
n_estimators=10, n_jobs=1, objective='binary:logistic',
n_estimators=10, n_jobs=1,
objective='binary:logistic',
random_state=0, silent=True)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='roc_auc')
rfecv.fit(X, y)
# Multi-class classification
X, y = load_iris(return_X_y=True)
bst = xgb.XGBClassifier(base_score=0.4, booster='gblinear', learning_rate=0.1,
n_estimators=10, n_jobs=1, objective='multi:softprob',
bst = xgb.XGBClassifier(base_score=0.4, booster='gblinear',
learning_rate=0.1,
n_estimators=10, n_jobs=1,
objective='multi:softprob',
random_state=0, reg_alpha=0.001, reg_lambda=0.01,
scale_pos_weight=0.5, silent=True)
rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss')

View File

@ -1,27 +1,28 @@
# coding: utf-8
import nose
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
def _skip_if_no_sklearn():
if not SKLEARN_INSTALLED:
raise nose.SkipTest()
def no_sklearn():
return {'condition': not SKLEARN_INSTALLED,
'reason': 'Scikit-Learn is not installed'}
def _skip_if_no_pandas():
if not PANDAS_INSTALLED:
raise nose.SkipTest()
def no_pandas():
return {'condition': not PANDAS_INSTALLED,
'reason': 'Pandas is not installed.'}
def _skip_if_no_dt():
if not DT_INSTALLED:
raise nose.SkipTest()
def no_dt():
return {'condition': not DT_INSTALLED,
'reason': 'Datatable is not installed.'}
def _skip_if_no_matplotlib():
def no_matplotlib():
reason = 'Matplotlib is not installed.'
try:
import matplotlib.pyplot as _ # noqa
return {'condition': False,
'reason': reason}
except ImportError:
raise nose.SkipTest()
return {'condition': True,
'reason': reason}

View File

@ -53,7 +53,7 @@ if [ ${TASK} == "python_test" ]; then
echo "-------------------------------"
source activate python3
python --version
conda install numpy scipy pandas matplotlib nose scikit-learn
conda install numpy scipy pandas matplotlib scikit-learn
# Install data table from source
wget http://releases.llvm.org/5.0.2/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
@ -62,15 +62,15 @@ if [ ${TASK} == "python_test" ]; then
python -m pip install datatable --no-binary datatable
python -m pip install graphviz pytest pytest-cov codecov
python -m nose -v tests/python || exit -1
py.test tests/python --cov=python-package/xgboost
py.test -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
codecov
source activate python2
echo "-------------------------------"
python --version
conda install numpy scipy pandas matplotlib nose scikit-learn
conda install numpy scipy pandas matplotlib scikit-learn
python -m pip install graphviz
python -m nose -v tests/python || exit -1
py.test -v --fulltrace -s tests/python || exit -1
exit 0
fi
@ -79,17 +79,15 @@ if [ ${TASK} == "python_lightweight_test" ]; then
echo "-------------------------------"
source activate python3
python --version
conda install numpy scipy nose
conda install numpy scipy
python -m pip install graphviz pytest pytest-cov codecov
python -m nose -v tests/python || exit -1
py.test tests/python --cov=python-package/xgboost
py.test -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
codecov
source activate python2
echo "-------------------------------"
python --version
conda install numpy scipy nose
conda install numpy scipy
python -m pip install graphviz
python -m nose -v tests/python || exit -1
python -m pip install flake8==3.4.1
flake8 --ignore E501 python-package || exit -1
flake8 --ignore E501 tests/python || exit -1