diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 4f8f90f65..469ce4499 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -93,6 +93,11 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows, bst_feature_t columns, size_t nnz, int device, size_t num_cuts, bool has_weight) { +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + // device available memory is not accurate when rmm is used. + return nnz; +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + if (sketch_batch_num_elements == 0) { auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); // use up to 80% of available space diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 65c51c5b3..a4c18299d 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -45,6 +45,10 @@ TEST(HistUtil, DeviceSketch) { } TEST(HistUtil, SketchBatchNumElements) { +#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 + LOG(WARNING) << "Test not runnable with RMM enabled."; + return; +#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 size_t constexpr kCols = 10000; int device; dh::safe_cuda(cudaGetDevice(&device));