Refactor python tests (#3410)

* Add unit test utility

* Refactor updater tests. Add coverage for histmaker.
This commit is contained in:
Rory Mitchell 2018-06-27 11:20:27 +12:00 committed by GitHub
parent 0988fb191f
commit a0a1df1aba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 258 additions and 379 deletions

View File

@ -1,131 +1,41 @@
from __future__ import print_function
import sys
sys.path.append("../../tests/python")
import xgboost as xgb
import numpy as np import numpy as np
import sys
import unittest import unittest
from nose.plugins.attrib import attr
from sklearn.datasets import load_digits, load_boston, load_breast_cancer, make_regression
import itertools as it
rng = np.random.RandomState(1994) sys.path.append("tests/python")
import xgboost as xgb
from regression_test_utilities import run_suite, parameter_combinations, \
assert_results_non_increasing
def non_increasing(L, tolerance): def assert_gpu_results(cpu_results, gpu_results):
return all((y - x) < tolerance for x, y in zip(L, L[1:])) for cpu_res, gpu_res in zip(cpu_results, gpu_results):
# Check final eval result roughly equivalent
assert np.allclose(cpu_res["eval"][-1], gpu_res["eval"][-1], 1e-3, 1e-2)
# Check result is always decreasing and final accuracy is within tolerance datasets = ["Boston", "Cancer", "Digits", "Sparse regression"]
def assert_accuracy(res, tree_method, comparison_tree_method, tolerance, param):
assert non_increasing(res[tree_method], tolerance)
assert np.allclose(res[tree_method][-1], res[comparison_tree_method][-1], 1e-3, 1e-2)
def train_boston(param_in, comparison_tree_method):
data = load_boston()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param.update(param_in)
param['max_depth'] = 2
res_tmp = {}
res = {}
num_rounds = 10
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['rmse']
param["tree_method"] = comparison_tree_method
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['rmse']
return res
def train_digits(param_in, comparison_tree_method):
data = load_digits()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param['objective'] = 'multi:softmax'
param['num_class'] = 10
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['merror']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['merror']
return res
def train_cancer(param_in, comparison_tree_method):
data = load_breast_cancer()
dtrain = xgb.DMatrix(data.data, label=data.target)
param = {}
param['objective'] = 'binary:logistic'
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['error']
param["tree_method"] = comparison_tree_method
xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['error']
return res
def train_sparse(param_in, comparison_tree_method):
n = 5000
sparsity = 0.75
X, y = make_regression(n, random_state=rng)
X = np.array([[np.nan if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
dtrain = xgb.DMatrix(X, label=y)
param = {}
param.update(param_in)
res_tmp = {}
res = {}
num_rounds = 10
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[param['tree_method']] = res_tmp['train']['rmse']
param["tree_method"] = comparison_tree_method
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')], evals_result=res_tmp)
res[comparison_tree_method] = res_tmp['train']['rmse']
return res
# Enumerates all permutations of variable parameters
def assert_updater_accuracy(tree_method, comparison_tree_method, variable_param, tolerance):
param = {'tree_method': tree_method}
names = sorted(variable_param)
combinations = it.product(*(variable_param[Name] for Name in names))
for set in combinations:
print(names, file=sys.stderr)
print(set, file=sys.stderr)
param_tmp = param.copy()
for i, name in enumerate(names):
param_tmp[name] = set[i]
print(param_tmp, file=sys.stderr)
assert_accuracy(train_boston(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance,
param_tmp)
assert_accuracy(train_digits(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance,
param_tmp)
assert_accuracy(train_cancer(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance,
param_tmp)
assert_accuracy(train_sparse(param_tmp, comparison_tree_method), tree_method, comparison_tree_method, tolerance,
param_tmp)
@attr('gpu')
class TestGPU(unittest.TestCase): class TestGPU(unittest.TestCase):
def test_gpu_exact(self): def test_gpu_exact(self):
variable_param = {'max_depth': [2, 6, 15]} variable_param = {'max_depth': [2, 6, 15], }
assert_updater_accuracy('gpu_exact', 'exact', variable_param, 0.02) for param in parameter_combinations(variable_param):
param['tree_method'] = 'gpu_exact'
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)
param['tree_method'] = 'exact'
cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results)
def test_gpu_hist(self): def test_gpu_hist(self):
variable_param = {'n_gpus': [1, -1], 'max_depth': [2, 6], 'max_leaves': [255, 4], 'max_bin': [2, 16, 1024], variable_param = {'n_gpus': [1, -1], 'max_depth': [2, 6], 'max_leaves': [255, 4],
'max_bin': [2, 16, 1024],
'grow_policy': ['depthwise', 'lossguide']} 'grow_policy': ['depthwise', 'lossguide']}
assert_updater_accuracy('gpu_hist', 'hist', variable_param, 0.01) for param in parameter_combinations(variable_param):
param['tree_method'] = 'gpu_hist'
gpu_results = run_suite(param, select_datasets=datasets)
assert_results_non_increasing(gpu_results, 1e-2)
param['tree_method'] = 'hist'
cpu_results = run_suite(param, select_datasets=datasets)
assert_gpu_results(cpu_results, gpu_results)

View File

@ -0,0 +1,134 @@
from __future__ import print_function
import glob
import itertools as it
import numpy as np
import os
import sys
import xgboost as xgb
try:
from sklearn import datasets
from sklearn.preprocessing import scale
except ImportError:
None
class Dataset:
def __init__(self, name, get_dataset, objective, metric, use_external_memory=False):
self.name = name
self.objective = objective
self.metric = metric
self.X, self.y = get_dataset()
self.use_external_memory = use_external_memory
def get_boston():
data = datasets.load_boston()
return data.data, data.target
def get_digits():
data = datasets.load_digits()
return data.data, data.target
def get_cancer():
data = datasets.load_breast_cancer()
return data.data, data.target
def get_sparse():
rng = np.random.RandomState(199)
n = 5000
sparsity = 0.75
X, y = datasets.make_regression(n, random_state=rng)
X = np.array([[0.0 if rng.uniform(0, 1) < sparsity else x for x in x_row] for x_row in X])
from scipy import sparse
X = sparse.csr_matrix(X)
return X, y
def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
param = param_in.copy()
param["objective"] = dataset.objective
if dataset.objective == "multi:softmax":
param["num_class"] = int(np.max(dataset.y) + 1)
param["eval_metric"] = dataset.metric
if scale_features:
X = scale(dataset.X, with_mean=isinstance(dataset.X, np.ndarray))
else:
X = dataset.X
if dataset.use_external_memory:
np.savetxt('tmptmp_1234.csv', np.hstack((dataset.y.reshape(len(dataset.y), 1), X)),
delimiter=',')
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
else:
dtrain = xgb.DMatrix(X, dataset.y)
print("Training on dataset: " + dataset.name, file=sys.stderr)
print("Using parameters: " + str(param), file=sys.stderr)
res = {}
bst = xgb.train(param, dtrain, num_rounds, [(dtrain, 'train')],
evals_result=res, verbose_eval=False)
# Free the booster and dmatrix so we can delete temporary files
bst_copy = bst.copy()
del bst
del dtrain
# Cleanup temporary files
if dataset.use_external_memory:
for f in glob.glob("tmptmp_*"):
os.remove(f)
return {"dataset": dataset, "bst": bst_copy, "param": param.copy(),
"eval": res['train'][dataset.metric]}
def parameter_combinations(variable_param):
"""
Enumerate all possible combinations of parameters
"""
result = []
names = sorted(variable_param)
combinations = it.product(*(variable_param[Name] for Name in names))
for set in combinations:
param = {}
for i, name in enumerate(names):
param[name] = set[i]
result.append(param)
return result
def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
"""
Run the given parameters on a range of datasets. Objective and eval metric will be automatically set
"""
datasets = [
Dataset("Boston", get_boston, "reg:linear", "rmse"),
Dataset("Digits", get_digits, "multi:softmax", "merror"),
Dataset("Cancer", get_cancer, "binary:logistic", "error"),
Dataset("Sparse regression", get_sparse, "reg:linear", "rmse"),
Dataset("Boston External Memory", get_boston, "reg:linear", "rmse",
use_external_memory=True)
]
results = [
]
for d in datasets:
if select_datasets is None or d.name in select_datasets:
results.append(
train_dataset(d, param, num_rounds=num_rounds, scale_features=scale_features))
return results
def non_increasing(L, tolerance):
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
def assert_results_non_increasing(results, tolerance=1e-5):
for r in results:
assert non_increasing(r['eval'], tolerance), r

View File

@ -1,136 +0,0 @@
import xgboost as xgb
import testing as tm
import numpy as np
import unittest
rng = np.random.RandomState(1994)
class TestFastHist(unittest.TestCase):
def test_fast_hist(self):
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
try:
from sklearn.model_selection import train_test_split
except:
from sklearn.cross_validation import train_test_split
# regression test --- hist must be same as exact on all-categorial data
dpath = 'demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
ag_param = {'max_depth': 2,
'tree_method': 'exact',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_param2 = {'max_depth': 2,
'tree_method': 'hist',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
ag_res = {}
ag_res2 = {}
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res)
xgb.train(ag_param2, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=ag_res2)
assert ag_res['train']['auc'] == ag_res2['train']['auc']
assert ag_res['test']['auc'] == ag_res2['test']['auc']
digits = load_digits(2)
X = digits['data']
y = digits['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = xgb.DMatrix(X_train, y_train)
dtest = xgb.DMatrix(X_test, y_test)
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 3,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
param2 = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'lossguide',
'max_depth': 0,
'max_leaves': 8,
'eval_metric': 'auc'}
res = {}
xgb.train(param2, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert self.non_decreasing(res['test']['auc'])
param3 = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'lossguide',
'max_depth': 0,
'max_leaves': 8,
'max_bin': 16,
'eval_metric': 'auc'}
res = {}
xgb.train(param3, dtrain, 10, [(dtrain, 'train'), (dtest, 'test')],
evals_result=res)
assert self.non_decreasing(res['train']['auc'])
# fail-safe test for dense data
from sklearn.datasets import load_svmlight_file
dpath = 'demo/data/'
X2, y2 = load_svmlight_file(dpath + 'agaricus.txt.train')
X2 = X2.toarray()
dtrain2 = xgb.DMatrix(X2, label=y2)
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc'}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 2
dtrain3 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain3, 10, [(dtrain3, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
for j in range(X2.shape[1]):
for i in np.random.choice(X2.shape[0], size=10, replace=False):
X2[i, j] = 3
dtrain4 = xgb.DMatrix(X2, label=y2)
res = {}
xgb.train(param, dtrain4, 10, [(dtrain4, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
# fail-safe test for max_bin=2
param = {'objective': 'binary:logistic',
'tree_method': 'hist',
'grow_policy': 'depthwise',
'max_depth': 2,
'eval_metric': 'auc',
'max_bin': 2}
res = {}
xgb.train(param, dtrain2, 10, [(dtrain2, 'train')], evals_result=res)
assert self.non_decreasing(res['train']['auc'])
assert res['train']['auc'][0] >= 0.85
def non_decreasing(self, L):
return all(x <= y for x, y in zip(L, L[1:]))

View File

@ -1,24 +1,17 @@
from __future__ import print_function from __future__ import print_function
import itertools as it
import numpy as np import numpy as np
import sys
import os
import glob
import testing as tm import testing as tm
import unittest import unittest
import xgboost as xgb import xgboost as xgb
try: try:
from sklearn import metrics, datasets
from sklearn.linear_model import ElasticNet from sklearn.linear_model import ElasticNet
from sklearn.preprocessing import scale from sklearn.preprocessing import scale
from regression_test_utilities import run_suite, parameter_combinations
except ImportError: except ImportError:
None None
rng = np.random.RandomState(199)
num_rounds = 1000
def is_float(s): def is_float(s):
try: try:
@ -32,130 +25,53 @@ def xgb_get_weights(bst):
return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)]) return np.array([float(s) for s in bst.get_dump()[0].split() if is_float(s)])
def check_ElasticNet(X, y, pred, tol, reg_alpha, reg_lambda, weights): def assert_regression_result(results, tol):
enet = ElasticNet(alpha=reg_alpha + reg_lambda, regression_results = [r for r in results if r["param"]["objective"] == "reg:linear"]
l1_ratio=reg_alpha / (reg_alpha + reg_lambda)) for res in regression_results:
enet.fit(X, y) X = scale(res["dataset"].X, with_mean=isinstance(res["dataset"].X, np.ndarray))
enet_pred = enet.predict(X) y = res["dataset"].y
assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all() reg_alpha = res["param"]["alpha"]
assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all() reg_lambda = res["param"]["lambda"]
pred = res["bst"].predict(xgb.DMatrix(X))
weights = xgb_get_weights(res["bst"])[1:]
enet = ElasticNet(alpha=reg_alpha + reg_lambda,
l1_ratio=reg_alpha / (reg_alpha + reg_lambda))
enet.fit(X, y)
enet_pred = enet.predict(X)
assert np.isclose(weights, enet.coef_, rtol=tol, atol=tol).all(), (weights, enet.coef_)
assert np.isclose(enet_pred, pred, rtol=tol, atol=tol).all(), (
res["dataset"].name, enet_pred[:5], pred[:5])
def train_diabetes(param_in): # TODO: More robust classification tests
data = datasets.load_diabetes() def assert_classification_result(results):
X = scale(data.data) classification_results = [r for r in results if r["param"]["objective"] != "reg:linear"]
dtrain = xgb.DMatrix(X, label=data.target) for res in classification_results:
param = {} # Check accuracy is reasonable
param.update(param_in) assert res["eval"][-1] < 0.5, (res["dataset"].name, res["eval"][-1])
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
param['alpha'], param['lambda'],
xgb_get_weights(bst)[1:])
def train_breast_cancer(param_in):
data = datasets.load_breast_cancer()
X = scale(data.data)
dtrain = xgb.DMatrix(X, label=data.target)
param = {'objective': 'binary:logistic'}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
xgb_score = metrics.accuracy_score(data.target, np.round(xgb_pred))
assert xgb_score >= 0.8
def train_classification(param_in):
X, y = datasets.make_classification(random_state=rng)
X = scale(X)
dtrain = xgb.DMatrix(X, label=y)
param = {'objective': 'binary:logistic'}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
xgb_score = metrics.accuracy_score(y, np.round(xgb_pred))
assert xgb_score >= 0.8
def train_classification_multi(param_in):
num_class = 3
X, y = datasets.make_classification(n_samples=100, random_state=rng,
n_classes=num_class, n_informative=4,
n_features=4, n_redundant=0)
X = scale(X)
dtrain = xgb.DMatrix(X, label=y)
param = {'objective': 'multi:softmax', 'num_class': num_class}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
xgb_score = metrics.accuracy_score(y, np.round(xgb_pred))
assert xgb_score >= 0.50
def train_boston(param_in):
data = datasets.load_boston()
X = scale(data.data)
dtrain = xgb.DMatrix(X, label=data.target)
param = {}
param.update(param_in)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
check_ElasticNet(X, data.target, xgb_pred, 1e-2,
param['alpha'], param['lambda'],
xgb_get_weights(bst)[1:])
def train_external_mem(param_in):
data = datasets.load_boston()
X = scale(data.data)
y = data.target
param = {}
param.update(param_in)
dtrain = xgb.DMatrix(X, label=y)
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred = bst.predict(dtrain)
np.savetxt('tmptmp_1234.csv', np.hstack((y.reshape(len(y), 1), X)),
delimiter=',', fmt='%10.9f')
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_')
bst = xgb.train(param, dtrain, num_rounds)
xgb_pred_ext = bst.predict(dtrain)
assert np.abs(xgb_pred_ext - xgb_pred).max() < 1e-3
del dtrain, bst
for f in glob.glob("tmptmp_*"):
os.remove(f)
# Enumerates all permutations of variable parameters
def assert_updater_accuracy(linear_updater, variable_param):
param = {'booster': 'gblinear', 'updater': linear_updater, 'eta': 1.,
'top_k': 10, 'tolerance': 1e-5, 'nthread': 2}
names = sorted(variable_param)
combinations = it.product(*(variable_param[Name] for Name in names))
for set in combinations:
param_tmp = param.copy()
for i, name in enumerate(names):
param_tmp[name] = set[i]
print(param_tmp, file=sys.stderr)
train_boston(param_tmp)
train_diabetes(param_tmp)
train_classification(param_tmp)
train_classification_multi(param_tmp)
train_breast_cancer(param_tmp)
if 'gpu' not in linear_updater:
train_external_mem(param_tmp)
class TestLinear(unittest.TestCase): class TestLinear(unittest.TestCase):
def test_coordinate(self): def test_coordinate(self):
tm._skip_if_no_sklearn() tm._skip_if_no_sklearn()
variable_param = {'alpha': [.005, .1], 'lambda': [.005], variable_param = {'booster': ['gblinear'], 'updater': ['coord_descent'], 'eta': [0.5],
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']} 'top_k': [10], 'tolerance': [1e-5], 'nthread': [2],
assert_updater_accuracy('coord_descent', variable_param) 'alpha': [.005, .1], 'lambda': [.005],
'feature_selector': ['cyclic', 'shuffle', 'greedy', 'thrifty']
}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, None, scale_features=True)
assert_regression_result(results, 1e-2)
assert_classification_result(results)
def test_shotgun(self): def test_shotgun(self):
tm._skip_if_no_sklearn() tm._skip_if_no_sklearn()
variable_param = {'alpha': [.005, .1], 'lambda': [.005, .1]} variable_param = {'booster': ['gblinear'], 'updater': ['shotgun'], 'eta': [0.5],
assert_updater_accuracy('shotgun', variable_param) 'top_k': [10], 'tolerance': [1e-5], 'nthread': [2],
'alpha': [.005, .1], 'lambda': [.005],
'feature_selector': ['cyclic', 'shuffle']
}
for param in parameter_combinations(variable_param):
results = run_suite(param, 200, None, True)
assert_regression_result(results, 1e-2)
assert_classification_result(results)

View File

@ -0,0 +1,55 @@
import testing as tm
import unittest
import xgboost as xgb
try:
from regression_test_utilities import run_suite, parameter_combinations, \
assert_results_non_increasing
except ImportError:
None
class TestUpdaters(unittest.TestCase):
def test_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_histmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
def test_colmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'updater': ['grow_colmaker'], 'max_depth': [2, 8]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
def test_fast_histmaker(self):
tm._skip_if_no_sklearn()
variable_param = {'tree_method': ['hist'], 'max_depth': [2, 8], 'max_bin': [2, 256],
'grow_policy': ['depthwise', 'lossguide'], 'max_leaves': [64, 0],
'silent': [1]}
for param in parameter_combinations(variable_param):
result = run_suite(param)
assert_results_non_increasing(result, 1e-2)
# hist must be same as exact on all-categorial data
dpath = 'demo/data/'
ag_dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
ag_dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
ag_param = {'max_depth': 2,
'tree_method': 'hist',
'eta': 1,
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc'}
hist_res = {}
exact_res = {}
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=hist_res)
ag_param["tree_method"] = "exact"
xgb.train(ag_param, ag_dtrain, 10, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
evals_result=exact_res)
assert hist_res['train']['auc'] == exact_res['train']['auc']
assert hist_res['test']['auc'] == exact_res['test']['auc']