Fix integer overflow. (#9380)
This commit is contained in:
parent
16eb41936d
commit
0a07900b9f
@ -26,9 +26,8 @@
|
|||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace common {
|
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
constexpr float SketchContainer::kFactor;
|
constexpr float SketchContainer::kFactor;
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@ -87,13 +86,13 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
|||||||
return peak;
|
return peak;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows,
|
||||||
bst_row_t num_rows, bst_feature_t columns,
|
bst_feature_t columns, size_t nnz, int device, size_t num_cuts,
|
||||||
size_t nnz, int device,
|
bool has_weight) {
|
||||||
size_t num_cuts, bool has_weight) {
|
auto constexpr kIntMax = static_cast<std::size_t>(std::numeric_limits<std::int32_t>::max());
|
||||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
// device available memory is not accurate when rmm is used.
|
// device available memory is not accurate when rmm is used.
|
||||||
return nnz;
|
return std::min(nnz, kIntMax);
|
||||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||||
|
|
||||||
if (sketch_batch_num_elements == 0) {
|
if (sketch_batch_num_elements == 0) {
|
||||||
@ -106,7 +105,8 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
|||||||
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
|
sketch_batch_num_elements = std::min(num_rows * static_cast<size_t>(columns), nnz);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return sketch_batch_num_elements;
|
|
||||||
|
return std::min(sketch_batch_num_elements, kIntMax);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::device_vector<float>* weights,
|
void SortByWeight(dh::device_vector<float>* weights,
|
||||||
@ -355,5 +355,4 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
|||||||
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
||||||
return cuts;
|
return cuts;
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -9,15 +9,16 @@ import xgboost as xgb
|
|||||||
def test_large_input():
|
def test_large_input():
|
||||||
available_bytes, _ = cp.cuda.runtime.memGetInfo()
|
available_bytes, _ = cp.cuda.runtime.memGetInfo()
|
||||||
# 15 GB
|
# 15 GB
|
||||||
required_bytes = 1.5e+10
|
required_bytes = 1.5e10
|
||||||
if available_bytes < required_bytes:
|
if available_bytes < required_bytes:
|
||||||
pytest.skip("Not enough memory on this device")
|
pytest.skip("Not enough memory on this device")
|
||||||
n = 1000
|
n = 1000
|
||||||
m = ((1 << 31) + n - 1) // n
|
m = ((1 << 31) + n - 1) // n
|
||||||
assert (np.log2(m * n) > 31)
|
assert np.log2(m * n) > 31
|
||||||
X = cp.ones((m, n), dtype=np.float32)
|
X = cp.ones((m, n), dtype=np.float32)
|
||||||
y = cp.ones(m)
|
y = cp.ones(m)
|
||||||
dmat = xgb.QuantileDMatrix(X, y)
|
w = cp.ones(m)
|
||||||
|
dmat = xgb.QuantileDMatrix(X, y, weight=w)
|
||||||
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
|
booster = xgb.train({"tree_method": "gpu_hist", "max_depth": 1}, dmat, 1)
|
||||||
del y
|
del y
|
||||||
booster.inplace_predict(X)
|
booster.inplace_predict(X)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user