Add use_rmm flag to global configuration (#6656)

* Ensure RMM is 0.18 or later

* Add use_rmm flag to global configuration

* Modify XGBCachingDeviceAllocatorImpl to skip CUB when use_rmm=True

* Update the demo

* [CI] Pin NumPy to 1.19.4, since NumPy 1.19.5 doesn't work with latest Shap
This commit is contained in:
Philip Hyunsu Cho
2021-03-09 14:53:05 -08:00
committed by GitHub
parent e4894111ba
commit 366f3cb9d8
12 changed files with 117 additions and 20 deletions

View File

@@ -220,7 +220,8 @@ TEST(CAPI, XGBGlobalConfig) {
{
const char *config_str = R"json(
{
"verbosity": 0
"verbosity": 0,
"use_rmm": false
}
)json";
ret = XGBSetGlobalConfig(config_str);
@@ -233,6 +234,24 @@ TEST(CAPI, XGBGlobalConfig) {
auto updated_config =
Json::Load({updated_config_str.data(), updated_config_str.size()});
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
ASSERT_EQ(get<Boolean>(updated_config["use_rmm"]), false);
}
{
const char *config_str = R"json(
{
"use_rmm": true
}
)json";
ret = XGBSetGlobalConfig(config_str);
ASSERT_EQ(ret, 0);
const char *updated_config_cstr;
ret = XGBGetGlobalConfig(&updated_config_cstr);
ASSERT_EQ(ret, 0);
std::string updated_config_str{updated_config_cstr};
auto updated_config =
Json::Load({updated_config_str.data(), updated_config_str.size()});
ASSERT_EQ(get<Boolean>(updated_config["use_rmm"]), true);
}
{
const char *config_str = R"json(