Optimize adapter element counting on GPU. (#9209)

- Implement a simple `IterSpan` for passing iterators with size.
- Use shared memory for column size counts.
- Use one thread for each sample in row count to reduce atomic operations.
This commit is contained in:
Jiaming Yuan
2023-05-30 23:28:43 +08:00
committed by GitHub
parent 097f11b6e0
commit 17fd3f55e9
9 changed files with 323 additions and 61 deletions

View File

@@ -204,9 +204,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
@@ -273,9 +272,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,

View File

@@ -53,32 +53,126 @@ struct EntryCompareOp {
};
// Get column size from adapter batch and for output cuts.
template <typename Iter>
void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
Iter batch_iter, data::IsValidFunctor is_valid,
size_t begin, size_t end,
HostDeviceVector<SketchContainer::OffsetT> *cuts_ptr,
template <std::uint32_t kBlockThreads, typename CounterT, typename BatchIt>
__global__ void GetColumnSizeSharedMemKernel(IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid,
Span<std::size_t> out_column_size) {
extern __shared__ char smem[];
auto smem_cs_ptr = reinterpret_cast<CounterT*>(smem);
dh::BlockFill(smem_cs_ptr, out_column_size.size(), 0);
cub::CTA_SYNC();
auto n = batch_iter.size();
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n)) {
auto e = batch_iter[idx];
if (is_valid(e)) {
atomicAdd(&smem_cs_ptr[e.column_idx], static_cast<CounterT>(1));
}
}
cub::CTA_SYNC();
auto out_global_ptr = out_column_size;
for (auto i : dh::BlockStrideRange(static_cast<std::size_t>(0), out_column_size.size())) {
atomicAdd(&out_global_ptr[i], static_cast<std::size_t>(smem_cs_ptr[i]));
}
}
template <std::uint32_t kBlockThreads, typename Kernel>
std::uint32_t EstimateGridSize(std::int32_t device, Kernel kernel, std::size_t shared_mem) {
int n_mps = 0;
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
kBlockThreads, shared_mem));
std::uint32_t grid_size = n_blocks_per_mp * n_mps;
return grid_size;
}
/**
* \brief Get the size of each column. This is a histogram with additional handling of
* invalid values.
*
* \tparam BatchIt Type of input adapter batch.
* \tparam force_use_global_memory Used for testing. Force global atomic add.
* \tparam force_use_u64 Used for testing. For u64 as counter in shared memory.
*
* \param device CUDA device ordinal.
* \param batch_iter Iterator for input data from adapter batch.
* \param is_valid Whehter an element is considered as missing.
* \param out_column_size Output buffer for the size of each column.
*/
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter,
data::IsValidFunctor is_valid, Span<std::size_t> out_column_size) {
thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0);
std::size_t max_shared_memory = dh::MaxSharedMemory(device);
// Not strictly correct as we should use number of samples to determine the type of
// counter. However, the sample size is not known due to sliding window on number of
// elements.
std::size_t n = batch_iter.size();
std::size_t required_shared_memory = 0;
bool use_u32{false};
if (!force_use_u64 && n < static_cast<std::size_t>(std::numeric_limits<std::uint32_t>::max())) {
required_shared_memory = out_column_size.size() * sizeof(std::uint32_t);
use_u32 = true;
} else {
required_shared_memory = out_column_size.size() * sizeof(std::size_t);
use_u32 = false;
}
bool use_shared = required_shared_memory <= max_shared_memory && required_shared_memory != 0;
if (!force_use_global_memory && use_shared) {
CHECK_NE(required_shared_memory, 0);
std::uint32_t constexpr kBlockThreads = 512;
if (use_u32) {
CHECK(!force_use_u64);
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
kernel, batch_iter, is_valid, out_column_size);
} else {
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
kernel, batch_iter, is_valid, out_column_size);
}
} else {
auto d_out_column_size = out_column_size;
dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) {
auto e = batch_iter[idx];
if (is_valid(e)) {
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
}
});
}
}
template <typename BatchIt>
void GetColumnSizesScan(int device, size_t num_columns, std::size_t num_cuts_per_feature,
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
dh::caching_device_vector<size_t>* column_sizes_scan) {
column_sizes_scan->resize(num_columns + 1, 0);
column_sizes_scan->resize(num_columns + 1);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = column_sizes_scan->data().get();
dh::LaunchN(end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
}
});
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan);
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
@@ -121,29 +215,26 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
// Count the valid entries in each column and copy them out.
template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
Range1d range, float missing,
size_t columns, size_t cuts_per_feature, int device,
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range,
float missing, size_t columns, size_t cuts_per_feature, int device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
batch.GetElement(idx).value);
return Entry(batch.GetElement(idx).column_idx, batch.GetElement(idx).value);
});
auto n = range.end() - range.begin();
auto span = IterSpan{batch_iter + range.begin(), n};
data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column
GetColumnSizesScan(device, columns, cuts_per_feature,
batch_iter, is_valid,
range.begin(), range.end(),
cut_sizes_scan,
GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
column_sizes_scan);
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
sorted_entries->begin(), is_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
is_valid);
}
void SortByWeight(dh::device_vector<float>* weights,