diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 1c9525a62..eabdb86de 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -26,9 +26,8 @@ #include "quantile.h" #include "xgboost/host_device_vector.h" -namespace xgboost { -namespace common { +namespace xgboost::common { constexpr float SketchContainer::kFactor; namespace detail { @@ -87,13 +86,13 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, return peak; } -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, bst_feature_t columns, - size_t nnz, int device, - size_t num_cuts, bool has_weight) { +size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows, + bst_feature_t columns, size_t nnz, int device, size_t num_cuts, + bool has_weight) { + auto constexpr kIntMax = static_cast(std::numeric_limits::max()); #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 // device available memory is not accurate when rmm is used. - return nnz; + return std::min(nnz, kIntMax); #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 if (sketch_batch_num_elements == 0) { @@ -106,7 +105,8 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, sketch_batch_num_elements = std::min(num_rows * static_cast(columns), nnz); } } - return sketch_batch_num_elements; + + return std::min(sketch_batch_num_elements, kIntMax); } void SortByWeight(dh::device_vector* weights, @@ -355,5 +355,4 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit()); return cuts; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/python-gpu/test_large_input.py b/tests/python-gpu/test_large_input.py index fdd3493e5..2d85cabc8 100644 --- a/tests/python-gpu/test_large_input.py +++ b/tests/python-gpu/test_large_input.py @@ -9,15 +9,16 @@ import xgboost as xgb def test_large_input(): available_bytes, _ = cp.cuda.runtime.memGetInfo() # 15 GB - required_bytes = 1.5e+10 + required_bytes = 1.5e10 if available_bytes < required_bytes: pytest.skip("Not enough memory on this device") n = 1000 m = ((1 << 31) + n - 1) // n - assert (np.log2(m * n) > 31) + assert np.log2(m * n) > 31 X = cp.ones((m, n), dtype=np.float32) y = cp.ones(m) - dmat = xgb.QuantileDMatrix(X, y) + w = cp.ones(m) + dmat = xgb.QuantileDMatrix(X, y, weight=w) booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1) del y booster.inplace_predict(X)