Refactor Python tests. (#3897)
* Deprecate nose tests. * Format python tests.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import testing as tm
|
||||
import unittest
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
|
||||
try:
|
||||
@@ -10,24 +11,27 @@ except ImportError:
|
||||
|
||||
|
||||
class TestUpdaters(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_histmaker(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'updater': ['grow_histmaker'], 'max_depth': [2, 8]}
|
||||
for param in parameter_combinations(variable_param):
|
||||
result = run_suite(param)
|
||||
assert_results_non_increasing(result, 1e-2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_colmaker(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'updater': ['grow_colmaker'], 'max_depth': [2, 8]}
|
||||
for param in parameter_combinations(variable_param):
|
||||
result = run_suite(param)
|
||||
assert_results_non_increasing(result, 1e-2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_fast_histmaker(self):
|
||||
tm._skip_if_no_sklearn()
|
||||
variable_param = {'tree_method': ['hist'], 'max_depth': [2, 8], 'max_bin': [2, 256],
|
||||
'grow_policy': ['depthwise', 'lossguide'], 'max_leaves': [64, 0],
|
||||
variable_param = {'tree_method': ['hist'],
|
||||
'max_depth': [2, 8],
|
||||
'max_bin': [2, 256],
|
||||
'grow_policy': ['depthwise', 'lossguide'],
|
||||
'max_leaves': [64, 0],
|
||||
'silent': [1]}
|
||||
for param in parameter_combinations(variable_param):
|
||||
result = run_suite(param)
|
||||
@@ -46,10 +50,12 @@ class TestUpdaters(unittest.TestCase):
|
||||
hist_res = {}
|
||||
exact_res = {}
|
||||
|
||||
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
xgb.train(ag_param, ag_dtrain, 10,
|
||||
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=hist_res)
|
||||
ag_param["tree_method"] = "exact"
|
||||
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
xgb.train(ag_param, ag_dtrain, 10,
|
||||
[(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
||||
evals_result=exact_res)
|
||||
assert hist_res['train']['auc'] == exact_res['train']['auc']
|
||||
assert hist_res['test']['auc'] == exact_res['test']['auc']
|
||||
|
||||
Reference in New Issue
Block a user