Refactor Python tests. (#3897)

* Deprecate nose tests.
* Format python tests.
This commit is contained in:
Jiaming Yuan
2018-11-15 13:56:33 +13:00
committed by GitHub
parent c76d993681
commit 2ea0f887c1
23 changed files with 302 additions and 225 deletions

View File

@@ -1,5 +1,6 @@
import testing as tm
import unittest
import pytest
import xgboost as xgb
try:
@@ -10,24 +11,27 @@ except ImportError:
class TestUpdaters(unittest.TestCase):
@pytest.mark.skipif(**tm.no_sklearn())
def test_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_histmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_colmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_colmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
@pytest.mark.skipif(**tm.no_sklearn())
def test_fast_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'tree_method': ['hist'], 'max_depth': [2, 8], 'max_bin': [2, 256],
'grow_policy': ['depthwise', 'lossguide'], 'max_leaves': [64, 0],
variable_param = {'tree_method': ['hist'],
'max_depth': [2, 8],
'max_bin': [2, 256],
'grow_policy': ['depthwise', 'lossguide'],
'max_leaves': [64, 0],
'silent': [1]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
@@ -46,10 +50,12 @@ class TestUpdaters(unittest.TestCase):
hist_res = {}
exact_res = {}
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param, ag_dtrain, 10,
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=hist_res)
ag_param["tree_method"] = "exact"
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
xgb.train(ag_param, ag_dtrain, 10,
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=exact_res)
assert hist_res['train']['auc'] == exact_res['train']['auc']
assert hist_res['test']['auc'] == exact_res['test']['auc']