RMM integration plugin (#5873)

* [CI] Add RMM as an optional dependency

* Replace caching allocator with pool allocator from RMM

* Revert "Replace caching allocator with pool allocator from RMM"

This reverts commit e15845d4e72e890c2babe31a988b26503a7d9038.

* Use rmm::mr::get_default_resource()

* Try setting default resource (doesn't work yet)

* Allocate pool_mr in the heap

* Prevent leaking pool_mr handle

* Separate EXPECT_DEATH() in separate test suite suffixed DeathTest

* Turn off death tests for RMM

* Address reviewer's feedback

* Prevent leaking of cuda_mr

* Fix Jenkinsfile syntax

* Remove unnecessary function in Jenkinsfile

* [CI] Install NCCL into RMM container

* Run Python tests

* Try building with RMM, CUDA 10.0

* Do not use RMM for CUDA 10.0 target

* Actually test for test_rmm flag

* Fix TestPythonGPU

* Use CNMeM allocator, since pool allocator doesn't yet support multiGPU

* Use 10.0 container to build RMM-enabled XGBoost

* Revert "Use 10.0 container to build RMM-enabled XGBoost"

This reverts commit 789021fa31112e25b683aef39fff375403060141.

* Fix Jenkinsfile

* [CI] Assign larger /dev/shm to NCCL

* Use 10.2 artifact to run multi-GPU Python tests

* Add CUDA 10.0 -> 11.0 cross-version test; remove CUDA 10.0 target

* Rename Conda env rmm_test -> gpu_test

* Use env var to opt into CNMeM pool for C++ tests

* Use identical CUDA version for RMM builds and tests

* Use Pytest fixtures to enable RMM pool in Python tests

* Move RMM to plugin/CMakeLists.txt; use PLUGIN_RMM

* Use per-device MR; use command arg in gtest

* Set CMake prefix path to use Conda env

* Use 0.15 nightly version of RMM

* Remove unnecessary header

* Fix a unit test when cudf is missing

* Add RMM demos

* Remove print()

* Use HostDeviceVector in GPU predictor

* Simplify pytest setup; use LocalCUDACluster fixture

* Address reviewers' commments

Co-authored-by: Hyunsu Cho <chohyu01@cs.wasshington.edu>
This commit is contained in:
Philip Hyunsu Cho 2020-08-12 01:26:02 -07:00 committed by GitHub
parent c3ea3b7e37
commit 9adb812a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 508 additions and 140 deletions

View File

@ -60,6 +60,7 @@ address, leak, undefined and thread.")
## Plugins ## Plugins
option(PLUGIN_LZ4 "Build lz4 plugin" OFF) option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
## TODO: 1. Add check if DPC++ compiler is used for building ## TODO: 1. Add check if DPC++ compiler is used for building
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF) option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON) option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
@ -84,6 +85,9 @@ endif (R_LIB AND GOOGLE_TEST)
if (USE_AVX) if (USE_AVX)
message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.") message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.")
endif (USE_AVX) endif (USE_AVX)
if (PLUGIN_RMM AND NOT (USE_CUDA))
message(SEND_ERROR "`PLUGIN_RMM` must be enabled with `USE_CUDA` flag.")
endif (PLUGIN_RMM AND NOT (USE_CUDA))
if (ENABLE_ALL_WARNINGS) if (ENABLE_ALL_WARNINGS)
if ((NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") AND (NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) if ((NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") AND (NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
message(SEND_ERROR "ENABLE_ALL_WARNINGS is only available for Clang and GCC.") message(SEND_ERROR "ENABLE_ALL_WARNINGS is only available for Clang and GCC.")

57
Jenkinsfile vendored
View File

@ -73,7 +73,7 @@ pipeline {
'build-gpu-cuda10.0': { BuildCUDA(cuda_version: '10.0') }, 'build-gpu-cuda10.0': { BuildCUDA(cuda_version: '10.0') },
// The build-gpu-* builds below use Ubuntu image // The build-gpu-* builds below use Ubuntu image
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') }, 'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') }, 'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2', build_rmm: true) },
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') }, 'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') }, 'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') },
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') }, 'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
@ -89,11 +89,12 @@ pipeline {
script { script {
parallel ([ parallel ([
'test-python-cpu': { TestPythonCPU() }, 'test-python-cpu': { TestPythonCPU() },
'test-python-gpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2') }, // artifact_cuda_version doesn't apply to RMM tests; RMM tests will always match CUDA version between artifact and host env
'test-python-gpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', test_rmm: true) },
'test-python-gpu-cuda11.0-cross': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '11.0') }, 'test-python-gpu-cuda11.0-cross': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '11.0') },
'test-python-gpu-cuda11.0': { TestPythonGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') }, 'test-python-gpu-cuda11.0': { TestPythonGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', multi_gpu: true) }, 'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', multi_gpu: true, test_rmm: true) },
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') }, 'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', test_rmm: true) },
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') }, 'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') }, 'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') },
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') }, 'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
@ -280,6 +281,22 @@ def BuildCUDA(args) {
} }
echo 'Stashing C++ test executable (testxgboost)...' echo 'Stashing C++ test executable (testxgboost)...'
stash name: "xgboost_cpp_tests_cuda${args.cuda_version}", includes: 'build/testxgboost' stash name: "xgboost_cpp_tests_cuda${args.cuda_version}", includes: 'build/testxgboost'
if (args.build_rmm) {
echo "Build with CUDA ${args.cuda_version} and RMM"
container_type = "rmm"
docker_binary = "docker"
docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}"
sh """
rm -rf build/
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_via_cmake.sh --conda-env=gpu_test -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON ${arch_flag}
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "cd python-package && rm -rf dist/* && python setup.py bdist_wheel --universal"
${dockerRun} ${container_type} ${docker_binary} ${docker_args} python tests/ci_build/rename_whl.py python-package/dist/*.whl ${commit_id} manylinux2010_x86_64
"""
echo 'Stashing Python wheel...'
stash name: "xgboost_whl_rmm_cuda${args.cuda_version}", includes: 'python-package/dist/*.whl'
echo 'Stashing C++ test executable (testxgboost)...'
stash name: "xgboost_cpp_tests_rmm_cuda${args.cuda_version}", includes: 'build/testxgboost'
}
deleteDir() deleteDir()
} }
} }
@ -366,18 +383,15 @@ def TestPythonGPU(args) {
def container_type = "gpu" def container_type = "gpu"
def docker_binary = "nvidia-docker" def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
if (args.multi_gpu) { def mgpu_indicator = (args.multi_gpu) ? 'mgpu' : 'gpu'
echo "Using multiple GPUs" // Allocate extra space in /dev/shm to enable NCCL
// Allocate extra space in /dev/shm to enable NCCL def docker_extra_params = (args.multi_gpu) ? "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" : ''
def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator}"
sh """ if (args.test_rmm) {
${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu sh "rm -rfv build/ python-package/dist/"
""" unstash name: "xgboost_whl_rmm_cuda${args.host_cuda_version}"
} else { unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
echo "Using a single GPU" sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator} --use-rmm-pool"
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu
"""
} }
deleteDir() deleteDir()
} }
@ -408,6 +422,17 @@ def TestCppGPU(args) {
def docker_binary = "nvidia-docker" def docker_binary = "nvidia-docker"
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} build/testxgboost" sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} build/testxgboost"
if (args.test_rmm) {
sh "rm -rfv build/"
unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
echo "Test C++, CUDA ${args.host_cuda_version} with RMM"
container_type = "rmm"
docker_binary = "nvidia-docker"
docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
sh """
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "source activate gpu_test && build/testxgboost --use-rmm-pool --gtest_filter=-*DeathTest.*"
"""
}
deleteDir() deleteDir()
} }
} }

31
demo/rmm_plugin/README.md Normal file
View File

@ -0,0 +1,31 @@
Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL)
====================================================================
[RAPIDS Memory Manager (RMM)](https://github.com/rapidsai/rmm) library provides a collection of
efficient memory allocators for NVIDIA GPUs. It is now possible to use XGBoost with memory
allocators provided by RMM, by enabling the RMM integration plugin.
The demos in this directory highlights one RMM allocator in particular: **the pool sub-allocator**.
This allocator addresses the slow speed of `cudaMalloc()` by allocating a large chunk of memory
upfront. Subsequent allocations will draw from the pool of already allocated memory and thus avoid
the overhead of calling `cudaMalloc()` directly. See
[this GTC talk slides](https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf)
for more details.
Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this,
run CMake with option `-DPLUGIN_RMM=ON` (`-DUSE_CUDA=ON` also required):
```
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON
make -j4
```
CMake will attempt to locate the RMM library in your build environment. You may choose to build
RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you
should specify the location of RMM with the CMake prefix:
```
# If using Conda:
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
# If using RMM installed with a custom location
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm
```
* [Using RMM with a single GPU](./rmm_singlegpu.py)
* [Using RMM with a local Dask cluster consisting of multiple GPUs](./rmm_mgpu_with_dask.py)

View File

@ -0,0 +1,27 @@
import xgboost as xgb
from sklearn.datasets import make_classification
import dask
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
def main(client):
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
X = dask.array.from_array(X)
y = dask.array.from_array(y)
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
'tree_method': 'gpu_hist'}
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
evals=[(dtrain, 'train')])
bst = output['booster']
history = output['history']
for i, e in enumerate(history['train']['merror']):
print(f'[{i}] train-merror: {e}')
if __name__ == '__main__':
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to
# LocalCUDACluster constructor.
with LocalCUDACluster(rmm_pool_size='2GB') as cluster:
with Client(cluster) as client:
main(client)

View File

@ -0,0 +1,14 @@
import xgboost as xgb
import rmm
from sklearn.datasets import make_classification
# Initialize RMM pool allocator
rmm.reinitialize(pool_allocator=True)
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
dtrain = xgb.DMatrix(X, label=y)
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
'tree_method': 'gpu_hist'}
# XGBoost will automatically use the RMM pool allocator
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train')])

View File

@ -7,6 +7,25 @@ if (PLUGIN_DENSE_PARSER)
target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc) target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc)
endif (PLUGIN_DENSE_PARSER) endif (PLUGIN_DENSE_PARSER)
if (PLUGIN_RMM)
find_path(RMM_INCLUDE "rmm"
HINTS "$ENV{RMM_ROOT}/include")
find_library(RMM_LIBRARY "rmm"
HINTS "$ENV{RMM_ROOT}/lib" "$ENV{RMM_ROOT}/build")
if ((NOT RMM_LIBRARY) OR (NOT RMM_INCLUDE))
message(FATAL_ERROR "Could not locate RMM library")
endif ()
message(STATUS "RMM: RMM_LIBRARY set to ${RMM_LIBRARY}")
message(STATUS "RMM: RMM_INCLUDE set to ${RMM_INCLUDE}")
target_include_directories(objxgboost PUBLIC ${RMM_INCLUDE})
target_link_libraries(objxgboost PUBLIC ${RMM_LIBRARY} cuda)
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_RMM=1)
endif (PLUGIN_RMM)
if (PLUGIN_UPDATER_ONEAPI) if (PLUGIN_UPDATER_ONEAPI)
add_library(oneapi_plugin OBJECT add_library(oneapi_plugin OBJECT
${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc

View File

@ -317,7 +317,7 @@ def _is_cudf_df(data):
import cudf import cudf
except ImportError: except ImportError:
return False return False
return isinstance(data, cudf.DataFrame) return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
def _cudf_array_interfaces(data): def _cudf_array_interfaces(data):

View File

@ -36,7 +36,12 @@
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
#include "nccl.h" #include "nccl.h"
#endif #endif // XGBOOST_USE_NCCL
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
@ -370,12 +375,21 @@ inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
} }
namespace detail { namespace detail {
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
template <typename T>
using XGBBaseDeviceAllocator = rmm::mr::thrust_allocator<T>;
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
template <typename T>
using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
/** /**
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
*/ */
template <class T> template <class T>
struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> { struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
using SuperT = thrust::device_malloc_allocator<T>; using SuperT = XGBBaseDeviceAllocator<T>;
using pointer = thrust::device_ptr<T>; // NOLINT using pointer = thrust::device_ptr<T>; // NOLINT
template<typename U> template<typename U>
struct rebind // NOLINT struct rebind // NOLINT
@ -391,10 +405,15 @@ struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
return SuperT::deallocate(ptr, n); return SuperT::deallocate(ptr, n);
} }
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
XGBDefaultDeviceAllocatorImpl()
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{0}) {}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
}; };
/** /**
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction. * \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs
* allocations if verbose. Does not initialise memory on construction.
*/ */
template <class T> template <class T>
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> { struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {

View File

@ -11,6 +11,7 @@
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/tree_model.h"
#include "device_helpers.cuh" #include "device_helpers.cuh"
namespace xgboost { namespace xgboost {
@ -402,6 +403,7 @@ template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
#if defined(__APPLE__) #if defined(__APPLE__)
/* /*

View File

@ -213,39 +213,21 @@ __global__ void PredictKernel(Data data,
class DeviceModel { class DeviceModel {
public: public:
dh::device_vector<RegTree::Node> nodes; // Need to lazily construct the vectors because GPU id is only known at runtime
dh::device_vector<size_t> tree_segments; HostDeviceVector<RegTree::Node> nodes;
dh::device_vector<int> tree_group; HostDeviceVector<size_t> tree_segments;
HostDeviceVector<int> tree_group;
size_t tree_beg_; // NOLINT size_t tree_beg_; // NOLINT
size_t tree_end_; // NOLINT size_t tree_end_; // NOLINT
int num_group; int num_group;
void CopyModel(const gbm::GBTreeModel& model,
const thrust::host_vector<size_t>& h_tree_segments,
const thrust::host_vector<RegTree::Node>& h_nodes,
size_t tree_begin, size_t tree_end) {
nodes.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
sizeof(RegTree::Node) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice));
tree_group.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice));
this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
}
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) { void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) {
dh::safe_cuda(cudaSetDevice(gpu_id)); dh::safe_cuda(cudaSetDevice(gpu_id));
CHECK_EQ(model.param.size_leaf_vector, 0); CHECK_EQ(model.param.size_leaf_vector, 0);
// Copy decision trees to device // Copy decision trees to device
thrust::host_vector<size_t> h_tree_segments{}; tree_segments = std::move(HostDeviceVector<size_t>({}, gpu_id));
auto& h_tree_segments = tree_segments.HostVector();
h_tree_segments.reserve((tree_end - tree_begin) + 1); h_tree_segments.reserve((tree_end - tree_begin) + 1);
size_t sum = 0; size_t sum = 0;
h_tree_segments.push_back(sum); h_tree_segments.push_back(sum);
@ -254,13 +236,21 @@ class DeviceModel {
h_tree_segments.push_back(sum); h_tree_segments.push_back(sum);
} }
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back()); nodes = std::move(HostDeviceVector<RegTree::Node>(h_tree_segments.back(), RegTree::Node(),
gpu_id));
auto& h_nodes = nodes.HostVector();
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
std::copy(src_nodes.begin(), src_nodes.end(), std::copy(src_nodes.begin(), src_nodes.end(),
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
} }
CopyModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
auto& h_tree_group = tree_group.HostVector();
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
} }
}; };
@ -287,8 +277,8 @@ class GPUPredictor : public xgboost::Predictor {
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<SparsePageLoader, SparsePageView>, PredictKernel<SparsePageLoader, SparsePageView>,
data, data,
dh::ToSpan(model_.nodes), predictions->DeviceSpan().subspan(batch_offset), model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset),
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows, model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared, model_.num_group); entry_start, use_shared, model_.num_group);
} }
@ -303,8 +293,8 @@ class GPUPredictor : public xgboost::Predictor {
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, PredictKernel<EllpackLoader, EllpackDeviceAccessor>,
batch, batch,
dh::ToSpan(model_.nodes), out_preds->DeviceSpan().subspan(batch_offset), model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows,
entry_start, use_shared, model_.num_group); entry_start, use_shared, model_.num_group);
} }
@ -435,8 +425,8 @@ class GPUPredictor : public xgboost::Predictor {
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<Loader, typename Loader::BatchT>, PredictKernel<Loader, typename Loader::BatchT>,
m->Value(), m->Value(),
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(), d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(),
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group), d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(),
tree_begin, tree_end, m->NumColumns(), info.num_row_, tree_begin, tree_end, m->NumColumns(), info.num_row_,
entry_start, use_shared, output_groups); entry_start, use_shared, output_groups);
} }

View File

@ -17,8 +17,8 @@ ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with cuDF, Dask, and cuPy # Create new Conda environment with cuDF, Dask, and cuPy
RUN \ RUN \
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \ conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.7 cudf=0.14 cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \ python=3.7 cudf=0.15* rmm=0.15* cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
ENV GOSU_VERSION 1.10 ENV GOSU_VERSION 1.10

View File

@ -0,0 +1,47 @@
ARG CUDA_VERSION
FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu16.04
ARG CUDA_VERSION
# Environment
ENV DEBIAN_FRONTEND noninteractive
SHELL ["/bin/bash", "-c"] # Use Bash as shell
# Install all basic requirements
RUN \
apt-get update && \
apt-get install -y wget unzip bzip2 libgomp1 build-essential ninja-build git && \
# Python
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3.sh -b -p /opt/python && \
# CMake
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
RUN \
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
export NCCL_VERSION=2.7.5-1 && \
apt-get update && \
apt-get install -y --allow-downgrades --allow-change-held-packages libnccl2=${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-dev=${NCCL_VERSION}+cuda${CUDA_SHORT}
ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with RMM
RUN \
conda create -n gpu_test -c nvidia -c rapidsai-nightly -c rapidsai -c conda-forge -c defaults \
python=3.7 rmm=0.15* cudatoolkit=$CUDA_VERSION
ENV GOSU_VERSION 1.10
# Install lightweight sudo (not bound to TTY)
RUN set -ex; \
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
chmod +x /usr/local/bin/gosu && \
gosu nobody true
# Default entry-point to use if running locally
# It will preserve attributes of created files
COPY entrypoint.sh /scripts/
WORKDIR /workspace
ENTRYPOINT ["/scripts/entrypoint.sh"]

View File

@ -1,10 +1,23 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -e set -e
if [[ "$1" == --conda-env=* ]]
then
conda_env=$(echo "$1" | sed 's/^--conda-env=//g' -)
echo "Activating Conda environment ${conda_env}"
shift 1
cmake_args="$@"
source activate ${conda_env}
cmake_prefix_flag="-DCMAKE_PREFIX_PATH=$CONDA_PREFIX"
else
cmake_args="$@"
cmake_prefix_flag=''
fi
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
cmake .. "$@" -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja cmake .. ${cmake_args} -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja ${cmake_prefix_flag}
ninja clean ninja clean
time ninja -v time ninja -v
cd .. cd ..

View File

@ -2,7 +2,15 @@
set -e set -e
set -x set -x
suite=$1 if [ "$#" -lt 1 ]
then
suite=''
args=''
else
suite=$1
shift 1
args="$@"
fi
# Install XGBoost Python package # Install XGBoost Python package
function install_xgboost { function install_xgboost {
@ -26,34 +34,40 @@ function install_xgboost {
fi fi
} }
function uninstall_xgboost {
pip uninstall -y xgboost
}
# Run specified test suite # Run specified test suite
case "$suite" in case "$suite" in
gpu) gpu)
source activate gpu_test source activate gpu_test
install_xgboost install_xgboost
pytest -v -s -rxXs --fulltrace -m "not mgpu" tests/python-gpu pytest -v -s -rxXs --fulltrace -m "not mgpu" ${args} tests/python-gpu
uninstall_xgboost
;; ;;
mgpu) mgpu)
source activate gpu_test source activate gpu_test
install_xgboost install_xgboost
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu pytest -v -s -rxXs --fulltrace -m "mgpu" ${args} tests/python-gpu
cd tests/distributed cd tests/distributed
./runtests-gpu.sh ./runtests-gpu.sh
cd - uninstall_xgboost
;; ;;
cpu) cpu)
source activate cpu_test source activate cpu_test
install_xgboost install_xgboost
pytest -v -s --fulltrace tests/python pytest -v -s -rxXs --fulltrace ${args} tests/python
cd tests/distributed cd tests/distributed
./runtests.sh ./runtests.sh
uninstall_xgboost
;; ;;
*) *)
echo "Usage: $0 {gpu|mgpu|cpu}" echo "Usage: $0 {gpu|mgpu|cpu} [extra args to pass to pytest]"
exit 1 exit 1
;; ;;
esac esac

View File

@ -37,6 +37,8 @@ if (USE_CUDA)
$<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>) $<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>)
target_compile_definitions(testxgboost target_compile_definitions(testxgboost
PRIVATE -DXGBOOST_USE_CUDA=1) PRIVATE -DXGBOOST_USE_CUDA=1)
find_package(CUDA)
target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS})
set_target_properties(testxgboost PROPERTIES set_target_properties(testxgboost PROPERTIES
CUDA_SEPARABLE_COMPILATION OFF) CUDA_SEPARABLE_COMPILATION OFF)

View File

@ -97,11 +97,6 @@ TEST(Span, FromPtrLen) {
} }
} }
{
auto lazy = [=]() {Span<float const, 16> tmp (arr, 5);};
EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n");
}
// dynamic extent // dynamic extent
{ {
Span<float, 16> s (arr, 16); Span<float, 16> s (arr, 16);
@ -122,6 +117,15 @@ TEST(Span, FromPtrLen) {
} }
} }
TEST(SpanDeathTest, FromPtrLen) {
float arr[16];
InitializeRange(arr, arr+16);
{
auto lazy = [=]() {Span<float const, 16> tmp (arr, 5);};
EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n");
}
}
TEST(Span, FromFirstLast) { TEST(Span, FromFirstLast) {
float arr[16]; float arr[16];
InitializeRange(arr, arr+16); InitializeRange(arr, arr+16);
@ -285,7 +289,13 @@ TEST(Span, ElementAccess) {
ASSERT_EQ(i, arr[j]); ASSERT_EQ(i, arr[j]);
++j; ++j;
} }
}
TEST(SpanDeathTest, ElementAccess) {
float arr[16];
InitializeRange(arr, arr + 16);
Span<float> s (arr);
EXPECT_DEATH(s[16], "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s[16], "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s[-1], "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s[-1], "\\[xgboost\\] Condition .* failed.\n");
@ -312,7 +322,9 @@ TEST(Span, FrontBack) {
ASSERT_EQ(s.front(), 0); ASSERT_EQ(s.front(), 0);
ASSERT_EQ(s.back(), 3); ASSERT_EQ(s.back(), 3);
} }
}
TEST(SpanDeathTest, FrontBack) {
{ {
Span<float, 0> s; Span<float, 0> s;
EXPECT_DEATH(s.front(), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.front(), "\\[xgboost\\] Condition .* failed.\n");
@ -340,10 +352,6 @@ TEST(Span, FirstLast) {
for (size_t i = 0; i < first.size(); ++i) { for (size_t i = 0; i < first.size(); ++i) {
ASSERT_EQ(first[i], arr[i]); ASSERT_EQ(first[i], arr[i]);
} }
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
EXPECT_DEATH(s.first<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n");
} }
{ {
@ -359,10 +367,6 @@ TEST(Span, FirstLast) {
for (size_t i = 0; i < last.size(); ++i) { for (size_t i = 0; i < last.size(); ++i) {
ASSERT_EQ(last[i], arr[i+12]); ASSERT_EQ(last[i], arr[i+12]);
} }
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
EXPECT_DEATH(s.last<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n");
} }
// dynamic extent // dynamic extent
@ -379,10 +383,6 @@ TEST(Span, FirstLast) {
ASSERT_EQ(first[i], s[i]); ASSERT_EQ(first[i], s[i]);
} }
EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n");
delete [] arr; delete [] arr;
} }
@ -399,6 +399,50 @@ TEST(Span, FirstLast) {
ASSERT_EQ(s[12 + i], last[i]); ASSERT_EQ(s[12 + i], last[i]);
} }
delete [] arr;
}
}
TEST(SpanDeathTest, FirstLast) {
// static extent
{
float arr[16];
InitializeRange(arr, arr + 16);
Span<float> s (arr);
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
EXPECT_DEATH(s.first<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n");
}
{
float arr[16];
InitializeRange(arr, arr + 16);
Span<float> s (arr);
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
EXPECT_DEATH(s.last<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n");
}
// dynamic extent
{
float *arr = new float[16];
InitializeRange(arr, arr + 16);
Span<float> s (arr, 16);
EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n");
delete [] arr;
}
{
float *arr = new float[16];
InitializeRange(arr, arr + 16);
Span<float> s (arr, 16);
EXPECT_DEATH(s.last(-1), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.last(-1), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last(17), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.last(17), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s.last(32), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.last(32), "\\[xgboost\\] Condition .* failed.\n");
@ -420,7 +464,11 @@ TEST(Span, Subspan) {
auto s4 = s1.subspan(2, dynamic_extent); auto s4 = s1.subspan(2, dynamic_extent);
ASSERT_EQ(s1.data() + 2, s4.data()); ASSERT_EQ(s1.data() + 2, s4.data());
ASSERT_EQ(s4.size(), s1.size() - 2); ASSERT_EQ(s4.size(), s1.size() - 2);
}
TEST(SpanDeathTest, Subspan) {
int arr[16] {0};
Span<int> s1 (arr);
EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n");

View File

@ -221,7 +221,7 @@ struct TestElementAccess {
} }
}; };
TEST(GPUSpan, ElementAccess) { TEST(GPUSpanDeathTest, ElementAccess) {
dh::safe_cuda(cudaSetDevice(0)); dh::safe_cuda(cudaSetDevice(0));
auto test_element_access = []() { auto test_element_access = []() {
thrust::host_vector<float> h_vec (16); thrust::host_vector<float> h_vec (16);

View File

@ -59,7 +59,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
} }
#if !defined(__CUDACC__) #if !defined(__CUDACC__)
TEST(Transform, Exception) { TEST(TransformDeathTest, Exception) {
size_t const kSize {16}; size_t const kSize {16};
std::vector<bst_float> h_in(kSize); std::vector<bst_float> h_in(kSize);
const HostDeviceVector<bst_float> in_vec{h_in, -1}; const HostDeviceVector<bst_float> in_vec{h_in, -1};

View File

@ -20,6 +20,15 @@
#include "../../src/gbm/gbtree_model.h" #include "../../src/gbm/gbtree_model.h"
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
#include <memory>
#include <numeric>
#include <vector>
#include "rmm/mr/device/per_device_resource.hpp"
#include "rmm/mr/device/cuda_memory_resource.hpp"
#include "rmm/mr/device/pool_memory_resource.hpp"
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
bool FileExists(const std::string& filename) { bool FileExists(const std::string& filename) {
struct stat st; struct stat st;
return stat(filename.c_str(), &st) == 0; return stat(filename.c_str(), &st) == 0;
@ -478,4 +487,57 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
return gbm; return gbm;
} }
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
using CUDAMemoryResource = rmm::mr::cuda_memory_resource;
using PoolMemoryResource = rmm::mr::pool_memory_resource<CUDAMemoryResource>;
class RMMAllocator {
public:
std::vector<std::unique_ptr<CUDAMemoryResource>> cuda_mr;
std::vector<std::unique_ptr<PoolMemoryResource>> pool_mr;
int n_gpu;
RMMAllocator() : n_gpu(common::AllVisibleGPUs()) {
int current_device;
CHECK_EQ(cudaGetDevice(&current_device), cudaSuccess);
for (int i = 0; i < n_gpu; ++i) {
CHECK_EQ(cudaSetDevice(i), cudaSuccess);
cuda_mr.push_back(std::make_unique<CUDAMemoryResource>());
pool_mr.push_back(std::make_unique<PoolMemoryResource>(cuda_mr[i].get()));
}
CHECK_EQ(cudaSetDevice(current_device), cudaSuccess);
}
~RMMAllocator() = default;
};
void DeleteRMMResource(RMMAllocator* r) {
delete r;
}
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
bool use_rmm_pool = false;
for (int i = 1; i < argc; ++i) {
if (argv[i] == std::string("--use-rmm-pool")) {
use_rmm_pool = true;
}
}
if (!use_rmm_pool) {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
}
LOG(INFO) << "Using RMM memory pool";
auto ptr = RMMAllocatorPtr(new RMMAllocator(), DeleteRMMResource);
for (int i = 0; i < ptr->n_gpu; ++i) {
rmm::mr::set_per_device_resource(rmm::cuda_device_id(i), ptr->pool_mr[i].get());
}
return ptr;
}
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
class RMMAllocator {};
void DeleteRMMResource(RMMAllocator* r) {}
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
}
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
} // namespace xgboost } // namespace xgboost

View File

@ -8,6 +8,7 @@
#include <fstream> #include <fstream>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <memory>
#include <vector> #include <vector>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
@ -352,5 +353,9 @@ inline int Next(DataIterHandle self) {
return static_cast<CudaArrayIterForTest*>(self)->Next(); return static_cast<CudaArrayIterForTest*>(self)->Next();
} }
class RMMAllocator;
using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>;
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv);
} // namespace xgboost } // namespace xgboost
#endif #endif

View File

@ -3,13 +3,17 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <string> #include <string>
#include <memory>
#include <vector> #include <vector>
#include "helpers.h"
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
xgboost::Args args {{"verbosity", "2"}}; xgboost::Args args {{"verbosity", "2"}};
xgboost::ConsoleLogger::Configure(args); xgboost::ConsoleLogger::Configure(args);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
auto rmm_alloc = xgboost::SetUpRMMResourceForCppTests(argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }

View File

@ -119,7 +119,7 @@ void TestIncorrectRow() {
}); });
} }
TEST(RowPartitioner, IncorrectRow) { TEST(RowPartitionerDeathTest, IncorrectRow) {
ASSERT_DEATH({ TestIncorrectRow(); },".*"); ASSERT_DEATH({ TestIncorrectRow(); },".*");
} }
} // namespace tree } // namespace tree

View File

@ -2,4 +2,4 @@
markers = markers =
mgpu: Mark a test that requires multiple GPUs to run. mgpu: Mark a test that requires multiple GPUs to run.
ci: Mark a test that runs only on CI. ci: Mark a test that runs only on CI.
gtest: Mark a test that requires C++ Google Test executable. gtest: Mark a test that requires C++ Google Test executable.

View File

@ -0,0 +1,45 @@
import sys
import pytest
import logging
sys.path.append("tests/python")
import testing as tm # noqa
def has_rmm():
try:
import rmm
return True
except ImportError:
return False
@pytest.fixture(scope='session', autouse=True)
def setup_rmm_pool(request, pytestconfig):
if pytestconfig.getoption('--use-rmm-pool'):
if not has_rmm():
raise ImportError('The --use-rmm-pool option requires the RMM package')
import rmm
from dask_cuda.utils import get_n_gpus
rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024,
devices=list(range(get_n_gpus())))
@pytest.fixture(scope='function')
def local_cuda_cluster(request, pytestconfig):
kwargs = {}
if hasattr(request, 'param'):
kwargs.update(request.param)
if pytestconfig.getoption('--use-rmm-pool'):
if not has_rmm():
raise ImportError('The --use-rmm-pool option requires the RMM package')
import rmm
from dask_cuda.utils import get_n_gpus
rmm.reinitialize()
kwargs['rmm_pool_size'] = '2GB'
if tm.no_dask_cuda()['condition']:
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(**kwargs)
yield cluster
cluster.close()
def pytest_addoption(parser):
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')

View File

@ -6,7 +6,6 @@ sys.path.append("tests/python")
import testing as tm import testing as tm
import test_demos as td # noqa import test_demos as td # noqa
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_data_iterator(): def test_data_iterator():
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py') script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')

View File

@ -3,7 +3,6 @@ import os
import pytest import pytest
import numpy as np import numpy as np
import asyncio import asyncio
import unittest
import xgboost import xgboost
import subprocess import subprocess
from hypothesis import given, strategies, settings, note from hypothesis import given, strategies, settings, note
@ -23,7 +22,6 @@ import testing as tm # noqa
try: try:
import dask.dataframe as dd import dask.dataframe as dd
from xgboost import dask as dxgb from xgboost import dask as dxgb
from dask_cuda import LocalCUDACluster
from dask.distributed import Client from dask.distributed import Client
from dask import array as da from dask import array as da
import cudf import cudf
@ -151,50 +149,51 @@ def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client):
assert tm.non_increasing(history['train'][dataset.metric]) assert tm.non_increasing(history['train'][dataset.metric])
class TestDistributedGPU(unittest.TestCase): class TestDistributedGPU:
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_dask_dataframe(self): def test_dask_dataframe(self, local_cuda_cluster):
with LocalCUDACluster() as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: run_with_dask_dataframe(dxgb.DaskDMatrix, client)
run_with_dask_dataframe(dxgb.DaskDMatrix, client) run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
@given(parameter_strategy, strategies.integers(1, 20), @given(params=parameter_strategy, num_rounds=strategies.integers(1, 20),
tm.dataset_strategy) dataset=tm.dataset_strategy)
@settings(deadline=duration(seconds=120)) @settings(deadline=duration(seconds=120))
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], indirect=['local_cuda_cluster'])
@pytest.mark.mgpu @pytest.mark.mgpu
def test_gpu_hist(self, params, num_rounds, dataset): def test_gpu_hist(self, params, num_rounds, dataset, local_cuda_cluster):
with LocalCUDACluster(n_workers=2) as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix,
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, client)
client) run_gpu_hist(params, num_rounds, dataset,
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDeviceQuantileDMatrix, client)
dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_dask_array(self): def test_dask_array(self, local_cuda_cluster):
with LocalCUDACluster() as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: run_with_dask_array(dxgb.DaskDMatrix, client)
run_with_dask_array(dxgb.DaskDMatrix, client) run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_empty_dmatrix(self): def test_empty_dmatrix(self, local_cuda_cluster):
with LocalCUDACluster() as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: parameters = {'tree_method': 'gpu_hist',
parameters = {'tree_method': 'gpu_hist', 'debug_synchronize': True}
'debug_synchronize': True} run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_cls(client, parameters)
run_empty_dmatrix_cls(client, parameters)
def run_quantile(self, name): def run_quantile(self, name, local_cuda_cluster):
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows") pytest.skip("Skipping dask tests on Windows")
@ -217,34 +216,33 @@ class TestDistributedGPU(unittest.TestCase):
env[port[0]] = port[1] env[port[0]] = port[1]
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
with LocalCUDACluster() as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: workers = list(dxgb._get_client_workers(client).keys())
workers = list(dxgb._get_client_workers(client).keys()) rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client) futures = client.map(runit,
futures = client.map(runit, workers,
workers, pure=False,
pure=False, workers=workers,
workers=workers, rabit_args=rabit_args)
rabit_args=rabit_args) results = client.gather(futures)
results = client.gather(futures) for ret in results:
for ret in results: msg = ret.stdout.decode('utf-8')
msg = ret.stdout.decode('utf-8') assert msg.find('1 test from GPUQuantile') != -1, msg
assert msg.find('1 test from GPUQuantile') != -1, msg assert ret.returncode == 0, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_basic(self): def test_quantile_basic(self, local_cuda_cluster):
self.run_quantile('AllReduceBasic') self.run_quantile('AllReduceBasic', local_cuda_cluster)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
@pytest.mark.gtest @pytest.mark.gtest
def test_quantile_same_on_all_workers(self): def test_quantile_same_on_all_workers(self, local_cuda_cluster):
self.run_quantile('SameOnAllWorkers') self.run_quantile('SameOnAllWorkers', local_cuda_cluster)
async def run_from_dask_array_asyncio(scheduler_address): async def run_from_dask_array_asyncio(scheduler_address):
@ -275,11 +273,11 @@ async def run_from_dask_array_asyncio(scheduler_address):
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_with_asyncio(): def test_with_asyncio(local_cuda_cluster):
with LocalCUDACluster() as cluster: with Client(local_cuda_cluster) as client:
with Client(cluster) as client: address = client.scheduler.address
address = client.scheduler.address output = asyncio.run(run_from_dask_array_asyncio(address))
output = asyncio.run(run_from_dask_array_asyncio(address)) assert isinstance(output['booster'], xgboost.Booster)
assert isinstance(output['booster'], xgboost.Booster) assert isinstance(output['history'], dict)
assert isinstance(output['history'], dict)