RMM integration plugin (#5873)
* [CI] Add RMM as an optional dependency * Replace caching allocator with pool allocator from RMM * Revert "Replace caching allocator with pool allocator from RMM" This reverts commit e15845d4e72e890c2babe31a988b26503a7d9038. * Use rmm::mr::get_default_resource() * Try setting default resource (doesn't work yet) * Allocate pool_mr in the heap * Prevent leaking pool_mr handle * Separate EXPECT_DEATH() in separate test suite suffixed DeathTest * Turn off death tests for RMM * Address reviewer's feedback * Prevent leaking of cuda_mr * Fix Jenkinsfile syntax * Remove unnecessary function in Jenkinsfile * [CI] Install NCCL into RMM container * Run Python tests * Try building with RMM, CUDA 10.0 * Do not use RMM for CUDA 10.0 target * Actually test for test_rmm flag * Fix TestPythonGPU * Use CNMeM allocator, since pool allocator doesn't yet support multiGPU * Use 10.0 container to build RMM-enabled XGBoost * Revert "Use 10.0 container to build RMM-enabled XGBoost" This reverts commit 789021fa31112e25b683aef39fff375403060141. * Fix Jenkinsfile * [CI] Assign larger /dev/shm to NCCL * Use 10.2 artifact to run multi-GPU Python tests * Add CUDA 10.0 -> 11.0 cross-version test; remove CUDA 10.0 target * Rename Conda env rmm_test -> gpu_test * Use env var to opt into CNMeM pool for C++ tests * Use identical CUDA version for RMM builds and tests * Use Pytest fixtures to enable RMM pool in Python tests * Move RMM to plugin/CMakeLists.txt; use PLUGIN_RMM * Use per-device MR; use command arg in gtest * Set CMake prefix path to use Conda env * Use 0.15 nightly version of RMM * Remove unnecessary header * Fix a unit test when cudf is missing * Add RMM demos * Remove print() * Use HostDeviceVector in GPU predictor * Simplify pytest setup; use LocalCUDACluster fixture * Address reviewers' commments Co-authored-by: Hyunsu Cho <chohyu01@cs.wasshington.edu>
This commit is contained in:
parent
c3ea3b7e37
commit
9adb812a0a
@ -60,6 +60,7 @@ address, leak, undefined and thread.")
|
||||
## Plugins
|
||||
option(PLUGIN_LZ4 "Build lz4 plugin" OFF)
|
||||
option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF)
|
||||
option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF)
|
||||
## TODO: 1. Add check if DPC++ compiler is used for building
|
||||
option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF)
|
||||
option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
|
||||
@ -84,6 +85,9 @@ endif (R_LIB AND GOOGLE_TEST)
|
||||
if (USE_AVX)
|
||||
message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.")
|
||||
endif (USE_AVX)
|
||||
if (PLUGIN_RMM AND NOT (USE_CUDA))
|
||||
message(SEND_ERROR "`PLUGIN_RMM` must be enabled with `USE_CUDA` flag.")
|
||||
endif (PLUGIN_RMM AND NOT (USE_CUDA))
|
||||
if (ENABLE_ALL_WARNINGS)
|
||||
if ((NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") AND (NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU"))
|
||||
message(SEND_ERROR "ENABLE_ALL_WARNINGS is only available for Clang and GCC.")
|
||||
|
||||
55
Jenkinsfile
vendored
55
Jenkinsfile
vendored
@ -73,7 +73,7 @@ pipeline {
|
||||
'build-gpu-cuda10.0': { BuildCUDA(cuda_version: '10.0') },
|
||||
// The build-gpu-* builds below use Ubuntu image
|
||||
'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') },
|
||||
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') },
|
||||
'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2', build_rmm: true) },
|
||||
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') },
|
||||
'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') },
|
||||
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
|
||||
@ -89,11 +89,12 @@ pipeline {
|
||||
script {
|
||||
parallel ([
|
||||
'test-python-cpu': { TestPythonCPU() },
|
||||
'test-python-gpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2') },
|
||||
// artifact_cuda_version doesn't apply to RMM tests; RMM tests will always match CUDA version between artifact and host env
|
||||
'test-python-gpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', test_rmm: true) },
|
||||
'test-python-gpu-cuda11.0-cross': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '11.0') },
|
||||
'test-python-gpu-cuda11.0': { TestPythonGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
|
||||
'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', multi_gpu: true) },
|
||||
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') },
|
||||
'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', multi_gpu: true, test_rmm: true) },
|
||||
'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', test_rmm: true) },
|
||||
'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') },
|
||||
'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') },
|
||||
'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') },
|
||||
@ -280,6 +281,22 @@ def BuildCUDA(args) {
|
||||
}
|
||||
echo 'Stashing C++ test executable (testxgboost)...'
|
||||
stash name: "xgboost_cpp_tests_cuda${args.cuda_version}", includes: 'build/testxgboost'
|
||||
if (args.build_rmm) {
|
||||
echo "Build with CUDA ${args.cuda_version} and RMM"
|
||||
container_type = "rmm"
|
||||
docker_binary = "docker"
|
||||
docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}"
|
||||
sh """
|
||||
rm -rf build/
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_via_cmake.sh --conda-env=gpu_test -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON ${arch_flag}
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "cd python-package && rm -rf dist/* && python setup.py bdist_wheel --universal"
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} python tests/ci_build/rename_whl.py python-package/dist/*.whl ${commit_id} manylinux2010_x86_64
|
||||
"""
|
||||
echo 'Stashing Python wheel...'
|
||||
stash name: "xgboost_whl_rmm_cuda${args.cuda_version}", includes: 'python-package/dist/*.whl'
|
||||
echo 'Stashing C++ test executable (testxgboost)...'
|
||||
stash name: "xgboost_cpp_tests_rmm_cuda${args.cuda_version}", includes: 'build/testxgboost'
|
||||
}
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
@ -366,18 +383,15 @@ def TestPythonGPU(args) {
|
||||
def container_type = "gpu"
|
||||
def docker_binary = "nvidia-docker"
|
||||
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
|
||||
if (args.multi_gpu) {
|
||||
echo "Using multiple GPUs"
|
||||
def mgpu_indicator = (args.multi_gpu) ? 'mgpu' : 'gpu'
|
||||
// Allocate extra space in /dev/shm to enable NCCL
|
||||
def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'"
|
||||
sh """
|
||||
${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu
|
||||
"""
|
||||
} else {
|
||||
echo "Using a single GPU"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu
|
||||
"""
|
||||
def docker_extra_params = (args.multi_gpu) ? "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" : ''
|
||||
sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator}"
|
||||
if (args.test_rmm) {
|
||||
sh "rm -rfv build/ python-package/dist/"
|
||||
unstash name: "xgboost_whl_rmm_cuda${args.host_cuda_version}"
|
||||
unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
|
||||
sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator} --use-rmm-pool"
|
||||
}
|
||||
deleteDir()
|
||||
}
|
||||
@ -408,6 +422,17 @@ def TestCppGPU(args) {
|
||||
def docker_binary = "nvidia-docker"
|
||||
def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
|
||||
sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} build/testxgboost"
|
||||
if (args.test_rmm) {
|
||||
sh "rm -rfv build/"
|
||||
unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}"
|
||||
echo "Test C++, CUDA ${args.host_cuda_version} with RMM"
|
||||
container_type = "rmm"
|
||||
docker_binary = "nvidia-docker"
|
||||
docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}"
|
||||
sh """
|
||||
${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "source activate gpu_test && build/testxgboost --use-rmm-pool --gtest_filter=-*DeathTest.*"
|
||||
"""
|
||||
}
|
||||
deleteDir()
|
||||
}
|
||||
}
|
||||
|
||||
31
demo/rmm_plugin/README.md
Normal file
31
demo/rmm_plugin/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL)
|
||||
====================================================================
|
||||
[RAPIDS Memory Manager (RMM)](https://github.com/rapidsai/rmm) library provides a collection of
|
||||
efficient memory allocators for NVIDIA GPUs. It is now possible to use XGBoost with memory
|
||||
allocators provided by RMM, by enabling the RMM integration plugin.
|
||||
|
||||
The demos in this directory highlights one RMM allocator in particular: **the pool sub-allocator**.
|
||||
This allocator addresses the slow speed of `cudaMalloc()` by allocating a large chunk of memory
|
||||
upfront. Subsequent allocations will draw from the pool of already allocated memory and thus avoid
|
||||
the overhead of calling `cudaMalloc()` directly. See
|
||||
[this GTC talk slides](https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf)
|
||||
for more details.
|
||||
|
||||
Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this,
|
||||
run CMake with option `-DPLUGIN_RMM=ON` (`-DUSE_CUDA=ON` also required):
|
||||
```
|
||||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON
|
||||
make -j4
|
||||
```
|
||||
CMake will attempt to locate the RMM library in your build environment. You may choose to build
|
||||
RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you
|
||||
should specify the location of RMM with the CMake prefix:
|
||||
```
|
||||
# If using Conda:
|
||||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
|
||||
# If using RMM installed with a custom location
|
||||
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm
|
||||
```
|
||||
|
||||
* [Using RMM with a single GPU](./rmm_singlegpu.py)
|
||||
* [Using RMM with a local Dask cluster consisting of multiple GPUs](./rmm_mgpu_with_dask.py)
|
||||
27
demo/rmm_plugin/rmm_mgpu_with_dask.py
Normal file
27
demo/rmm_plugin/rmm_mgpu_with_dask.py
Normal file
@ -0,0 +1,27 @@
|
||||
import xgboost as xgb
|
||||
from sklearn.datasets import make_classification
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
|
||||
def main(client):
|
||||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
||||
X = dask.array.from_array(X)
|
||||
y = dask.array.from_array(y)
|
||||
dtrain = xgb.dask.DaskDMatrix(client, X, label=y)
|
||||
|
||||
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
|
||||
'tree_method': 'gpu_hist'}
|
||||
output = xgb.dask.train(client, params, dtrain, num_boost_round=100,
|
||||
evals=[(dtrain, 'train')])
|
||||
bst = output['booster']
|
||||
history = output['history']
|
||||
for i, e in enumerate(history['train']['merror']):
|
||||
print(f'[{i}] train-merror: {e}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
# To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to
|
||||
# LocalCUDACluster constructor.
|
||||
with LocalCUDACluster(rmm_pool_size='2GB') as cluster:
|
||||
with Client(cluster) as client:
|
||||
main(client)
|
||||
14
demo/rmm_plugin/rmm_singlegpu.py
Normal file
14
demo/rmm_plugin/rmm_singlegpu.py
Normal file
@ -0,0 +1,14 @@
|
||||
import xgboost as xgb
|
||||
import rmm
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
# Initialize RMM pool allocator
|
||||
rmm.reinitialize(pool_allocator=True)
|
||||
|
||||
X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3)
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
|
||||
params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3,
|
||||
'tree_method': 'gpu_hist'}
|
||||
# XGBoost will automatically use the RMM pool allocator
|
||||
bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train')])
|
||||
@ -7,6 +7,25 @@ if (PLUGIN_DENSE_PARSER)
|
||||
target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc)
|
||||
endif (PLUGIN_DENSE_PARSER)
|
||||
|
||||
if (PLUGIN_RMM)
|
||||
find_path(RMM_INCLUDE "rmm"
|
||||
HINTS "$ENV{RMM_ROOT}/include")
|
||||
|
||||
find_library(RMM_LIBRARY "rmm"
|
||||
HINTS "$ENV{RMM_ROOT}/lib" "$ENV{RMM_ROOT}/build")
|
||||
|
||||
if ((NOT RMM_LIBRARY) OR (NOT RMM_INCLUDE))
|
||||
message(FATAL_ERROR "Could not locate RMM library")
|
||||
endif ()
|
||||
|
||||
message(STATUS "RMM: RMM_LIBRARY set to ${RMM_LIBRARY}")
|
||||
message(STATUS "RMM: RMM_INCLUDE set to ${RMM_INCLUDE}")
|
||||
|
||||
target_include_directories(objxgboost PUBLIC ${RMM_INCLUDE})
|
||||
target_link_libraries(objxgboost PUBLIC ${RMM_LIBRARY} cuda)
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_RMM=1)
|
||||
endif (PLUGIN_RMM)
|
||||
|
||||
if (PLUGIN_UPDATER_ONEAPI)
|
||||
add_library(oneapi_plugin OBJECT
|
||||
${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc
|
||||
|
||||
@ -317,7 +317,7 @@ def _is_cudf_df(data):
|
||||
import cudf
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, cudf.DataFrame)
|
||||
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
||||
|
||||
|
||||
def _cudf_array_interfaces(data):
|
||||
|
||||
@ -36,7 +36,12 @@
|
||||
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl.h"
|
||||
#endif
|
||||
#endif // XGBOOST_USE_NCCL
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
#include "rmm/mr/device/per_device_resource.hpp"
|
||||
#include "rmm/mr/device/thrust_allocator_adaptor.hpp"
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
|
||||
|
||||
@ -370,12 +375,21 @@ inline void DebugSyncDevice(std::string file="", int32_t line = -1) {
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
template <typename T>
|
||||
using XGBBaseDeviceAllocator = rmm::mr::thrust_allocator<T>;
|
||||
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
template <typename T>
|
||||
using XGBBaseDeviceAllocator = thrust::device_malloc_allocator<T>;
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
/**
|
||||
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
|
||||
*/
|
||||
template <class T>
|
||||
struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
using SuperT = thrust::device_malloc_allocator<T>;
|
||||
struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
using SuperT = XGBBaseDeviceAllocator<T>;
|
||||
using pointer = thrust::device_ptr<T>; // NOLINT
|
||||
template<typename U>
|
||||
struct rebind // NOLINT
|
||||
@ -391,10 +405,15 @@ struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||
return SuperT::deallocate(ptr, n);
|
||||
}
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
XGBDefaultDeviceAllocatorImpl()
|
||||
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{0}) {}
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction.
|
||||
* \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs
|
||||
* allocations if verbose. Does not initialise memory on construction.
|
||||
*/
|
||||
template <class T>
|
||||
struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -402,6 +403,7 @@ template class HostDeviceVector<FeatureType>;
|
||||
template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Node>;
|
||||
|
||||
#if defined(__APPLE__)
|
||||
/*
|
||||
|
||||
@ -213,39 +213,21 @@ __global__ void PredictKernel(Data data,
|
||||
|
||||
class DeviceModel {
|
||||
public:
|
||||
dh::device_vector<RegTree::Node> nodes;
|
||||
dh::device_vector<size_t> tree_segments;
|
||||
dh::device_vector<int> tree_group;
|
||||
// Need to lazily construct the vectors because GPU id is only known at runtime
|
||||
HostDeviceVector<RegTree::Node> nodes;
|
||||
HostDeviceVector<size_t> tree_segments;
|
||||
HostDeviceVector<int> tree_group;
|
||||
size_t tree_beg_; // NOLINT
|
||||
size_t tree_end_; // NOLINT
|
||||
int num_group;
|
||||
|
||||
void CopyModel(const gbm::GBTreeModel& model,
|
||||
const thrust::host_vector<size_t>& h_tree_segments,
|
||||
const thrust::host_vector<RegTree::Node>& h_nodes,
|
||||
size_t tree_begin, size_t tree_end) {
|
||||
nodes.resize(h_nodes.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
|
||||
sizeof(RegTree::Node) * h_nodes.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_segments.resize(h_tree_segments.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(),
|
||||
sizeof(size_t) * h_tree_segments.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
tree_group.resize(model.tree_info.size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
this->tree_beg_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group = model.learner_model_param->num_output_group;
|
||||
}
|
||||
|
||||
void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) {
|
||||
dh::safe_cuda(cudaSetDevice(gpu_id));
|
||||
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0);
|
||||
// Copy decision trees to device
|
||||
thrust::host_vector<size_t> h_tree_segments{};
|
||||
tree_segments = std::move(HostDeviceVector<size_t>({}, gpu_id));
|
||||
auto& h_tree_segments = tree_segments.HostVector();
|
||||
h_tree_segments.reserve((tree_end - tree_begin) + 1);
|
||||
size_t sum = 0;
|
||||
h_tree_segments.push_back(sum);
|
||||
@ -254,13 +236,21 @@ class DeviceModel {
|
||||
h_tree_segments.push_back(sum);
|
||||
}
|
||||
|
||||
thrust::host_vector<RegTree::Node> h_nodes(h_tree_segments.back());
|
||||
nodes = std::move(HostDeviceVector<RegTree::Node>(h_tree_segments.back(), RegTree::Node(),
|
||||
gpu_id));
|
||||
auto& h_nodes = nodes.HostVector();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
CopyModel(model, h_tree_segments, h_nodes, tree_begin, tree_end);
|
||||
|
||||
tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
|
||||
auto& h_tree_group = tree_group.HostVector();
|
||||
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
|
||||
this->tree_beg_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group = model.learner_model_param->num_output_group;
|
||||
}
|
||||
};
|
||||
|
||||
@ -287,8 +277,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
dh::ToSpan(model_.nodes), predictions->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group),
|
||||
model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
}
|
||||
@ -303,8 +293,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>,
|
||||
batch,
|
||||
dh::ToSpan(model_.nodes), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group),
|
||||
model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
}
|
||||
@ -435,8 +425,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<Loader, typename Loader::BatchT>,
|
||||
m->Value(),
|
||||
dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(),
|
||||
dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group),
|
||||
d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(),
|
||||
d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(),
|
||||
tree_begin, tree_end, m->NumColumns(), info.num_row_,
|
||||
entry_start, use_shared, output_groups);
|
||||
}
|
||||
|
||||
@ -17,8 +17,8 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Create new Conda environment with cuDF, Dask, and cuPy
|
||||
RUN \
|
||||
conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.7 cudf=0.14 cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
|
||||
conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
|
||||
python=3.7 cudf=0.15* rmm=0.15* cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \
|
||||
numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
47
tests/ci_build/Dockerfile.rmm
Normal file
47
tests/ci_build/Dockerfile.rmm
Normal file
@ -0,0 +1,47 @@
|
||||
ARG CUDA_VERSION
|
||||
FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu16.04
|
||||
ARG CUDA_VERSION
|
||||
|
||||
# Environment
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
SHELL ["/bin/bash", "-c"] # Use Bash as shell
|
||||
|
||||
# Install all basic requirements
|
||||
RUN \
|
||||
apt-get update && \
|
||||
apt-get install -y wget unzip bzip2 libgomp1 build-essential ninja-build git && \
|
||||
# Python
|
||||
wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash Miniconda3.sh -b -p /opt/python && \
|
||||
# CMake
|
||||
wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \
|
||||
bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr
|
||||
|
||||
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
|
||||
RUN \
|
||||
export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \
|
||||
export NCCL_VERSION=2.7.5-1 && \
|
||||
apt-get update && \
|
||||
apt-get install -y --allow-downgrades --allow-change-held-packages libnccl2=${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-dev=${NCCL_VERSION}+cuda${CUDA_SHORT}
|
||||
|
||||
ENV PATH=/opt/python/bin:$PATH
|
||||
|
||||
# Create new Conda environment with RMM
|
||||
RUN \
|
||||
conda create -n gpu_test -c nvidia -c rapidsai-nightly -c rapidsai -c conda-forge -c defaults \
|
||||
python=3.7 rmm=0.15* cudatoolkit=$CUDA_VERSION
|
||||
|
||||
ENV GOSU_VERSION 1.10
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \
|
||||
chmod +x /usr/local/bin/gosu && \
|
||||
gosu nobody true
|
||||
|
||||
# Default entry-point to use if running locally
|
||||
# It will preserve attributes of created files
|
||||
COPY entrypoint.sh /scripts/
|
||||
|
||||
WORKDIR /workspace
|
||||
ENTRYPOINT ["/scripts/entrypoint.sh"]
|
||||
@ -1,10 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
if [[ "$1" == --conda-env=* ]]
|
||||
then
|
||||
conda_env=$(echo "$1" | sed 's/^--conda-env=//g' -)
|
||||
echo "Activating Conda environment ${conda_env}"
|
||||
shift 1
|
||||
cmake_args="$@"
|
||||
source activate ${conda_env}
|
||||
cmake_prefix_flag="-DCMAKE_PREFIX_PATH=$CONDA_PREFIX"
|
||||
else
|
||||
cmake_args="$@"
|
||||
cmake_prefix_flag=''
|
||||
fi
|
||||
|
||||
rm -rf build
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. "$@" -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja
|
||||
cmake .. ${cmake_args} -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja ${cmake_prefix_flag}
|
||||
ninja clean
|
||||
time ninja -v
|
||||
cd ..
|
||||
|
||||
@ -2,7 +2,15 @@
|
||||
set -e
|
||||
set -x
|
||||
|
||||
if [ "$#" -lt 1 ]
|
||||
then
|
||||
suite=''
|
||||
args=''
|
||||
else
|
||||
suite=$1
|
||||
shift 1
|
||||
args="$@"
|
||||
fi
|
||||
|
||||
# Install XGBoost Python package
|
||||
function install_xgboost {
|
||||
@ -26,34 +34,40 @@ function install_xgboost {
|
||||
fi
|
||||
}
|
||||
|
||||
function uninstall_xgboost {
|
||||
pip uninstall -y xgboost
|
||||
}
|
||||
|
||||
# Run specified test suite
|
||||
case "$suite" in
|
||||
gpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
pytest -v -s -rxXs --fulltrace -m "not mgpu" tests/python-gpu
|
||||
pytest -v -s -rxXs --fulltrace -m "not mgpu" ${args} tests/python-gpu
|
||||
uninstall_xgboost
|
||||
;;
|
||||
|
||||
mgpu)
|
||||
source activate gpu_test
|
||||
install_xgboost
|
||||
pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu
|
||||
pytest -v -s -rxXs --fulltrace -m "mgpu" ${args} tests/python-gpu
|
||||
|
||||
cd tests/distributed
|
||||
./runtests-gpu.sh
|
||||
cd -
|
||||
uninstall_xgboost
|
||||
;;
|
||||
|
||||
cpu)
|
||||
source activate cpu_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace tests/python
|
||||
pytest -v -s -rxXs --fulltrace ${args} tests/python
|
||||
cd tests/distributed
|
||||
./runtests.sh
|
||||
uninstall_xgboost
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Usage: $0 {gpu|mgpu|cpu}"
|
||||
echo "Usage: $0 {gpu|mgpu|cpu} [extra args to pass to pytest]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
@ -37,6 +37,8 @@ if (USE_CUDA)
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${GEN_CODE}>)
|
||||
target_compile_definitions(testxgboost
|
||||
PRIVATE -DXGBOOST_USE_CUDA=1)
|
||||
find_package(CUDA)
|
||||
target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS})
|
||||
set_target_properties(testxgboost PROPERTIES
|
||||
CUDA_SEPARABLE_COMPILATION OFF)
|
||||
|
||||
|
||||
@ -97,11 +97,6 @@ TEST(Span, FromPtrLen) {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto lazy = [=]() {Span<float const, 16> tmp (arr, 5);};
|
||||
EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
|
||||
// dynamic extent
|
||||
{
|
||||
Span<float, 16> s (arr, 16);
|
||||
@ -122,6 +117,15 @@ TEST(Span, FromPtrLen) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SpanDeathTest, FromPtrLen) {
|
||||
float arr[16];
|
||||
InitializeRange(arr, arr+16);
|
||||
{
|
||||
auto lazy = [=]() {Span<float const, 16> tmp (arr, 5);};
|
||||
EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Span, FromFirstLast) {
|
||||
float arr[16];
|
||||
InitializeRange(arr, arr+16);
|
||||
@ -285,7 +289,13 @@ TEST(Span, ElementAccess) {
|
||||
ASSERT_EQ(i, arr[j]);
|
||||
++j;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SpanDeathTest, ElementAccess) {
|
||||
float arr[16];
|
||||
InitializeRange(arr, arr + 16);
|
||||
|
||||
Span<float> s (arr);
|
||||
EXPECT_DEATH(s[16], "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s[-1], "\\[xgboost\\] Condition .* failed.\n");
|
||||
|
||||
@ -312,7 +322,9 @@ TEST(Span, FrontBack) {
|
||||
ASSERT_EQ(s.front(), 0);
|
||||
ASSERT_EQ(s.back(), 3);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SpanDeathTest, FrontBack) {
|
||||
{
|
||||
Span<float, 0> s;
|
||||
EXPECT_DEATH(s.front(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
@ -340,10 +352,6 @@ TEST(Span, FirstLast) {
|
||||
for (size_t i = 0; i < first.size(); ++i) {
|
||||
ASSERT_EQ(first[i], arr[i]);
|
||||
}
|
||||
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||
EXPECT_DEATH(s.first<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
|
||||
{
|
||||
@ -359,10 +367,6 @@ TEST(Span, FirstLast) {
|
||||
for (size_t i = 0; i < last.size(); ++i) {
|
||||
ASSERT_EQ(last[i], arr[i+12]);
|
||||
}
|
||||
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||
EXPECT_DEATH(s.last<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
|
||||
// dynamic extent
|
||||
@ -379,10 +383,6 @@ TEST(Span, FirstLast) {
|
||||
ASSERT_EQ(first[i], s[i]);
|
||||
}
|
||||
|
||||
EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n");
|
||||
|
||||
delete [] arr;
|
||||
}
|
||||
|
||||
@ -399,6 +399,50 @@ TEST(Span, FirstLast) {
|
||||
ASSERT_EQ(s[12 + i], last[i]);
|
||||
}
|
||||
|
||||
delete [] arr;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SpanDeathTest, FirstLast) {
|
||||
// static extent
|
||||
{
|
||||
float arr[16];
|
||||
InitializeRange(arr, arr + 16);
|
||||
|
||||
Span<float> s (arr);
|
||||
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||
EXPECT_DEATH(s.first<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
|
||||
{
|
||||
float arr[16];
|
||||
InitializeRange(arr, arr + 16);
|
||||
|
||||
Span<float> s (arr);
|
||||
auto constexpr kOne = static_cast<Span<float, 4>::index_type>(-1);
|
||||
EXPECT_DEATH(s.last<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n");
|
||||
}
|
||||
|
||||
// dynamic extent
|
||||
{
|
||||
float *arr = new float[16];
|
||||
InitializeRange(arr, arr + 16);
|
||||
Span<float> s (arr, 16);
|
||||
EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n");
|
||||
|
||||
delete [] arr;
|
||||
}
|
||||
|
||||
{
|
||||
float *arr = new float[16];
|
||||
InitializeRange(arr, arr + 16);
|
||||
Span<float> s (arr, 16);
|
||||
EXPECT_DEATH(s.last(-1), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last(17), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s.last(32), "\\[xgboost\\] Condition .* failed.\n");
|
||||
@ -420,7 +464,11 @@ TEST(Span, Subspan) {
|
||||
auto s4 = s1.subspan(2, dynamic_extent);
|
||||
ASSERT_EQ(s1.data() + 2, s4.data());
|
||||
ASSERT_EQ(s4.size(), s1.size() - 2);
|
||||
}
|
||||
|
||||
TEST(SpanDeathTest, Subspan) {
|
||||
int arr[16] {0};
|
||||
Span<int> s1 (arr);
|
||||
EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n");
|
||||
EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n");
|
||||
|
||||
|
||||
@ -221,7 +221,7 @@ struct TestElementAccess {
|
||||
}
|
||||
};
|
||||
|
||||
TEST(GPUSpan, ElementAccess) {
|
||||
TEST(GPUSpanDeathTest, ElementAccess) {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
auto test_element_access = []() {
|
||||
thrust::host_vector<float> h_vec (16);
|
||||
|
||||
@ -59,7 +59,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) {
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC__)
|
||||
TEST(Transform, Exception) {
|
||||
TEST(TransformDeathTest, Exception) {
|
||||
size_t const kSize {16};
|
||||
std::vector<bst_float> h_in(kSize);
|
||||
const HostDeviceVector<bst_float> in_vec{h_in, -1};
|
||||
|
||||
@ -20,6 +20,15 @@
|
||||
#include "../../src/gbm/gbtree_model.h"
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "rmm/mr/device/per_device_resource.hpp"
|
||||
#include "rmm/mr/device/cuda_memory_resource.hpp"
|
||||
#include "rmm/mr/device/pool_memory_resource.hpp"
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
bool FileExists(const std::string& filename) {
|
||||
struct stat st;
|
||||
return stat(filename.c_str(), &st) == 0;
|
||||
@ -478,4 +487,57 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
|
||||
return gbm;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
using CUDAMemoryResource = rmm::mr::cuda_memory_resource;
|
||||
using PoolMemoryResource = rmm::mr::pool_memory_resource<CUDAMemoryResource>;
|
||||
class RMMAllocator {
|
||||
public:
|
||||
std::vector<std::unique_ptr<CUDAMemoryResource>> cuda_mr;
|
||||
std::vector<std::unique_ptr<PoolMemoryResource>> pool_mr;
|
||||
int n_gpu;
|
||||
RMMAllocator() : n_gpu(common::AllVisibleGPUs()) {
|
||||
int current_device;
|
||||
CHECK_EQ(cudaGetDevice(¤t_device), cudaSuccess);
|
||||
for (int i = 0; i < n_gpu; ++i) {
|
||||
CHECK_EQ(cudaSetDevice(i), cudaSuccess);
|
||||
cuda_mr.push_back(std::make_unique<CUDAMemoryResource>());
|
||||
pool_mr.push_back(std::make_unique<PoolMemoryResource>(cuda_mr[i].get()));
|
||||
}
|
||||
CHECK_EQ(cudaSetDevice(current_device), cudaSuccess);
|
||||
}
|
||||
~RMMAllocator() = default;
|
||||
};
|
||||
|
||||
void DeleteRMMResource(RMMAllocator* r) {
|
||||
delete r;
|
||||
}
|
||||
|
||||
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
|
||||
bool use_rmm_pool = false;
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
if (argv[i] == std::string("--use-rmm-pool")) {
|
||||
use_rmm_pool = true;
|
||||
}
|
||||
}
|
||||
if (!use_rmm_pool) {
|
||||
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
|
||||
}
|
||||
LOG(INFO) << "Using RMM memory pool";
|
||||
auto ptr = RMMAllocatorPtr(new RMMAllocator(), DeleteRMMResource);
|
||||
for (int i = 0; i < ptr->n_gpu; ++i) {
|
||||
rmm::mr::set_per_device_resource(rmm::cuda_device_id(i), ptr->pool_mr[i].get());
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
class RMMAllocator {};
|
||||
|
||||
void DeleteRMMResource(RMMAllocator* r) {}
|
||||
|
||||
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
|
||||
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
@ -352,5 +353,9 @@ inline int Next(DataIterHandle self) {
|
||||
return static_cast<CudaArrayIterForTest*>(self)->Next();
|
||||
}
|
||||
|
||||
class RMMAllocator;
|
||||
using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>;
|
||||
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv);
|
||||
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
|
||||
@ -3,13 +3,17 @@
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "helpers.h"
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
xgboost::Args args {{"verbosity", "2"}};
|
||||
xgboost::ConsoleLogger::Configure(args);
|
||||
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
auto rmm_alloc = xgboost::SetUpRMMResourceForCppTests(argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
||||
@ -119,7 +119,7 @@ void TestIncorrectRow() {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(RowPartitioner, IncorrectRow) {
|
||||
TEST(RowPartitionerDeathTest, IncorrectRow) {
|
||||
ASSERT_DEATH({ TestIncorrectRow(); },".*");
|
||||
}
|
||||
} // namespace tree
|
||||
|
||||
45
tests/python-gpu/conftest.py
Normal file
45
tests/python-gpu/conftest.py
Normal file
@ -0,0 +1,45 @@
|
||||
import sys
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm # noqa
|
||||
|
||||
def has_rmm():
|
||||
try:
|
||||
import rmm
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def setup_rmm_pool(request, pytestconfig):
|
||||
if pytestconfig.getoption('--use-rmm-pool'):
|
||||
if not has_rmm():
|
||||
raise ImportError('The --use-rmm-pool option requires the RMM package')
|
||||
import rmm
|
||||
from dask_cuda.utils import get_n_gpus
|
||||
rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024,
|
||||
devices=list(range(get_n_gpus())))
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def local_cuda_cluster(request, pytestconfig):
|
||||
kwargs = {}
|
||||
if hasattr(request, 'param'):
|
||||
kwargs.update(request.param)
|
||||
if pytestconfig.getoption('--use-rmm-pool'):
|
||||
if not has_rmm():
|
||||
raise ImportError('The --use-rmm-pool option requires the RMM package')
|
||||
import rmm
|
||||
from dask_cuda.utils import get_n_gpus
|
||||
rmm.reinitialize()
|
||||
kwargs['rmm_pool_size'] = '2GB'
|
||||
if tm.no_dask_cuda()['condition']:
|
||||
raise ImportError('The local_cuda_cluster fixture requires dask_cuda package')
|
||||
from dask_cuda import LocalCUDACluster
|
||||
cluster = LocalCUDACluster(**kwargs)
|
||||
yield cluster
|
||||
cluster.close()
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')
|
||||
@ -6,7 +6,6 @@ sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import test_demos as td # noqa
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_data_iterator():
|
||||
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')
|
||||
|
||||
@ -3,7 +3,6 @@ import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import unittest
|
||||
import xgboost
|
||||
import subprocess
|
||||
from hypothesis import given, strategies, settings, note
|
||||
@ -23,7 +22,6 @@ import testing as tm # noqa
|
||||
try:
|
||||
import dask.dataframe as dd
|
||||
from xgboost import dask as dxgb
|
||||
from dask_cuda import LocalCUDACluster
|
||||
from dask.distributed import Client
|
||||
from dask import array as da
|
||||
import cudf
|
||||
@ -151,50 +149,51 @@ def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client):
|
||||
assert tm.non_increasing(history['train'][dataset.metric])
|
||||
|
||||
|
||||
class TestDistributedGPU(unittest.TestCase):
|
||||
class TestDistributedGPU:
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_dataframe(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_dask_dataframe(self, local_cuda_cluster):
|
||||
with Client(local_cuda_cluster) as client:
|
||||
run_with_dask_dataframe(dxgb.DaskDMatrix, client)
|
||||
run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@given(parameter_strategy, strategies.integers(1, 20),
|
||||
tm.dataset_strategy)
|
||||
@given(params=parameter_strategy, num_rounds=strategies.integers(1, 20),
|
||||
dataset=tm.dataset_strategy)
|
||||
@settings(deadline=duration(seconds=120))
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], indirect=['local_cuda_cluster'])
|
||||
@pytest.mark.mgpu
|
||||
def test_gpu_hist(self, params, num_rounds, dataset):
|
||||
with LocalCUDACluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_gpu_hist(self, params, num_rounds, dataset, local_cuda_cluster):
|
||||
with Client(local_cuda_cluster) as client:
|
||||
run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix,
|
||||
client)
|
||||
run_gpu_hist(params, num_rounds, dataset,
|
||||
dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
def test_dask_array(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_dask_array(self, local_cuda_cluster):
|
||||
with Client(local_cuda_cluster) as client:
|
||||
run_with_dask_array(dxgb.DaskDMatrix, client)
|
||||
run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
def test_empty_dmatrix(self):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_empty_dmatrix(self, local_cuda_cluster):
|
||||
with Client(local_cuda_cluster) as client:
|
||||
parameters = {'tree_method': 'gpu_hist',
|
||||
'debug_synchronize': True}
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
def run_quantile(self, name):
|
||||
def run_quantile(self, name, local_cuda_cluster):
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows")
|
||||
|
||||
@ -217,8 +216,7 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
env[port[0]] = port[1]
|
||||
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
with Client(local_cuda_cluster) as client:
|
||||
workers = list(dxgb._get_client_workers(client).keys())
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||
futures = client.map(runit,
|
||||
@ -236,15 +234,15 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_basic(self):
|
||||
self.run_quantile('AllReduceBasic')
|
||||
def test_quantile_basic(self, local_cuda_cluster):
|
||||
self.run_quantile('AllReduceBasic', local_cuda_cluster)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
def test_quantile_same_on_all_workers(self, local_cuda_cluster):
|
||||
self.run_quantile('SameOnAllWorkers', local_cuda_cluster)
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
@ -275,10 +273,10 @@ async def run_from_dask_array_asyncio(scheduler_address):
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
def test_with_asyncio():
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_with_asyncio(local_cuda_cluster):
|
||||
with Client(local_cuda_cluster) as client:
|
||||
address = client.scheduler.address
|
||||
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||
assert isinstance(output['booster'], xgboost.Booster)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user