Optimize adapter element counting on GPU. (#9209)

- Implement a simple `IterSpan` for passing iterators with size.
- Use shared memory for column size counts.
- Use one thread for each sample in row count to reduce atomic operations.
This commit is contained in:
Jiaming Yuan
2023-05-30 23:28:43 +08:00
committed by GitHub
parent 097f11b6e0
commit 17fd3f55e9
9 changed files with 323 additions and 61 deletions

View File

@@ -39,6 +39,14 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
return {row_idx, column_idx, value};
}
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
auto const& column = columns_[fidx];
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
? column(ridx)
: std::numeric_limits<float>::quiet_NaN();
return value;
}
XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
@@ -160,6 +168,10 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
float value = array_interface_(row_idx, column_idx);
return {row_idx, column_idx, value};
}
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
float value = array_interface_(ridx, fidx);
return value;
}
XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
@@ -196,24 +208,47 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
// Returns maximum row length
template <typename AdapterBatchT>
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
int device_idx, float missing) {
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, int device_idx,
float missing) {
dh::safe_cuda(cudaSetDevice(device_idx));
IsValidFunctor is_valid(missing);
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes()));
auto n_samples = batch.NumRows();
bst_feature_t n_features = batch.NumCols();
// Use more than 1 threads for each row in case of dataset being too wide.
bst_feature_t stride{0};
if (n_features < 32) {
stride = std::min(n_features, 4u);
} else if (n_features < 64) {
stride = 8;
} else if (n_features < 128) {
stride = 16;
} else {
stride = 32;
}
// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[element.row_idx]),
static_cast<unsigned long long>(1)); // NOLINT
dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) {
bst_row_t cnt{0};
auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride);
SPAN_CHECK(ridx < n_samples);
for (bst_feature_t fidx = fbeg; fidx < n_features; fidx += stride) {
if (is_valid(batch.GetElement(ridx, fidx))) {
cnt++;
}
}
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[ridx]),
static_cast<unsigned long long>(cnt)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
size_t row_stride =
bst_row_t row_stride =
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<std::size_t>(0), thrust::maximum<size_t>());
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
return row_stride;
}

View File

@@ -80,7 +80,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
}
auto batch_rows = num_rows();
accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);
@@ -134,7 +134,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
init_page();
dh::safe_cuda(cudaSetDevice(get_device()));
auto rows = num_rows();
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
dh::device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);