Fix integer overflow. (#10615)
This commit is contained in:
parent
f6cae4da85
commit
b2cae34a8e
@ -24,7 +24,7 @@ struct EllpackDeviceAccessor {
|
|||||||
/*! \brief Whether or not if the matrix is dense. */
|
/*! \brief Whether or not if the matrix is dense. */
|
||||||
bool is_dense;
|
bool is_dense;
|
||||||
/*! \brief Row length for ELLPACK, equal to number of features. */
|
/*! \brief Row length for ELLPACK, equal to number of features. */
|
||||||
size_t row_stride;
|
bst_idx_t row_stride;
|
||||||
bst_idx_t base_rowid{0};
|
bst_idx_t base_rowid{0};
|
||||||
bst_idx_t n_rows{0};
|
bst_idx_t n_rows{0};
|
||||||
common::CompressedIterator<std::uint32_t> gidx_iter;
|
common::CompressedIterator<std::uint32_t> gidx_iter;
|
||||||
@ -118,7 +118,7 @@ struct EllpackDeviceAccessor {
|
|||||||
* not found). */
|
* not found). */
|
||||||
[[nodiscard]] XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
[[nodiscard]] XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||||
|
|
||||||
[[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return gidx_fvalue_map.size(); }
|
[[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return this->NumBins(); }
|
||||||
|
|
||||||
[[nodiscard]] XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); }
|
[[nodiscard]] XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); }
|
||||||
|
|
||||||
|
|||||||
@ -31,11 +31,12 @@ FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense,
|
|||||||
|
|
||||||
for (size_t i = 2; i < cut_ptrs.size(); ++i) {
|
for (size_t i = 2; i < cut_ptrs.size(); ++i) {
|
||||||
int last_start = bin_segments_h.back();
|
int last_start = bin_segments_h.back();
|
||||||
|
// Push a new group whenever the size of required bin storage is greater than the
|
||||||
|
// shared memory size.
|
||||||
if (cut_ptrs[i] - last_start > max_shmem_bins) {
|
if (cut_ptrs[i] - last_start > max_shmem_bins) {
|
||||||
feature_segments_h.push_back(i - 1);
|
feature_segments_h.push_back(i - 1);
|
||||||
bin_segments_h.push_back(cut_ptrs[i - 1]);
|
bin_segments_h.push_back(cut_ptrs[i - 1]);
|
||||||
max_group_bins = std::max(max_group_bins,
|
max_group_bins = std::max(max_group_bins, bin_segments_h.back() - last_start);
|
||||||
bin_segments_h.back() - last_start);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
feature_segments_h.push_back(cut_ptrs.size() - 1);
|
feature_segments_h.push_back(cut_ptrs.size() - 1);
|
||||||
|
|||||||
@ -23,6 +23,18 @@ struct Pair {
|
|||||||
__host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
|
__host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
|
||||||
return {lhs.first + rhs.first, lhs.second + rhs.second};
|
return {lhs.first + rhs.first, lhs.second + rhs.second};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEV_INLINE bst_idx_t IterIdx(EllpackDeviceAccessor const& matrix,
|
||||||
|
RowPartitioner::RowIndexT ridx, FeatureGroup const& group,
|
||||||
|
bst_idx_t idx, std::int32_t feature_stride) {
|
||||||
|
// ridx_local = ridx - base_rowid <== Row index local to each batch
|
||||||
|
// entry_idx = ridx_local * row_stride <== Starting entry index for this row in the matrix
|
||||||
|
// entry_idx += start_feature <== Inside a row, first column inside this feature group
|
||||||
|
// idx % feature_stride <== The feaature index local to the current feature group
|
||||||
|
// entry_idx += idx % feature_stride <== Final index.
|
||||||
|
return (ridx - matrix.base_rowid) * matrix.row_stride + group.start_feature +
|
||||||
|
idx % feature_stride;
|
||||||
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||||
@ -159,15 +171,16 @@ class HistogramAgent {
|
|||||||
idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_);
|
idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_);
|
||||||
idx += kBlockThreads) {
|
idx += kBlockThreads) {
|
||||||
Idx ridx = d_ridx_[idx / feature_stride_];
|
Idx ridx = d_ridx_[idx / feature_stride_];
|
||||||
Idx midx = (ridx - matrix_.base_rowid) * matrix_.row_stride + group_.start_feature +
|
bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)];
|
||||||
idx % feature_stride_;
|
if (matrix_.is_dense || gidx != matrix_.NullValue()) {
|
||||||
bst_bin_t gidx = matrix_.gidx_iter[midx] - group_.start_bin;
|
|
||||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
|
||||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||||
AtomicAddGpairShared(smem_arr_ + gidx, adjusted);
|
// Subtract start_bin to write to group-local histogram. If this is not a dense
|
||||||
|
// matrix, then start_bin is 0 since featuregrouping doesn't support sparse data.
|
||||||
|
AtomicAddGpairShared(smem_arr_ + gidx - group_.start_bin, adjusted);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instruction level parallelism by loop unrolling
|
// Instruction level parallelism by loop unrolling
|
||||||
// Allows the kernel to pipeline many operations while waiting for global memory
|
// Allows the kernel to pipeline many operations while waiting for global memory
|
||||||
// Increases the throughput of this kernel significantly
|
// Increases the throughput of this kernel significantly
|
||||||
@ -187,12 +200,11 @@ class HistogramAgent {
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kItemsPerThread; i++) {
|
for (int i = 0; i < kItemsPerThread; i++) {
|
||||||
gpair[i] = d_gpair_[ridx[i]];
|
gpair[i] = d_gpair_[ridx[i]];
|
||||||
gidx[i] = matrix_.gidx_iter[(ridx[i] - matrix_.base_rowid) * matrix_.row_stride +
|
gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], group_, idx[i], feature_stride_)];
|
||||||
group_.start_feature + idx[i] % feature_stride_];
|
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < kItemsPerThread; i++) {
|
for (int i = 0; i < kItemsPerThread; i++) {
|
||||||
if ((matrix_.is_dense || gidx[i] != matrix_.NumBins())) {
|
if ((matrix_.is_dense || gidx[i] != matrix_.NullValue())) {
|
||||||
auto adjusted = rounding_.ToFixedPoint(gpair[i]);
|
auto adjusted = rounding_.ToFixedPoint(gpair[i]);
|
||||||
AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
|
AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
|
||||||
}
|
}
|
||||||
@ -219,9 +231,8 @@ class HistogramAgent {
|
|||||||
__device__ void BuildHistogramWithGlobal() {
|
__device__ void BuildHistogramWithGlobal() {
|
||||||
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
|
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
|
||||||
Idx ridx = d_ridx_[idx / feature_stride_];
|
Idx ridx = d_ridx_[idx / feature_stride_];
|
||||||
bst_bin_t gidx = matrix_.gidx_iter[(ridx - matrix_.base_rowid) * matrix_.row_stride +
|
bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)];
|
||||||
group_.start_feature + idx % feature_stride_];
|
if (matrix_.is_dense || gidx != matrix_.NullValue()) {
|
||||||
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
|
|
||||||
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
|
||||||
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted);
|
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user