From 9adb812a0a5fafe44635da646701a66d59267774 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 12 Aug 2020 01:26:02 -0700 Subject: [PATCH] RMM integration plugin (#5873) * [CI] Add RMM as an optional dependency * Replace caching allocator with pool allocator from RMM * Revert "Replace caching allocator with pool allocator from RMM" This reverts commit e15845d4e72e890c2babe31a988b26503a7d9038. * Use rmm::mr::get_default_resource() * Try setting default resource (doesn't work yet) * Allocate pool_mr in the heap * Prevent leaking pool_mr handle * Separate EXPECT_DEATH() in separate test suite suffixed DeathTest * Turn off death tests for RMM * Address reviewer's feedback * Prevent leaking of cuda_mr * Fix Jenkinsfile syntax * Remove unnecessary function in Jenkinsfile * [CI] Install NCCL into RMM container * Run Python tests * Try building with RMM, CUDA 10.0 * Do not use RMM for CUDA 10.0 target * Actually test for test_rmm flag * Fix TestPythonGPU * Use CNMeM allocator, since pool allocator doesn't yet support multiGPU * Use 10.0 container to build RMM-enabled XGBoost * Revert "Use 10.0 container to build RMM-enabled XGBoost" This reverts commit 789021fa31112e25b683aef39fff375403060141. * Fix Jenkinsfile * [CI] Assign larger /dev/shm to NCCL * Use 10.2 artifact to run multi-GPU Python tests * Add CUDA 10.0 -> 11.0 cross-version test; remove CUDA 10.0 target * Rename Conda env rmm_test -> gpu_test * Use env var to opt into CNMeM pool for C++ tests * Use identical CUDA version for RMM builds and tests * Use Pytest fixtures to enable RMM pool in Python tests * Move RMM to plugin/CMakeLists.txt; use PLUGIN_RMM * Use per-device MR; use command arg in gtest * Set CMake prefix path to use Conda env * Use 0.15 nightly version of RMM * Remove unnecessary header * Fix a unit test when cudf is missing * Add RMM demos * Remove print() * Use HostDeviceVector in GPU predictor * Simplify pytest setup; use LocalCUDACluster fixture * Address reviewers' commments Co-authored-by: Hyunsu Cho --- CMakeLists.txt | 4 + Jenkinsfile | 57 ++++++--- demo/rmm_plugin/README.md | 31 +++++ demo/rmm_plugin/rmm_mgpu_with_dask.py | 27 +++++ demo/rmm_plugin/rmm_singlegpu.py | 14 +++ plugin/CMakeLists.txt | 19 +++ python-package/xgboost/data.py | 2 +- src/common/device_helpers.cuh | 27 ++++- src/common/host_device_vector.cu | 2 + src/predictor/gpu_predictor.cu | 56 ++++----- tests/ci_build/Dockerfile.gpu | 4 +- tests/ci_build/Dockerfile.rmm | 47 ++++++++ tests/ci_build/build_via_cmake.sh | 15 ++- tests/ci_build/test_python.sh | 26 ++++- tests/cpp/CMakeLists.txt | 2 + tests/cpp/common/test_span.cc | 82 ++++++++++--- tests/cpp/common/test_span.cu | 2 +- tests/cpp/common/test_transform_range.cc | 2 +- tests/cpp/helpers.cc | 62 ++++++++++ tests/cpp/helpers.h | 5 + tests/cpp/test_main.cc | 4 + .../cpp/tree/gpu_hist/test_row_partitioner.cu | 2 +- tests/pytest.ini | 2 +- tests/python-gpu/conftest.py | 45 ++++++++ tests/python-gpu/test_gpu_demos.py | 1 - tests/python-gpu/test_gpu_with_dask.py | 108 +++++++++--------- 26 files changed, 508 insertions(+), 140 deletions(-) create mode 100644 demo/rmm_plugin/README.md create mode 100644 demo/rmm_plugin/rmm_mgpu_with_dask.py create mode 100644 demo/rmm_plugin/rmm_singlegpu.py create mode 100644 tests/ci_build/Dockerfile.rmm create mode 100644 tests/python-gpu/conftest.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e564c6c6..e26265130 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ address, leak, undefined and thread.") ## Plugins option(PLUGIN_LZ4 "Build lz4 plugin" OFF) option(PLUGIN_DENSE_PARSER "Build dense parser plugin" OFF) +option(PLUGIN_RMM "Build with RAPIDS Memory Manager (RMM)" OFF) ## TODO: 1. Add check if DPC++ compiler is used for building option(PLUGIN_UPDATER_ONEAPI "DPC++ updater" OFF) option(ADD_PKGCONFIG "Add xgboost.pc into system." ON) @@ -84,6 +85,9 @@ endif (R_LIB AND GOOGLE_TEST) if (USE_AVX) message(SEND_ERROR "The option 'USE_AVX' is deprecated as experimental AVX features have been removed from XGBoost.") endif (USE_AVX) +if (PLUGIN_RMM AND NOT (USE_CUDA)) + message(SEND_ERROR "`PLUGIN_RMM` must be enabled with `USE_CUDA` flag.") +endif (PLUGIN_RMM AND NOT (USE_CUDA)) if (ENABLE_ALL_WARNINGS) if ((NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") AND (NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU")) message(SEND_ERROR "ENABLE_ALL_WARNINGS is only available for Clang and GCC.") diff --git a/Jenkinsfile b/Jenkinsfile index 30b683830..274caeb3e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -73,7 +73,7 @@ pipeline { 'build-gpu-cuda10.0': { BuildCUDA(cuda_version: '10.0') }, // The build-gpu-* builds below use Ubuntu image 'build-gpu-cuda10.1': { BuildCUDA(cuda_version: '10.1') }, - 'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2') }, + 'build-gpu-cuda10.2': { BuildCUDA(cuda_version: '10.2', build_rmm: true) }, 'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0') }, 'build-jvm-packages-gpu-cuda10.0': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.0') }, 'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') }, @@ -89,11 +89,12 @@ pipeline { script { parallel ([ 'test-python-cpu': { TestPythonCPU() }, - 'test-python-gpu-cuda10.2': { TestPythonGPU(host_cuda_version: '10.2') }, + // artifact_cuda_version doesn't apply to RMM tests; RMM tests will always match CUDA version between artifact and host env + 'test-python-gpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', test_rmm: true) }, 'test-python-gpu-cuda11.0-cross': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '11.0') }, 'test-python-gpu-cuda11.0': { TestPythonGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') }, - 'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', multi_gpu: true) }, - 'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2') }, + 'test-python-mgpu-cuda10.2': { TestPythonGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.2', multi_gpu: true, test_rmm: true) }, + 'test-cpp-gpu-cuda10.2': { TestCppGPU(artifact_cuda_version: '10.2', host_cuda_version: '10.2', test_rmm: true) }, 'test-cpp-gpu-cuda11.0': { TestCppGPU(artifact_cuda_version: '11.0', host_cuda_version: '11.0') }, 'test-jvm-jdk8-cuda10.0': { CrossTestJVMwithJDKGPU(artifact_cuda_version: '10.0', host_cuda_version: '10.0') }, 'test-jvm-jdk8': { CrossTestJVMwithJDK(jdk_version: '8', spark_version: '3.0.0') }, @@ -280,6 +281,22 @@ def BuildCUDA(args) { } echo 'Stashing C++ test executable (testxgboost)...' stash name: "xgboost_cpp_tests_cuda${args.cuda_version}", includes: 'build/testxgboost' + if (args.build_rmm) { + echo "Build with CUDA ${args.cuda_version} and RMM" + container_type = "rmm" + docker_binary = "docker" + docker_args = "--build-arg CUDA_VERSION=${args.cuda_version}" + sh """ + rm -rf build/ + ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/build_via_cmake.sh --conda-env=gpu_test -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON ${arch_flag} + ${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "cd python-package && rm -rf dist/* && python setup.py bdist_wheel --universal" + ${dockerRun} ${container_type} ${docker_binary} ${docker_args} python tests/ci_build/rename_whl.py python-package/dist/*.whl ${commit_id} manylinux2010_x86_64 + """ + echo 'Stashing Python wheel...' + stash name: "xgboost_whl_rmm_cuda${args.cuda_version}", includes: 'python-package/dist/*.whl' + echo 'Stashing C++ test executable (testxgboost)...' + stash name: "xgboost_cpp_tests_rmm_cuda${args.cuda_version}", includes: 'build/testxgboost' + } deleteDir() } } @@ -366,18 +383,15 @@ def TestPythonGPU(args) { def container_type = "gpu" def docker_binary = "nvidia-docker" def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" - if (args.multi_gpu) { - echo "Using multiple GPUs" - // Allocate extra space in /dev/shm to enable NCCL - def docker_extra_params = "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" - sh """ - ${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh mgpu - """ - } else { - echo "Using a single GPU" - sh """ - ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh gpu - """ + def mgpu_indicator = (args.multi_gpu) ? 'mgpu' : 'gpu' + // Allocate extra space in /dev/shm to enable NCCL + def docker_extra_params = (args.multi_gpu) ? "CI_DOCKER_EXTRA_PARAMS_INIT='--shm-size=4g'" : '' + sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator}" + if (args.test_rmm) { + sh "rm -rfv build/ python-package/dist/" + unstash name: "xgboost_whl_rmm_cuda${args.host_cuda_version}" + unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}" + sh "${docker_extra_params} ${dockerRun} ${container_type} ${docker_binary} ${docker_args} tests/ci_build/test_python.sh ${mgpu_indicator} --use-rmm-pool" } deleteDir() } @@ -408,6 +422,17 @@ def TestCppGPU(args) { def docker_binary = "nvidia-docker" def docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" sh "${dockerRun} ${container_type} ${docker_binary} ${docker_args} build/testxgboost" + if (args.test_rmm) { + sh "rm -rfv build/" + unstash name: "xgboost_cpp_tests_rmm_cuda${args.host_cuda_version}" + echo "Test C++, CUDA ${args.host_cuda_version} with RMM" + container_type = "rmm" + docker_binary = "nvidia-docker" + docker_args = "--build-arg CUDA_VERSION=${args.host_cuda_version}" + sh """ + ${dockerRun} ${container_type} ${docker_binary} ${docker_args} bash -c "source activate gpu_test && build/testxgboost --use-rmm-pool --gtest_filter=-*DeathTest.*" + """ + } deleteDir() } } diff --git a/demo/rmm_plugin/README.md b/demo/rmm_plugin/README.md new file mode 100644 index 000000000..ad73c61f3 --- /dev/null +++ b/demo/rmm_plugin/README.md @@ -0,0 +1,31 @@ +Using XGBoost with RAPIDS Memory Manager (RMM) plugin (EXPERIMENTAL) +==================================================================== +[RAPIDS Memory Manager (RMM)](https://github.com/rapidsai/rmm) library provides a collection of +efficient memory allocators for NVIDIA GPUs. It is now possible to use XGBoost with memory +allocators provided by RMM, by enabling the RMM integration plugin. + +The demos in this directory highlights one RMM allocator in particular: **the pool sub-allocator**. +This allocator addresses the slow speed of `cudaMalloc()` by allocating a large chunk of memory +upfront. Subsequent allocations will draw from the pool of already allocated memory and thus avoid +the overhead of calling `cudaMalloc()` directly. See +[this GTC talk slides](https://on-demand.gputechconf.com/gtc/2015/presentation/S5530-Stephen-Jones.pdf) +for more details. + +Before running the demos, ensure that XGBoost is compiled with the RMM plugin enabled. To do this, +run CMake with option `-DPLUGIN_RMM=ON` (`-DUSE_CUDA=ON` also required): +``` +cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON +make -j4 +``` +CMake will attempt to locate the RMM library in your build environment. You may choose to build +RMM from the source, or install it using the Conda package manager. If CMake cannot find RMM, you +should specify the location of RMM with the CMake prefix: +``` +# If using Conda: +cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX +# If using RMM installed with a custom location +cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON -DPLUGIN_RMM=ON -DCMAKE_PREFIX_PATH=/path/to/rmm +``` + +* [Using RMM with a single GPU](./rmm_singlegpu.py) +* [Using RMM with a local Dask cluster consisting of multiple GPUs](./rmm_mgpu_with_dask.py) diff --git a/demo/rmm_plugin/rmm_mgpu_with_dask.py b/demo/rmm_plugin/rmm_mgpu_with_dask.py new file mode 100644 index 000000000..eac0c5da4 --- /dev/null +++ b/demo/rmm_plugin/rmm_mgpu_with_dask.py @@ -0,0 +1,27 @@ +import xgboost as xgb +from sklearn.datasets import make_classification +import dask +from dask.distributed import Client +from dask_cuda import LocalCUDACluster + +def main(client): + X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3) + X = dask.array.from_array(X) + y = dask.array.from_array(y) + dtrain = xgb.dask.DaskDMatrix(client, X, label=y) + + params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3, + 'tree_method': 'gpu_hist'} + output = xgb.dask.train(client, params, dtrain, num_boost_round=100, + evals=[(dtrain, 'train')]) + bst = output['booster'] + history = output['history'] + for i, e in enumerate(history['train']['merror']): + print(f'[{i}] train-merror: {e}') + +if __name__ == '__main__': + # To use RMM pool allocator with a GPU Dask cluster, just add rmm_pool_size option to + # LocalCUDACluster constructor. + with LocalCUDACluster(rmm_pool_size='2GB') as cluster: + with Client(cluster) as client: + main(client) diff --git a/demo/rmm_plugin/rmm_singlegpu.py b/demo/rmm_plugin/rmm_singlegpu.py new file mode 100644 index 000000000..c56e0a0ce --- /dev/null +++ b/demo/rmm_plugin/rmm_singlegpu.py @@ -0,0 +1,14 @@ +import xgboost as xgb +import rmm +from sklearn.datasets import make_classification + +# Initialize RMM pool allocator +rmm.reinitialize(pool_allocator=True) + +X, y = make_classification(n_samples=10000, n_informative=5, n_classes=3) +dtrain = xgb.DMatrix(X, label=y) + +params = {'max_depth': 8, 'eta': 0.01, 'objective': 'multi:softprob', 'num_class': 3, + 'tree_method': 'gpu_hist'} +# XGBoost will automatically use the RMM pool allocator +bst = xgb.train(params, dtrain, num_boost_round=100, evals=[(dtrain, 'train')]) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index d253b398a..8afce01ff 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -7,6 +7,25 @@ if (PLUGIN_DENSE_PARSER) target_sources(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/dense_parser/dense_libsvm.cc) endif (PLUGIN_DENSE_PARSER) +if (PLUGIN_RMM) + find_path(RMM_INCLUDE "rmm" + HINTS "$ENV{RMM_ROOT}/include") + + find_library(RMM_LIBRARY "rmm" + HINTS "$ENV{RMM_ROOT}/lib" "$ENV{RMM_ROOT}/build") + + if ((NOT RMM_LIBRARY) OR (NOT RMM_INCLUDE)) + message(FATAL_ERROR "Could not locate RMM library") + endif () + + message(STATUS "RMM: RMM_LIBRARY set to ${RMM_LIBRARY}") + message(STATUS "RMM: RMM_INCLUDE set to ${RMM_INCLUDE}") + + target_include_directories(objxgboost PUBLIC ${RMM_INCLUDE}) + target_link_libraries(objxgboost PUBLIC ${RMM_LIBRARY} cuda) + target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_RMM=1) +endif (PLUGIN_RMM) + if (PLUGIN_UPDATER_ONEAPI) add_library(oneapi_plugin OBJECT ${xgboost_SOURCE_DIR}/plugin/updater_oneapi/regression_obj_oneapi.cc diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index b29eac795..9491efd1c 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -317,7 +317,7 @@ def _is_cudf_df(data): import cudf except ImportError: return False - return isinstance(data, cudf.DataFrame) + return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame) def _cudf_array_interfaces(data): diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index d339dc72d..beb94680f 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -36,7 +36,12 @@ #ifdef XGBOOST_USE_NCCL #include "nccl.h" -#endif +#endif // XGBOOST_USE_NCCL + +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 +#include "rmm/mr/device/per_device_resource.hpp" +#include "rmm/mr/device/thrust_allocator_adaptor.hpp" +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) @@ -370,12 +375,21 @@ inline void DebugSyncDevice(std::string file="", int32_t line = -1) { } namespace detail { + +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 +template +using XGBBaseDeviceAllocator = rmm::mr::thrust_allocator; +#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 +template +using XGBBaseDeviceAllocator = thrust::device_malloc_allocator; +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + /** * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. */ template -struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator { - using SuperT = thrust::device_malloc_allocator; +struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator { + using SuperT = XGBBaseDeviceAllocator; using pointer = thrust::device_ptr; // NOLINT template struct rebind // NOLINT @@ -391,10 +405,15 @@ struct XGBDefaultDeviceAllocatorImpl : thrust::device_malloc_allocator { GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); return SuperT::deallocate(ptr, n); } +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + XGBDefaultDeviceAllocatorImpl() + : SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{0}) {} +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 }; /** - * \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs allocations if verbose. Does not initialise memory on construction. + * \brief Caching memory allocator, uses cub::CachingDeviceAllocator as a back-end and logs + * allocations if verbose. Does not initialise memory on construction. */ template struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator { diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 7470f1c07..39a0fbe9e 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -11,6 +11,7 @@ #include "xgboost/data.h" #include "xgboost/host_device_vector.h" +#include "xgboost/tree_model.h" #include "device_helpers.cuh" namespace xgboost { @@ -402,6 +403,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 9a498136d..c05688eaf 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -213,39 +213,21 @@ __global__ void PredictKernel(Data data, class DeviceModel { public: - dh::device_vector nodes; - dh::device_vector tree_segments; - dh::device_vector tree_group; + // Need to lazily construct the vectors because GPU id is only known at runtime + HostDeviceVector nodes; + HostDeviceVector tree_segments; + HostDeviceVector tree_group; size_t tree_beg_; // NOLINT size_t tree_end_; // NOLINT int num_group; - void CopyModel(const gbm::GBTreeModel& model, - const thrust::host_vector& h_tree_segments, - const thrust::host_vector& h_nodes, - size_t tree_begin, size_t tree_end) { - nodes.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), - sizeof(RegTree::Node) * h_nodes.size(), - cudaMemcpyHostToDevice)); - tree_segments.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_segments.data().get(), h_tree_segments.data(), - sizeof(size_t) * h_tree_segments.size(), - cudaMemcpyHostToDevice)); - tree_group.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpyAsync(tree_group.data().get(), model.tree_info.data(), - sizeof(int) * model.tree_info.size(), - cudaMemcpyHostToDevice)); - this->tree_beg_ = tree_begin; - this->tree_end_ = tree_end; - this->num_group = model.learner_model_param->num_output_group; - } - void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) { dh::safe_cuda(cudaSetDevice(gpu_id)); + CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device - thrust::host_vector h_tree_segments{}; + tree_segments = std::move(HostDeviceVector({}, gpu_id)); + auto& h_tree_segments = tree_segments.HostVector(); h_tree_segments.reserve((tree_end - tree_begin) + 1); size_t sum = 0; h_tree_segments.push_back(sum); @@ -254,13 +236,21 @@ class DeviceModel { h_tree_segments.push_back(sum); } - thrust::host_vector h_nodes(h_tree_segments.back()); + nodes = std::move(HostDeviceVector(h_tree_segments.back(), RegTree::Node(), + gpu_id)); + auto& h_nodes = nodes.HostVector(); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); std::copy(src_nodes.begin(), src_nodes.end(), h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); } - CopyModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); + + tree_group = std::move(HostDeviceVector(model.tree_info.size(), 0, gpu_id)); + auto& h_tree_group = tree_group.HostVector(); + std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size()); + this->tree_beg_ = tree_begin; + this->tree_end_ = tree_end; + this->num_group = model.learner_model_param->num_output_group; } }; @@ -287,8 +277,8 @@ class GPUPredictor : public xgboost::Predictor { dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( PredictKernel, data, - dh::ToSpan(model_.nodes), predictions->DeviceSpan().subspan(batch_offset), - dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), + model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), model_.tree_beg_, model_.tree_end_, num_features, num_rows, entry_start, use_shared, model_.num_group); } @@ -303,8 +293,8 @@ class GPUPredictor : public xgboost::Predictor { dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( PredictKernel, batch, - dh::ToSpan(model_.nodes), out_preds->DeviceSpan().subspan(batch_offset), - dh::ToSpan(model_.tree_segments), dh::ToSpan(model_.tree_group), + model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), + model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, entry_start, use_shared, model_.num_group); } @@ -435,8 +425,8 @@ class GPUPredictor : public xgboost::Predictor { dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( PredictKernel, m->Value(), - dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(), - dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group), + d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(), + d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(), tree_begin, tree_end, m->NumColumns(), info.num_row_, entry_start, use_shared, output_groups); } diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 055caf3c1..efc3d9186 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -17,8 +17,8 @@ ENV PATH=/opt/python/bin:$PATH # Create new Conda environment with cuDF, Dask, and cuPy RUN \ - conda create -n gpu_test -c rapidsai -c nvidia -c conda-forge -c defaults \ - python=3.7 cudf=0.14 cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \ + conda create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \ + python=3.7 cudf=0.15* rmm=0.15* cudatoolkit=$CUDA_VERSION dask dask-cuda dask-cudf cupy \ numpy pytest scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis ENV GOSU_VERSION 1.10 diff --git a/tests/ci_build/Dockerfile.rmm b/tests/ci_build/Dockerfile.rmm new file mode 100644 index 000000000..a92f09c47 --- /dev/null +++ b/tests/ci_build/Dockerfile.rmm @@ -0,0 +1,47 @@ +ARG CUDA_VERSION +FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu16.04 +ARG CUDA_VERSION + +# Environment +ENV DEBIAN_FRONTEND noninteractive +SHELL ["/bin/bash", "-c"] # Use Bash as shell + +# Install all basic requirements +RUN \ + apt-get update && \ + apt-get install -y wget unzip bzip2 libgomp1 build-essential ninja-build git && \ + # Python + wget -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3.sh -b -p /opt/python && \ + # CMake + wget -nv -nc https://cmake.org/files/v3.13/cmake-3.13.0-Linux-x86_64.sh --no-check-certificate && \ + bash cmake-3.13.0-Linux-x86_64.sh --skip-license --prefix=/usr + +# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) +RUN \ + export CUDA_SHORT=`echo $CUDA_VERSION | egrep -o '[0-9]+\.[0-9]'` && \ + export NCCL_VERSION=2.7.5-1 && \ + apt-get update && \ + apt-get install -y --allow-downgrades --allow-change-held-packages libnccl2=${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-dev=${NCCL_VERSION}+cuda${CUDA_SHORT} + +ENV PATH=/opt/python/bin:$PATH + +# Create new Conda environment with RMM +RUN \ + conda create -n gpu_test -c nvidia -c rapidsai-nightly -c rapidsai -c conda-forge -c defaults \ + python=3.7 rmm=0.15* cudatoolkit=$CUDA_VERSION + +ENV GOSU_VERSION 1.10 + +# Install lightweight sudo (not bound to TTY) +RUN set -ex; \ + wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ + chmod +x /usr/local/bin/gosu && \ + gosu nobody true + +# Default entry-point to use if running locally +# It will preserve attributes of created files +COPY entrypoint.sh /scripts/ + +WORKDIR /workspace +ENTRYPOINT ["/scripts/entrypoint.sh"] diff --git a/tests/ci_build/build_via_cmake.sh b/tests/ci_build/build_via_cmake.sh index ffb780c8d..44c9b5d4a 100755 --- a/tests/ci_build/build_via_cmake.sh +++ b/tests/ci_build/build_via_cmake.sh @@ -1,10 +1,23 @@ #!/usr/bin/env bash set -e +if [[ "$1" == --conda-env=* ]] +then + conda_env=$(echo "$1" | sed 's/^--conda-env=//g' -) + echo "Activating Conda environment ${conda_env}" + shift 1 + cmake_args="$@" + source activate ${conda_env} + cmake_prefix_flag="-DCMAKE_PREFIX_PATH=$CONDA_PREFIX" +else + cmake_args="$@" + cmake_prefix_flag='' +fi + rm -rf build mkdir build cd build -cmake .. "$@" -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja +cmake .. ${cmake_args} -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DENABLE_ALL_WARNINGS=ON -GNinja ${cmake_prefix_flag} ninja clean time ninja -v cd .. diff --git a/tests/ci_build/test_python.sh b/tests/ci_build/test_python.sh index 55c3037d0..28cd5e526 100755 --- a/tests/ci_build/test_python.sh +++ b/tests/ci_build/test_python.sh @@ -2,7 +2,15 @@ set -e set -x -suite=$1 +if [ "$#" -lt 1 ] +then + suite='' + args='' +else + suite=$1 + shift 1 + args="$@" +fi # Install XGBoost Python package function install_xgboost { @@ -26,34 +34,40 @@ function install_xgboost { fi } +function uninstall_xgboost { + pip uninstall -y xgboost +} + # Run specified test suite case "$suite" in gpu) source activate gpu_test install_xgboost - pytest -v -s -rxXs --fulltrace -m "not mgpu" tests/python-gpu + pytest -v -s -rxXs --fulltrace -m "not mgpu" ${args} tests/python-gpu + uninstall_xgboost ;; mgpu) source activate gpu_test install_xgboost - pytest -v -s -rxXs --fulltrace -m "mgpu" tests/python-gpu + pytest -v -s -rxXs --fulltrace -m "mgpu" ${args} tests/python-gpu cd tests/distributed ./runtests-gpu.sh - cd - + uninstall_xgboost ;; cpu) source activate cpu_test install_xgboost - pytest -v -s --fulltrace tests/python + pytest -v -s -rxXs --fulltrace ${args} tests/python cd tests/distributed ./runtests.sh + uninstall_xgboost ;; *) - echo "Usage: $0 {gpu|mgpu|cpu}" + echo "Usage: $0 {gpu|mgpu|cpu} [extra args to pass to pytest]" exit 1 ;; esac diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index e40783b6b..b733eed23 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -37,6 +37,8 @@ if (USE_CUDA) $<$:${GEN_CODE}>) target_compile_definitions(testxgboost PRIVATE -DXGBOOST_USE_CUDA=1) + find_package(CUDA) + target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS}) set_target_properties(testxgboost PROPERTIES CUDA_SEPARABLE_COMPILATION OFF) diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index c49763f3f..53fd32e1a 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -97,11 +97,6 @@ TEST(Span, FromPtrLen) { } } - { - auto lazy = [=]() {Span tmp (arr, 5);}; - EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n"); - } - // dynamic extent { Span s (arr, 16); @@ -122,6 +117,15 @@ TEST(Span, FromPtrLen) { } } +TEST(SpanDeathTest, FromPtrLen) { + float arr[16]; + InitializeRange(arr, arr+16); + { + auto lazy = [=]() {Span tmp (arr, 5);}; + EXPECT_DEATH(lazy(), "\\[xgboost\\] Condition .* failed.\n"); + } +} + TEST(Span, FromFirstLast) { float arr[16]; InitializeRange(arr, arr+16); @@ -285,7 +289,13 @@ TEST(Span, ElementAccess) { ASSERT_EQ(i, arr[j]); ++j; } +} +TEST(SpanDeathTest, ElementAccess) { + float arr[16]; + InitializeRange(arr, arr + 16); + + Span s (arr); EXPECT_DEATH(s[16], "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s[-1], "\\[xgboost\\] Condition .* failed.\n"); @@ -312,7 +322,9 @@ TEST(Span, FrontBack) { ASSERT_EQ(s.front(), 0); ASSERT_EQ(s.back(), 3); } +} +TEST(SpanDeathTest, FrontBack) { { Span s; EXPECT_DEATH(s.front(), "\\[xgboost\\] Condition .* failed.\n"); @@ -340,10 +352,6 @@ TEST(Span, FirstLast) { for (size_t i = 0; i < first.size(); ++i) { ASSERT_EQ(first[i], arr[i]); } - auto constexpr kOne = static_cast::index_type>(-1); - EXPECT_DEATH(s.first(), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n"); } { @@ -359,10 +367,6 @@ TEST(Span, FirstLast) { for (size_t i = 0; i < last.size(); ++i) { ASSERT_EQ(last[i], arr[i+12]); } - auto constexpr kOne = static_cast::index_type>(-1); - EXPECT_DEATH(s.last(), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n"); } // dynamic extent @@ -379,10 +383,6 @@ TEST(Span, FirstLast) { ASSERT_EQ(first[i], s[i]); } - EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n"); - EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n"); - delete [] arr; } @@ -399,6 +399,50 @@ TEST(Span, FirstLast) { ASSERT_EQ(s[12 + i], last[i]); } + delete [] arr; + } +} + +TEST(SpanDeathTest, FirstLast) { + // static extent + { + float arr[16]; + InitializeRange(arr, arr + 16); + + Span s (arr); + auto constexpr kOne = static_cast::index_type>(-1); + EXPECT_DEATH(s.first(), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.first<17>(), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.first<32>(), "\\[xgboost\\] Condition .* failed.\n"); + } + + { + float arr[16]; + InitializeRange(arr, arr + 16); + + Span s (arr); + auto constexpr kOne = static_cast::index_type>(-1); + EXPECT_DEATH(s.last(), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.last<17>(), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.last<32>(), "\\[xgboost\\] Condition .* failed.\n"); + } + + // dynamic extent + { + float *arr = new float[16]; + InitializeRange(arr, arr + 16); + Span s (arr, 16); + EXPECT_DEATH(s.first(-1), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.first(17), "\\[xgboost\\] Condition .* failed.\n"); + EXPECT_DEATH(s.first(32), "\\[xgboost\\] Condition .* failed.\n"); + + delete [] arr; + } + + { + float *arr = new float[16]; + InitializeRange(arr, arr + 16); + Span s (arr, 16); EXPECT_DEATH(s.last(-1), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.last(17), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s.last(32), "\\[xgboost\\] Condition .* failed.\n"); @@ -420,7 +464,11 @@ TEST(Span, Subspan) { auto s4 = s1.subspan(2, dynamic_extent); ASSERT_EQ(s1.data() + 2, s4.data()); ASSERT_EQ(s4.size(), s1.size() - 2); +} +TEST(SpanDeathTest, Subspan) { + int arr[16] {0}; + Span s1 (arr); EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n"); diff --git a/tests/cpp/common/test_span.cu b/tests/cpp/common/test_span.cu index 00e00d4f4..7e9336902 100644 --- a/tests/cpp/common/test_span.cu +++ b/tests/cpp/common/test_span.cu @@ -221,7 +221,7 @@ struct TestElementAccess { } }; -TEST(GPUSpan, ElementAccess) { +TEST(GPUSpanDeathTest, ElementAccess) { dh::safe_cuda(cudaSetDevice(0)); auto test_element_access = []() { thrust::host_vector h_vec (16); diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 68319dfd3..84163ea66 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -59,7 +59,7 @@ TEST(Transform, DeclareUnifiedTest(Basic)) { } #if !defined(__CUDACC__) -TEST(Transform, Exception) { +TEST(TransformDeathTest, Exception) { size_t const kSize {16}; std::vector h_in(kSize); const HostDeviceVector in_vec{h_in, -1}; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 2274e57e7..858b65198 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -20,6 +20,15 @@ #include "../../src/gbm/gbtree_model.h" #include "xgboost/predictor.h" +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 +#include +#include +#include +#include "rmm/mr/device/per_device_resource.hpp" +#include "rmm/mr/device/cuda_memory_resource.hpp" +#include "rmm/mr/device/pool_memory_resource.hpp" +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + bool FileExists(const std::string& filename) { struct stat st; return stat(filename.c_str(), &st) == 0; @@ -478,4 +487,57 @@ std::unique_ptr CreateTrainedGBM( return gbm; } +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + +using CUDAMemoryResource = rmm::mr::cuda_memory_resource; +using PoolMemoryResource = rmm::mr::pool_memory_resource; +class RMMAllocator { + public: + std::vector> cuda_mr; + std::vector> pool_mr; + int n_gpu; + RMMAllocator() : n_gpu(common::AllVisibleGPUs()) { + int current_device; + CHECK_EQ(cudaGetDevice(¤t_device), cudaSuccess); + for (int i = 0; i < n_gpu; ++i) { + CHECK_EQ(cudaSetDevice(i), cudaSuccess); + cuda_mr.push_back(std::make_unique()); + pool_mr.push_back(std::make_unique(cuda_mr[i].get())); + } + CHECK_EQ(cudaSetDevice(current_device), cudaSuccess); + } + ~RMMAllocator() = default; +}; + +void DeleteRMMResource(RMMAllocator* r) { + delete r; +} + +RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) { + bool use_rmm_pool = false; + for (int i = 1; i < argc; ++i) { + if (argv[i] == std::string("--use-rmm-pool")) { + use_rmm_pool = true; + } + } + if (!use_rmm_pool) { + return RMMAllocatorPtr(nullptr, DeleteRMMResource); + } + LOG(INFO) << "Using RMM memory pool"; + auto ptr = RMMAllocatorPtr(new RMMAllocator(), DeleteRMMResource); + for (int i = 0; i < ptr->n_gpu; ++i) { + rmm::mr::set_per_device_resource(rmm::cuda_device_id(i), ptr->pool_mr[i].get()); + } + return ptr; +} +#else // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 +class RMMAllocator {}; + +void DeleteRMMResource(RMMAllocator* r) {} + +RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) { + return RMMAllocatorPtr(nullptr, DeleteRMMResource); +} +#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1 + } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 0783b9a89..5d4ce6cef 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -352,5 +353,9 @@ inline int Next(DataIterHandle self) { return static_cast(self)->Next(); } +class RMMAllocator; +using RMMAllocatorPtr = std::unique_ptr; +RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); + } // namespace xgboost #endif diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc index d9f3e8f33..b93329c2e 100644 --- a/tests/cpp/test_main.cc +++ b/tests/cpp/test_main.cc @@ -3,13 +3,17 @@ #include #include #include +#include #include +#include "helpers.h" + int main(int argc, char ** argv) { xgboost::Args args {{"verbosity", "2"}}; xgboost::ConsoleLogger::Configure(args); testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; + auto rmm_alloc = xgboost::SetUpRMMResourceForCppTests(argc, argv); return RUN_ALL_TESTS(); } diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 3210a25a1..4879ca080 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -119,7 +119,7 @@ void TestIncorrectRow() { }); } -TEST(RowPartitioner, IncorrectRow) { +TEST(RowPartitionerDeathTest, IncorrectRow) { ASSERT_DEATH({ TestIncorrectRow(); },".*"); } } // namespace tree diff --git a/tests/pytest.ini b/tests/pytest.ini index 136782056..5a0d27a6c 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -2,4 +2,4 @@ markers = mgpu: Mark a test that requires multiple GPUs to run. ci: Mark a test that runs only on CI. - gtest: Mark a test that requires C++ Google Test executable. \ No newline at end of file + gtest: Mark a test that requires C++ Google Test executable. diff --git a/tests/python-gpu/conftest.py b/tests/python-gpu/conftest.py new file mode 100644 index 000000000..1865ce529 --- /dev/null +++ b/tests/python-gpu/conftest.py @@ -0,0 +1,45 @@ +import sys +import pytest +import logging + +sys.path.append("tests/python") +import testing as tm # noqa + +def has_rmm(): + try: + import rmm + return True + except ImportError: + return False + +@pytest.fixture(scope='session', autouse=True) +def setup_rmm_pool(request, pytestconfig): + if pytestconfig.getoption('--use-rmm-pool'): + if not has_rmm(): + raise ImportError('The --use-rmm-pool option requires the RMM package') + import rmm + from dask_cuda.utils import get_n_gpus + rmm.reinitialize(pool_allocator=True, initial_pool_size=1024*1024*1024, + devices=list(range(get_n_gpus()))) + +@pytest.fixture(scope='function') +def local_cuda_cluster(request, pytestconfig): + kwargs = {} + if hasattr(request, 'param'): + kwargs.update(request.param) + if pytestconfig.getoption('--use-rmm-pool'): + if not has_rmm(): + raise ImportError('The --use-rmm-pool option requires the RMM package') + import rmm + from dask_cuda.utils import get_n_gpus + rmm.reinitialize() + kwargs['rmm_pool_size'] = '2GB' + if tm.no_dask_cuda()['condition']: + raise ImportError('The local_cuda_cluster fixture requires dask_cuda package') + from dask_cuda import LocalCUDACluster + cluster = LocalCUDACluster(**kwargs) + yield cluster + cluster.close() + +def pytest_addoption(parser): + parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool') diff --git a/tests/python-gpu/test_gpu_demos.py b/tests/python-gpu/test_gpu_demos.py index a3a9aaff5..f74d2adc2 100644 --- a/tests/python-gpu/test_gpu_demos.py +++ b/tests/python-gpu/test_gpu_demos.py @@ -6,7 +6,6 @@ sys.path.append("tests/python") import testing as tm import test_demos as td # noqa - @pytest.mark.skipif(**tm.no_cupy()) def test_data_iterator(): script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py') diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index e5c901813..a06bfc283 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -3,7 +3,6 @@ import os import pytest import numpy as np import asyncio -import unittest import xgboost import subprocess from hypothesis import given, strategies, settings, note @@ -23,7 +22,6 @@ import testing as tm # noqa try: import dask.dataframe as dd from xgboost import dask as dxgb - from dask_cuda import LocalCUDACluster from dask.distributed import Client from dask import array as da import cudf @@ -151,50 +149,51 @@ def run_gpu_hist(params, num_rounds, dataset, DMatrixT, client): assert tm.non_increasing(history['train'][dataset.metric]) -class TestDistributedGPU(unittest.TestCase): +class TestDistributedGPU: @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_dask_cudf()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu - def test_dask_dataframe(self): - with LocalCUDACluster() as cluster: - with Client(cluster) as client: - run_with_dask_dataframe(dxgb.DaskDMatrix, client) - run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client) + def test_dask_dataframe(self, local_cuda_cluster): + with Client(local_cuda_cluster) as client: + run_with_dask_dataframe(dxgb.DaskDMatrix, client) + run_with_dask_dataframe(dxgb.DaskDeviceQuantileDMatrix, client) - @given(parameter_strategy, strategies.integers(1, 20), - tm.dataset_strategy) + @given(params=parameter_strategy, num_rounds=strategies.integers(1, 20), + dataset=tm.dataset_strategy) @settings(deadline=duration(seconds=120)) + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) + @pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], indirect=['local_cuda_cluster']) @pytest.mark.mgpu - def test_gpu_hist(self, params, num_rounds, dataset): - with LocalCUDACluster(n_workers=2) as cluster: - with Client(cluster) as client: - run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, - client) - run_gpu_hist(params, num_rounds, dataset, - dxgb.DaskDeviceQuantileDMatrix, client) + def test_gpu_hist(self, params, num_rounds, dataset, local_cuda_cluster): + with Client(local_cuda_cluster) as client: + run_gpu_hist(params, num_rounds, dataset, dxgb.DaskDMatrix, + client) + run_gpu_hist(params, num_rounds, dataset, + dxgb.DaskDeviceQuantileDMatrix, client) @pytest.mark.skipif(**tm.no_cupy()) + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu - def test_dask_array(self): - with LocalCUDACluster() as cluster: - with Client(cluster) as client: - run_with_dask_array(dxgb.DaskDMatrix, client) - run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client) + def test_dask_array(self, local_cuda_cluster): + with Client(local_cuda_cluster) as client: + run_with_dask_array(dxgb.DaskDMatrix, client) + run_with_dask_array(dxgb.DaskDeviceQuantileDMatrix, client) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu - def test_empty_dmatrix(self): - with LocalCUDACluster() as cluster: - with Client(cluster) as client: - parameters = {'tree_method': 'gpu_hist', - 'debug_synchronize': True} - run_empty_dmatrix_reg(client, parameters) - run_empty_dmatrix_cls(client, parameters) + def test_empty_dmatrix(self, local_cuda_cluster): + with Client(local_cuda_cluster) as client: + parameters = {'tree_method': 'gpu_hist', + 'debug_synchronize': True} + run_empty_dmatrix_reg(client, parameters) + run_empty_dmatrix_cls(client, parameters) - def run_quantile(self, name): + def run_quantile(self, name, local_cuda_cluster): if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows") @@ -217,34 +216,33 @@ class TestDistributedGPU(unittest.TestCase): env[port[0]] = port[1] return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) - with LocalCUDACluster() as cluster: - with Client(cluster) as client: - workers = list(dxgb._get_client_workers(client).keys()) - rabit_args = client.sync(dxgb._get_rabit_args, workers, client) - futures = client.map(runit, - workers, - pure=False, - workers=workers, - rabit_args=rabit_args) - results = client.gather(futures) - for ret in results: - msg = ret.stdout.decode('utf-8') - assert msg.find('1 test from GPUQuantile') != -1, msg - assert ret.returncode == 0, msg + with Client(local_cuda_cluster) as client: + workers = list(dxgb._get_client_workers(client).keys()) + rabit_args = client.sync(dxgb._get_rabit_args, workers, client) + futures = client.map(runit, + workers, + pure=False, + workers=workers, + rabit_args=rabit_args) + results = client.gather(futures) + for ret in results: + msg = ret.stdout.decode('utf-8') + assert msg.find('1 test from GPUQuantile') != -1, msg + assert ret.returncode == 0, msg @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu @pytest.mark.gtest - def test_quantile_basic(self): - self.run_quantile('AllReduceBasic') + def test_quantile_basic(self, local_cuda_cluster): + self.run_quantile('AllReduceBasic', local_cuda_cluster) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu @pytest.mark.gtest - def test_quantile_same_on_all_workers(self): - self.run_quantile('SameOnAllWorkers') + def test_quantile_same_on_all_workers(self, local_cuda_cluster): + self.run_quantile('SameOnAllWorkers', local_cuda_cluster) async def run_from_dask_array_asyncio(scheduler_address): @@ -275,11 +273,11 @@ async def run_from_dask_array_asyncio(scheduler_address): @pytest.mark.skipif(**tm.no_dask()) +@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu -def test_with_asyncio(): - with LocalCUDACluster() as cluster: - with Client(cluster) as client: - address = client.scheduler.address - output = asyncio.run(run_from_dask_array_asyncio(address)) - assert isinstance(output['booster'], xgboost.Booster) - assert isinstance(output['history'], dict) +def test_with_asyncio(local_cuda_cluster): + with Client(local_cuda_cluster) as client: + address = client.scheduler.address + output = asyncio.run(run_from_dask_array_asyncio(address)) + assert isinstance(output['booster'], xgboost.Booster) + assert isinstance(output['history'], dict)