Add use_rmm flag to global configuration (#6656)
* Ensure RMM is 0.18 or later * Add use_rmm flag to global configuration * Modify XGBCachingDeviceAllocatorImpl to skip CUB when use_rmm=True * Update the demo * [CI] Pin NumPy to 1.19.4, since NumPy 1.19.5 doesn't work with latest Shap
This commit is contained in:
parent
e4894111ba
commit
366f3cb9d8
@ -9,3 +9,13 @@ test_that('Global configuration works with verbosity', {
|
|||||||
xgb.set.config(verbosity = old_verbosity)
|
xgb.set.config(verbosity = old_verbosity)
|
||||||
expect_equal(xgb.get.config()$verbosity, old_verbosity)
|
expect_equal(xgb.get.config()$verbosity, old_verbosity)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that('Global configuration works with use_rmm flag', {
|
||||||
|
old_use_rmm_flag <- xgb.get.config()$use_rmm
|
||||||
|
for (v in c(TRUE, FALSE)) {
|
||||||
|
xgb.set.config(use_rmm = v)
|
||||||
|
expect_equal(xgb.get.config()$use_rmm, v)
|
||||||
|
}
|
||||||
|
xgb.set.config(use_rmm = old_use_rmm_flag)
|
||||||
|
expect_equal(xgb.get.config()$use_rmm, old_use_rmm_flag)
|
||||||
|
})
|
||||||
|
|||||||
@ -5,13 +5,16 @@ from dask.distributed import Client
|
|||||||
from dask_cuda import LocalCUDACluster
|
from dask_cuda import LocalCUDACluster
|
||||||
|
|
||||||
def main(client):
|
def main(client):
|
||||||
|
# Inform XGBoost that RMM is used for GPU memory allocation
|
||||||
|
xgb.set_config(use_rmm=True)
|
||||||
|
|
||||||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
||||||
X = dask.array.from_array(X)
|
X = dask.array.from_array(X)
|
||||||
y = dask.array.from_array(y)
|
y = dask.array.from_array(y)
|
||||||
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)
|
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)
|
||||||
|
|
||||||
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
|
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
|
||||||
'tree_method': 'gpu_hist'}
|
'tree_method': 'gpu_hist', 'eval_metric': 'merror'}
|
||||||
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
|
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
|
||||||
evals=[(dtrain, 'train')])
|
evals=[(dtrain, 'train')])
|
||||||
bst = output['booster']
|
bst = output['booster']
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from sklearn.datasets import make_classification
|
|||||||
|
|
||||||
# Initialize RMM pool allocator
|
# Initialize RMM pool allocator
|
||||||
rmm.reinitialize(pool_allocator=True)
|
rmm.reinitialize(pool_allocator=True)
|
||||||
|
# Inform XGBoost that RMM is used for GPU memory allocation
|
||||||
|
xgb.set_config(use_rmm=True)
|
||||||
|
|
||||||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
||||||
dtrain = xgb.DMatrix(X, label=y)
|
dtrain = xgb.DMatrix(X, label=y)
|
||||||
|
|||||||
@ -22,6 +22,7 @@ Global Configuration
|
|||||||
The following parameters can be set in the global scope, using ``xgb.config_context()`` (Python) or ``xgb.set.config()`` (R).
|
The following parameters can be set in the global scope, using ``xgb.config_context()`` (Python) or ``xgb.set.config()`` (R).
|
||||||
|
|
||||||
* ``verbosity``: Verbosity of printing messages. Valid values of 0 (silent), 1 (warning), 2 (info), and 3 (debug).
|
* ``verbosity``: Verbosity of printing messages. Valid values of 0 (silent), 1 (warning), 2 (info), and 3 (debug).
|
||||||
|
* ``use_rmm``: Whether to use RAPIDS Memory Manager (RMM) to allocate GPU memory. This option is only applicable when XGBoost is built (compiled) with the RMM plugin enabled. Valid values are ``true`` and ``false``.
|
||||||
|
|
||||||
******************
|
******************
|
||||||
General Parameters
|
General Parameters
|
||||||
|
|||||||
@ -16,11 +16,15 @@ class Json;
|
|||||||
|
|
||||||
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
|
struct GlobalConfiguration : public XGBoostParameter<GlobalConfiguration> {
|
||||||
int verbosity { 1 };
|
int verbosity { 1 };
|
||||||
|
bool use_rmm { false };
|
||||||
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
|
DMLC_DECLARE_PARAMETER(GlobalConfiguration) {
|
||||||
DMLC_DECLARE_FIELD(verbosity)
|
DMLC_DECLARE_FIELD(verbosity)
|
||||||
.set_range(0, 3)
|
.set_range(0, 3)
|
||||||
.set_default(1) // shows only warning
|
.set_default(1) // shows only warning
|
||||||
.describe("Flag to print out detailed breakdown of runtime.");
|
.describe("Flag to print out detailed breakdown of runtime.");
|
||||||
|
DMLC_DECLARE_FIELD(use_rmm)
|
||||||
|
.set_default(false)
|
||||||
|
.describe("Whether to use RAPIDS Memory Manager to allocate GPU memory in XGBoost");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
#include "xgboost/global_config.h"
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
@ -42,6 +43,14 @@
|
|||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
#include "rmm/mr/device/per_device_resource.hpp"
|
#include "rmm/mr/device/per_device_resource.hpp"
|
||||||
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
||||||
|
#include "rmm/version_config.hpp"
|
||||||
|
|
||||||
|
#if !defined(RMM_VERSION_MAJOR) || !defined(RMM_VERSION_MINOR)
|
||||||
|
#error "Please use RMM version 0.18 or later"
|
||||||
|
#elif RMM_VERSION_MAJOR == 0 && RMM_VERSION_MINOR < 18
|
||||||
|
#error "Please use RMM version 0.18 or later"
|
||||||
|
#endif // !defined(RMM_VERSION_MAJOR) || !defined(RMM_VERSION_MINOR)
|
||||||
|
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
|
|
||||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
|
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
|
||||||
@ -453,21 +462,42 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
|||||||
return *allocator;
|
return *allocator;
|
||||||
}
|
}
|
||||||
pointer allocate(size_t n) { // NOLINT
|
pointer allocate(size_t n) { // NOLINT
|
||||||
T* ptr;
|
pointer thrust_ptr;
|
||||||
auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
if (use_cub_allocator_) {
|
||||||
n * sizeof(T));
|
T* raw_ptr{nullptr};
|
||||||
if (errc != cudaSuccess) {
|
auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&raw_ptr),
|
||||||
ThrowOOMError("Caching allocator", n * sizeof(T));
|
n * sizeof(T));
|
||||||
|
if (errc != cudaSuccess) {
|
||||||
|
ThrowOOMError("Caching allocator", n * sizeof(T));
|
||||||
|
}
|
||||||
|
thrust_ptr = pointer(raw_ptr);
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
thrust_ptr = SuperT::allocate(n);
|
||||||
|
dh::safe_cuda(cudaGetLastError());
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
ThrowOOMError(e.what(), n * sizeof(T));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
pointer thrust_ptr{ ptr };
|
|
||||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
||||||
return thrust_ptr;
|
return thrust_ptr;
|
||||||
}
|
}
|
||||||
void deallocate(pointer ptr, size_t n) { // NOLINT
|
void deallocate(pointer ptr, size_t n) { // NOLINT
|
||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
if (use_cub_allocator_) {
|
||||||
|
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||||
|
} else {
|
||||||
|
SuperT::deallocate(ptr, n);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
|
XGBCachingDeviceAllocatorImpl()
|
||||||
|
: SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()),
|
||||||
|
use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {}
|
||||||
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
||||||
|
private:
|
||||||
|
bool use_cub_allocator_{true};
|
||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ ENV PATH=/opt/python/bin:$PATH
|
|||||||
|
|
||||||
# Create new Conda environment with RMM
|
# Create new Conda environment with RMM
|
||||||
RUN \
|
RUN \
|
||||||
conda create -n gpu_test -c nvidia -c rapidsai-nightly -c rapidsai -c conda-forge -c defaults \
|
conda create -n gpu_test -c nvidia -c rapidsai -c conda-forge -c defaults \
|
||||||
python=3.7 rmm=0.18* cudatoolkit=$CUDA_VERSION_ARG
|
python=3.7 rmm=0.18* cudatoolkit=$CUDA_VERSION_ARG
|
||||||
|
|
||||||
ENV GOSU_VERSION 1.10
|
ENV GOSU_VERSION 1.10
|
||||||
|
|||||||
@ -8,7 +8,7 @@ dependencies:
|
|||||||
- pyyaml
|
- pyyaml
|
||||||
- cpplint
|
- cpplint
|
||||||
- pylint
|
- pylint
|
||||||
- numpy
|
- numpy=1.19.4
|
||||||
- scipy
|
- scipy
|
||||||
- scikit-learn
|
- scikit-learn
|
||||||
- pandas
|
- pandas
|
||||||
|
|||||||
@ -220,7 +220,8 @@ TEST(CAPI, XGBGlobalConfig) {
|
|||||||
{
|
{
|
||||||
const char *config_str = R"json(
|
const char *config_str = R"json(
|
||||||
{
|
{
|
||||||
"verbosity": 0
|
"verbosity": 0,
|
||||||
|
"use_rmm": false
|
||||||
}
|
}
|
||||||
)json";
|
)json";
|
||||||
ret = XGBSetGlobalConfig(config_str);
|
ret = XGBSetGlobalConfig(config_str);
|
||||||
@ -233,6 +234,24 @@ TEST(CAPI, XGBGlobalConfig) {
|
|||||||
auto updated_config =
|
auto updated_config =
|
||||||
Json::Load({updated_config_str.data(), updated_config_str.size()});
|
Json::Load({updated_config_str.data(), updated_config_str.size()});
|
||||||
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
|
ASSERT_EQ(get<Integer>(updated_config["verbosity"]), 0);
|
||||||
|
ASSERT_EQ(get<Boolean>(updated_config["use_rmm"]), false);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const char *config_str = R"json(
|
||||||
|
{
|
||||||
|
"use_rmm": true
|
||||||
|
}
|
||||||
|
)json";
|
||||||
|
ret = XGBSetGlobalConfig(config_str);
|
||||||
|
ASSERT_EQ(ret, 0);
|
||||||
|
const char *updated_config_cstr;
|
||||||
|
ret = XGBGetGlobalConfig(&updated_config_cstr);
|
||||||
|
ASSERT_EQ(ret, 0);
|
||||||
|
|
||||||
|
std::string updated_config_str{updated_config_cstr};
|
||||||
|
auto updated_config =
|
||||||
|
Json::Load({updated_config_str.data(), updated_config_str.size()});
|
||||||
|
ASSERT_EQ(get<Boolean>(updated_config["use_rmm"]), true);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const char *config_str = R"json(
|
const char *config_str = R"json(
|
||||||
|
|||||||
@ -19,4 +19,14 @@ TEST(GlobalConfiguration, Verbosity) {
|
|||||||
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
|
EXPECT_EQ(get<String>(current_config["verbosity"]), "0");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GlobalConfiguration, UseRMM) {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["use_rmm"] = String("true");
|
||||||
|
auto& global_config = *GlobalConfigThreadLocalStore::Get();
|
||||||
|
FromJson(config, &global_config);
|
||||||
|
// GetConfig() should return updated use_rmm flag
|
||||||
|
Json current_config { ToJson(*GlobalConfigThreadLocalStore::Get()) };
|
||||||
|
EXPECT_EQ(get<String>(current_config["use_rmm"]), "1");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -14,3 +14,15 @@ def test_global_config_verbosity(verbosity_level):
|
|||||||
new_verbosity = get_current_verbosity()
|
new_verbosity = get_current_verbosity()
|
||||||
assert new_verbosity == verbosity_level
|
assert new_verbosity == verbosity_level
|
||||||
assert old_verbosity == get_current_verbosity()
|
assert old_verbosity == get_current_verbosity()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('use_rmm', [False, True])
|
||||||
|
def test_global_config_use_rmm(use_rmm):
|
||||||
|
def get_current_use_rmm_flag():
|
||||||
|
return xgb.get_config()['use_rmm']
|
||||||
|
|
||||||
|
old_use_rmm_flag = get_current_use_rmm_flag()
|
||||||
|
with xgb.config_context(use_rmm=use_rmm):
|
||||||
|
new_use_rmm_flag = get_current_use_rmm_flag()
|
||||||
|
assert new_use_rmm_flag == use_rmm
|
||||||
|
assert old_use_rmm_flag == get_current_use_rmm_flag()
|
||||||
|
|||||||
@ -834,9 +834,15 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None:
|
|||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
def test_global_config(self, client: "Client") -> None:
|
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||||
|
def test_global_config(
|
||||||
|
self,
|
||||||
|
client: "Client",
|
||||||
|
config_key: str,
|
||||||
|
config_value: Any
|
||||||
|
) -> None:
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
xgb.config.set_config(verbosity=0)
|
xgb.config.set_config(**{config_key: config_value})
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
before_fname = './before_training-test_global_config'
|
before_fname = './before_training-test_global_config'
|
||||||
after_fname = './after_training-test_global_config'
|
after_fname = './after_training-test_global_config'
|
||||||
@ -844,36 +850,36 @@ class TestWithDask:
|
|||||||
class TestCallback(xgb.callback.TrainingCallback):
|
class TestCallback(xgb.callback.TrainingCallback):
|
||||||
def write_file(self, fname: str) -> None:
|
def write_file(self, fname: str) -> None:
|
||||||
with open(fname, 'w') as fd:
|
with open(fname, 'w') as fd:
|
||||||
fd.write(str(xgb.config.get_config()['verbosity']))
|
fd.write(str(xgb.config.get_config()[config_key]))
|
||||||
|
|
||||||
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||||
self.write_file(before_fname)
|
self.write_file(before_fname)
|
||||||
assert xgb.config.get_config()['verbosity'] == 0
|
assert xgb.config.get_config()[config_key] == config_value
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
||||||
assert xgb.config.get_config()['verbosity'] == 0
|
assert xgb.config.get_config()[config_key] == config_value
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(
|
def before_iteration(
|
||||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert xgb.config.get_config()['verbosity'] == 0
|
assert xgb.config.get_config()[config_key] == config_value
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(
|
def after_iteration(
|
||||||
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
self, model: xgb.Booster, epoch: int, evals_log: Dict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
self.write_file(after_fname)
|
self.write_file(after_fname)
|
||||||
assert xgb.config.get_config()['verbosity'] == 0
|
assert xgb.config.get_config()[config_key] == config_value
|
||||||
return False
|
return False
|
||||||
|
|
||||||
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
|
xgb.dask.train(client, {}, dtrain, num_boost_round=4, callbacks=[TestCallback()])[
|
||||||
'booster']
|
'booster']
|
||||||
|
|
||||||
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
|
with open(before_fname, 'r') as before, open(after_fname, 'r') as after:
|
||||||
assert before.read() == '0'
|
assert before.read() == str(config_value)
|
||||||
assert after.read() == '0'
|
assert after.read() == str(config_value)
|
||||||
|
|
||||||
os.remove(before_fname)
|
os.remove(before_fname)
|
||||||
os.remove(after_fname)
|
os.remove(after_fname)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user