Add use_rmm flag to global configuration (#6656)
* Ensure RMM is 0.18 or later * Add use_rmm flag to global configuration * Modify XGBCachingDeviceAllocatorImpl to skip CUB when use_rmm=True * Update the demo * [CI] Pin NumPy to 1.19.4, since NumPy 1.19.5 doesn't work with latest Shap
This commit is contained in:
committed by
GitHub
parent
e4894111ba
commit
366f3cb9d8
@@ -834,9 +834,15 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None:
|
||||
|
||||
|
||||
class TestWithDask:
|
||||
def test_global_config(self, client: "Client") -> None:
|
||||
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||
def test_global_config(
|
||||
self,
|
||||
client: "Client",
|
||||
config_key: str,
|
||||
config_value: Any
|
||||
) -> None:
|
||||
X, y, _ = generate_array()
|
||||
xgb.config.set_config(verbosity=0)
|
||||
xgb.config.set_config(**{config_key: config_value})
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
before_fname = './before_training-test_global_config'
|
||||
after_fname = './after_training-test_global_config'
|
||||
@@ -844,36 +850,36 @@ class TestWithDask:
|
||||
class TestCallback(xgb.callback.TrainingCallback):
|
||||
def write_file(self, fname: str) -> None:
|
||||
with open(fname, 'w') as fd:
|
||||
fd.write(str(xgb.config.get_config()['verbosity']))
|
||||
fd.write(str(xgb.config.get_config()[config_key]))
|
||||
|
||||
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||
self.write_file(before_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
assert xgb.config.get_config()[config_key] == config_value
|
||||
return model
|
||||
|
||||
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
assert xgb.config.get_config()[config_key] == config_value
|
||||
return model
|
||||
|
||||
def before_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||
) -> bool:
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
assert xgb.config.get_config()[config_key] == config_value
|
||||
return False
|
||||
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||
) -> bool:
|
||||
self.write_file(after_fname)
|
||||
assert xgb.config.get_config()['verbosity'] == 0
|
||||
assert xgb.config.get_config()[config_key] == config_value
|
||||
return False
|
||||
|
||||
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
|
||||
'booster']
|
||||
|
||||
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
|
||||
assert before.read() == '0'
|
||||
assert after.read() == '0'
|
||||
assert before.read() == str(config_value)
|
||||
assert after.read() == str(config_value)
|
||||
|
||||
os.remove(before_fname)
|
||||
os.remove(after_fname)
|
||||
|
||||
Reference in New Issue
Block a user