Add global configuration (#6414)
* Add management functions for global configuration: XGBSetGlobalConfig(), XGBGetGlobalConfig(). * Add Python interface: set_config(), get_config(), and config_context(). * Add unit tests for Python * Add R interface: xgb.set.config(), xgb.get.config() * Add unit tests for R Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
committed by
GitHub
parent
c2ba4fb957
commit
fb56da5e8b
16
tests/python/test_config.py
Normal file
16
tests/python/test_config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import testing as tm
|
||||
|
||||
|
||||
@pytest.mark.parametrize('verbosity_level', [0, 1, 2, 3])
|
||||
def test_global_config_verbosity(verbosity_level):
|
||||
def get_current_verbosity():
|
||||
return xgb.get_config()['verbosity']
|
||||
|
||||
old_verbosity = get_current_verbosity()
|
||||
with xgb.config_context(verbosity=verbosity_level):
|
||||
new_verbosity = get_current_verbosity()
|
||||
assert new_verbosity == verbosity_level
|
||||
assert old_verbosity == get_current_verbosity()
|
||||
@@ -637,6 +637,46 @@ def test_aft_survival():
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_global_config(self, client):
|
||||
X, y = generate_array()
|
||||
xgb.config.set_config(verbosity=0)
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
before_fname = './before_training-test_global_config'
|
||||
after_fname = './after_training-test_global_config'
|
||||
|
||||
class TestCallback(xgb.callback.TrainingCallback):
|
||||
def write_file(self, fname):
|
||||
with open(fname, 'w') as fd:
|
||||
fd.write(str(xgb.config.get_config()['verbosity']))
|
||||
|
||||
def before_training(self, model):
|
||||
self.write_file(before_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
self.write_file(after_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
return False
|
||||
|
||||
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
|
||||
'booster']
|
||||
|
||||
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
|
||||
assert before.read() == '0'
|
||||
assert after.read() == '0'
|
||||
|
||||
os.remove(before_fname)
|
||||
os.remove(after_fname)
|
||||
|
||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||
tree_method):
|
||||
params['tree_method'] = tree_method
|
||||
|
||||
Reference in New Issue
Block a user