Fix integer overflow. (#9380)

This commit is contained in:
Jiaming Yuan 2023-07-15 21:11:02 +08:00 committed by GitHub
parent 16eb41936d
commit 0a07900b9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 13 deletions

View File

@ -26,9 +26,8 @@
#include "quantile.h" #include "quantile.h"
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
namespace xgboost {
namespace common {
namespace xgboost::common {
constexpr float SketchContainer::kFactor; constexpr float SketchContainer::kFactor;
namespace detail { namespace detail {
@ -87,13 +86,13 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
return peak; return peak;
} }
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows,
bst_row_t num_rows, bst_feature_t columns, bst_feature_t columns, size_t nnz, int device, size_t num_cuts,
size_t nnz, int device, bool has_weight) {
size_t num_cuts, bool has_weight) { auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
// device available memory is not accurate when rmm is used. // device available memory is not accurate when rmm is used.
return nnz; return std::min(nnz, kIntMax);
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
if (sketch_batch_num_elements == 0) { if (sketch_batch_num_elements == 0) {
@ -106,7 +105,8 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz); sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
} }
} }
return sketch_batch_num_elements;
return std::min(sketch_batch_num_elements, kIntMax);
} }
void SortByWeight(dh::device_vector<float>* weights, void SortByWeight(dh::device_vector<float>* weights,
@ -355,5 +355,4 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit()); sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
return cuts; return cuts;
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost

View File

@ -9,15 +9,16 @@ import xgboost as xgb
def test_large_input(): def test_large_input():
available_bytes, _ = cp.cuda.runtime.memGetInfo() available_bytes, _ = cp.cuda.runtime.memGetInfo()
# 15 GB # 15 GB
required_bytes = 1.5e+10 required_bytes = 1.5e10
if available_bytes < required_bytes: if available_bytes < required_bytes:
pytest.skip("Not enough memory on this device") pytest.skip("Not enough memory on this device")
n = 1000 n = 1000
m = ((1 << 31) + n - 1) // n m = ((1 << 31) + n - 1) // n
assert (np.log2(m * n) > 31) assert np.log2(m * n) > 31
X = cp.ones((m, n), dtype=np.float32) X = cp.ones((m, n), dtype=np.float32)
y = cp.ones(m) y = cp.ones(m)
dmat = xgb.QuantileDMatrix(X, y) w = cp.ones(m)
dmat = xgb.QuantileDMatrix(X, y, weight=w)
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1) booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
del y del y
booster.inplace_predict(X) booster.inplace_predict(X)