Optimize adapter element counting on GPU. (#9209)
- Implement a simple `IterSpan` for passing iterators with size. - Use shared memory for column size counts. - Use one thread for each sample in row count to reduce atomic operations.
This commit is contained in:
parent
097f11b6e0
commit
17fd3f55e9
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018 XGBoost contributors
|
* Copyright 2018-2023, XGBoost contributors
|
||||||
* \brief span class based on ISO++20 span
|
* \brief span class based on ISO++20 span
|
||||||
*
|
*
|
||||||
* About NOLINTs in this file:
|
* About NOLINTs in this file:
|
||||||
@ -32,11 +32,12 @@
|
|||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <cinttypes> // size_t
|
#include <cinttypes> // size_t
|
||||||
#include <limits> // numeric_limits
|
|
||||||
#include <iterator>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <iterator>
|
||||||
|
#include <limits> // numeric_limits
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility> // for move
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@ -668,6 +669,44 @@ XGBOOST_DEVICE auto as_writable_bytes(Span<T, E> s) __span_noexcept -> // NOLIN
|
|||||||
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
Span<byte, detail::ExtentAsBytesValue<T, E>::value> {
|
||||||
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief A simple custom Span type that uses general iterator instead of pointer.
|
||||||
|
*/
|
||||||
|
template <typename It>
|
||||||
|
class IterSpan {
|
||||||
|
public:
|
||||||
|
using value_type = typename std::iterator_traits<It>::value_type; // NOLINT
|
||||||
|
using index_type = std::size_t; // NOLINT
|
||||||
|
using iterator = It; // NOLINT
|
||||||
|
|
||||||
|
private:
|
||||||
|
It it_;
|
||||||
|
index_type size_{0};
|
||||||
|
|
||||||
|
public:
|
||||||
|
IterSpan() = default;
|
||||||
|
XGBOOST_DEVICE IterSpan(It it, index_type size) : it_{std::move(it)}, size_{size} {}
|
||||||
|
XGBOOST_DEVICE explicit IterSpan(common::Span<It, dynamic_extent> span)
|
||||||
|
: it_{span.data()}, size_{span.size()} {}
|
||||||
|
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE index_type size() const noexcept { return size_; } // NOLINT
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) const { return it_[i]; }
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE decltype(auto) operator[](index_type i) { return it_[i]; }
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE bool empty() const noexcept { return size() == 0; } // NOLINT
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE It data() const noexcept { return it_; } // NOLINT
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE IterSpan<It> subspan( // NOLINT
|
||||||
|
index_type _offset, index_type _count = dynamic_extent) const {
|
||||||
|
SPAN_CHECK((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size()));
|
||||||
|
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
|
||||||
|
}
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
|
||||||
|
return {this, 0};
|
||||||
|
}
|
||||||
|
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
|
||||||
|
return {this, size()};
|
||||||
|
}
|
||||||
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -204,9 +204,8 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
|
|||||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
|
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
|
||||||
});
|
});
|
||||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||||
batch_it, dummy_is_valid,
|
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
||||||
0, sorted_entries.size(),
|
&column_sizes_scan);
|
||||||
&cuts_ptr, &column_sizes_scan);
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
|
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
@ -273,9 +272,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
|||||||
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
|
||||||
});
|
});
|
||||||
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
|
||||||
batch_it, dummy_is_valid,
|
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
|
||||||
0, sorted_entries.size(),
|
&column_sizes_scan);
|
||||||
&cuts_ptr, &column_sizes_scan);
|
|
||||||
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
|
||||||
if (sketch_container->HasCategorical()) {
|
if (sketch_container->HasCategorical()) {
|
||||||
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr,
|
||||||
|
|||||||
@ -53,32 +53,126 @@ struct EntryCompareOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Get column size from adapter batch and for output cuts.
|
// Get column size from adapter batch and for output cuts.
|
||||||
template <typename Iter>
|
template <std::uint32_t kBlockThreads, typename CounterT, typename BatchIt>
|
||||||
void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
|
__global__ void GetColumnSizeSharedMemKernel(IterSpan<BatchIt> batch_iter,
|
||||||
Iter batch_iter, data::IsValidFunctor is_valid,
|
data::IsValidFunctor is_valid,
|
||||||
size_t begin, size_t end,
|
Span<std::size_t> out_column_size) {
|
||||||
HostDeviceVector<SketchContainer::OffsetT> *cuts_ptr,
|
extern __shared__ char smem[];
|
||||||
|
|
||||||
|
auto smem_cs_ptr = reinterpret_cast<CounterT*>(smem);
|
||||||
|
|
||||||
|
dh::BlockFill(smem_cs_ptr, out_column_size.size(), 0);
|
||||||
|
|
||||||
|
cub::CTA_SYNC();
|
||||||
|
|
||||||
|
auto n = batch_iter.size();
|
||||||
|
|
||||||
|
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n)) {
|
||||||
|
auto e = batch_iter[idx];
|
||||||
|
if (is_valid(e)) {
|
||||||
|
atomicAdd(&smem_cs_ptr[e.column_idx], static_cast<CounterT>(1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cub::CTA_SYNC();
|
||||||
|
|
||||||
|
auto out_global_ptr = out_column_size;
|
||||||
|
for (auto i : dh::BlockStrideRange(static_cast<std::size_t>(0), out_column_size.size())) {
|
||||||
|
atomicAdd(&out_global_ptr[i], static_cast<std::size_t>(smem_cs_ptr[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <std::uint32_t kBlockThreads, typename Kernel>
|
||||||
|
std::uint32_t EstimateGridSize(std::int32_t device, Kernel kernel, std::size_t shared_mem) {
|
||||||
|
int n_mps = 0;
|
||||||
|
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||||
|
int n_blocks_per_mp = 0;
|
||||||
|
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
|
||||||
|
kBlockThreads, shared_mem));
|
||||||
|
std::uint32_t grid_size = n_blocks_per_mp * n_mps;
|
||||||
|
return grid_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get the size of each column. This is a histogram with additional handling of
|
||||||
|
* invalid values.
|
||||||
|
*
|
||||||
|
* \tparam BatchIt Type of input adapter batch.
|
||||||
|
* \tparam force_use_global_memory Used for testing. Force global atomic add.
|
||||||
|
* \tparam force_use_u64 Used for testing. For u64 as counter in shared memory.
|
||||||
|
*
|
||||||
|
* \param device CUDA device ordinal.
|
||||||
|
* \param batch_iter Iterator for input data from adapter batch.
|
||||||
|
* \param is_valid Whehter an element is considered as missing.
|
||||||
|
* \param out_column_size Output buffer for the size of each column.
|
||||||
|
*/
|
||||||
|
template <typename BatchIt, bool force_use_global_memory = false, bool force_use_u64 = false>
|
||||||
|
void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter,
|
||||||
|
data::IsValidFunctor is_valid, Span<std::size_t> out_column_size) {
|
||||||
|
thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0);
|
||||||
|
|
||||||
|
std::size_t max_shared_memory = dh::MaxSharedMemory(device);
|
||||||
|
// Not strictly correct as we should use number of samples to determine the type of
|
||||||
|
// counter. However, the sample size is not known due to sliding window on number of
|
||||||
|
// elements.
|
||||||
|
std::size_t n = batch_iter.size();
|
||||||
|
|
||||||
|
std::size_t required_shared_memory = 0;
|
||||||
|
bool use_u32{false};
|
||||||
|
if (!force_use_u64 && n < static_cast<std::size_t>(std::numeric_limits<std::uint32_t>::max())) {
|
||||||
|
required_shared_memory = out_column_size.size() * sizeof(std::uint32_t);
|
||||||
|
use_u32 = true;
|
||||||
|
} else {
|
||||||
|
required_shared_memory = out_column_size.size() * sizeof(std::size_t);
|
||||||
|
use_u32 = false;
|
||||||
|
}
|
||||||
|
bool use_shared = required_shared_memory <= max_shared_memory && required_shared_memory != 0;
|
||||||
|
|
||||||
|
if (!force_use_global_memory && use_shared) {
|
||||||
|
CHECK_NE(required_shared_memory, 0);
|
||||||
|
std::uint32_t constexpr kBlockThreads = 512;
|
||||||
|
if (use_u32) {
|
||||||
|
CHECK(!force_use_u64);
|
||||||
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
|
||||||
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
|
||||||
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
|
} else {
|
||||||
|
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
|
||||||
|
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
|
||||||
|
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
|
||||||
|
kernel, batch_iter, is_valid, out_column_size);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto d_out_column_size = out_column_size;
|
||||||
|
dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) {
|
||||||
|
auto e = batch_iter[idx];
|
||||||
|
if (is_valid(e)) {
|
||||||
|
atomicAdd(&d_out_column_size[e.column_idx], static_cast<size_t>(1));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename BatchIt>
|
||||||
|
void GetColumnSizesScan(int device, size_t num_columns, std::size_t num_cuts_per_feature,
|
||||||
|
IterSpan<BatchIt> batch_iter, data::IsValidFunctor is_valid,
|
||||||
|
HostDeviceVector<SketchContainer::OffsetT>* cuts_ptr,
|
||||||
dh::caching_device_vector<size_t>* column_sizes_scan) {
|
dh::caching_device_vector<size_t>* column_sizes_scan) {
|
||||||
column_sizes_scan->resize(num_columns + 1, 0);
|
column_sizes_scan->resize(num_columns + 1);
|
||||||
cuts_ptr->SetDevice(device);
|
cuts_ptr->SetDevice(device);
|
||||||
cuts_ptr->Resize(num_columns + 1, 0);
|
cuts_ptr->Resize(num_columns + 1, 0);
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
auto d_column_sizes_scan = column_sizes_scan->data().get();
|
auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan);
|
||||||
dh::LaunchN(end - begin, [=] __device__(size_t idx) {
|
LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan);
|
||||||
auto e = batch_iter[begin + idx];
|
|
||||||
if (is_valid(e)) {
|
|
||||||
atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Calculate cuts CSC pointer
|
// Calculate cuts CSC pointer
|
||||||
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
|
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
|
||||||
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
|
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
|
||||||
return thrust::min(num_cuts_per_feature, column_size);
|
return thrust::min(num_cuts_per_feature, column_size);
|
||||||
});
|
});
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
|
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
|
||||||
cut_ptr_it + column_sizes_scan->size(),
|
cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer());
|
||||||
cuts_ptr->DevicePointer());
|
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
|
||||||
column_sizes_scan->end(), column_sizes_scan->begin());
|
column_sizes_scan->end(), column_sizes_scan->begin());
|
||||||
}
|
}
|
||||||
@ -121,29 +215,26 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
|
|||||||
|
|
||||||
// Count the valid entries in each column and copy them out.
|
// Count the valid entries in each column and copy them out.
|
||||||
template <typename AdapterBatch, typename BatchIter>
|
template <typename AdapterBatch, typename BatchIter>
|
||||||
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Range1d range,
|
||||||
Range1d range, float missing,
|
float missing, size_t columns, size_t cuts_per_feature, int device,
|
||||||
size_t columns, size_t cuts_per_feature, int device,
|
|
||||||
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
||||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||||
dh::device_vector<Entry>* sorted_entries) {
|
dh::device_vector<Entry>* sorted_entries) {
|
||||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||||
return Entry(batch.GetElement(idx).column_idx,
|
return Entry(batch.GetElement(idx).column_idx, batch.GetElement(idx).value);
|
||||||
batch.GetElement(idx).value);
|
|
||||||
});
|
});
|
||||||
|
auto n = range.end() - range.begin();
|
||||||
|
auto span = IterSpan{batch_iter + range.begin(), n};
|
||||||
data::IsValidFunctor is_valid(missing);
|
data::IsValidFunctor is_valid(missing);
|
||||||
// Work out how many valid entries we have in each column
|
// Work out how many valid entries we have in each column
|
||||||
GetColumnSizesScan(device, columns, cuts_per_feature,
|
GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, cut_sizes_scan,
|
||||||
batch_iter, is_valid,
|
|
||||||
range.begin(), range.end(),
|
|
||||||
cut_sizes_scan,
|
|
||||||
column_sizes_scan);
|
column_sizes_scan);
|
||||||
size_t num_valid = column_sizes_scan->back();
|
size_t num_valid = column_sizes_scan->back();
|
||||||
// Copy current subset of valid elements into temporary storage and sort
|
// Copy current subset of valid elements into temporary storage and sort
|
||||||
sorted_entries->resize(num_valid);
|
sorted_entries->resize(num_valid);
|
||||||
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
|
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), sorted_entries->begin(),
|
||||||
sorted_entries->begin(), is_valid);
|
is_valid);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SortByWeight(dh::device_vector<float>* weights,
|
void SortByWeight(dh::device_vector<float>* weights,
|
||||||
|
|||||||
@ -39,6 +39,14 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
|||||||
return {row_idx, column_idx, value};
|
return {row_idx, column_idx, value};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||||
|
auto const& column = columns_[fidx];
|
||||||
|
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
|
||||||
|
? column(ridx)
|
||||||
|
: std::numeric_limits<float>::quiet_NaN();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
||||||
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
||||||
|
|
||||||
@ -160,6 +168,10 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
|||||||
float value = array_interface_(row_idx, column_idx);
|
float value = array_interface_(row_idx, column_idx);
|
||||||
return {row_idx, column_idx, value};
|
return {row_idx, column_idx, value};
|
||||||
}
|
}
|
||||||
|
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||||
|
float value = array_interface_(ridx, fidx);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
|
XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
|
||||||
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
|
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
|
||||||
@ -196,24 +208,47 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
|
|
||||||
// Returns maximum row length
|
// Returns maximum row length
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, int device_idx,
|
||||||
int device_idx, float missing) {
|
float missing) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
IsValidFunctor is_valid(missing);
|
IsValidFunctor is_valid(missing);
|
||||||
|
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes()));
|
||||||
|
|
||||||
|
auto n_samples = batch.NumRows();
|
||||||
|
bst_feature_t n_features = batch.NumCols();
|
||||||
|
|
||||||
|
// Use more than 1 threads for each row in case of dataset being too wide.
|
||||||
|
bst_feature_t stride{0};
|
||||||
|
if (n_features < 32) {
|
||||||
|
stride = std::min(n_features, 4u);
|
||||||
|
} else if (n_features < 64) {
|
||||||
|
stride = 8;
|
||||||
|
} else if (n_features < 128) {
|
||||||
|
stride = 16;
|
||||||
|
} else {
|
||||||
|
stride = 32;
|
||||||
|
}
|
||||||
|
|
||||||
// Count elements per row
|
// Count elements per row
|
||||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) {
|
||||||
auto element = batch.GetElement(idx);
|
bst_row_t cnt{0};
|
||||||
if (is_valid(element)) {
|
auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride);
|
||||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
SPAN_CHECK(ridx < n_samples);
|
||||||
&offset[element.row_idx]),
|
for (bst_feature_t fidx = fbeg; fidx < n_features; fidx += stride) {
|
||||||
static_cast<unsigned long long>(1)); // NOLINT
|
if (is_valid(batch.GetElement(ridx, fidx))) {
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||||
|
&offset[ridx]),
|
||||||
|
static_cast<unsigned long long>(cnt)); // NOLINT
|
||||||
});
|
});
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
size_t row_stride =
|
bst_row_t row_stride =
|
||||||
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||||
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||||
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
static_cast<bst_row_t>(0), thrust::maximum<bst_row_t>());
|
||||||
return row_stride;
|
return row_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
}
|
}
|
||||||
auto batch_rows = num_rows();
|
auto batch_rows = num_rows();
|
||||||
accumulated_rows += batch_rows;
|
accumulated_rows += batch_rows;
|
||||||
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
|
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||||
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
|
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
|
||||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||||
@ -134,7 +134,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
|||||||
init_page();
|
init_page();
|
||||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||||
auto rows = num_rows();
|
auto rows = num_rows();
|
||||||
dh::caching_device_vector<size_t> row_counts(rows + 1, 0);
|
dh::device_vector<size_t> row_counts(rows + 1, 0);
|
||||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||||
Dispatch(proxy, [=](auto const& value) {
|
Dispatch(proxy, [=](auto const& value) {
|
||||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||||
|
|||||||
@ -483,6 +483,73 @@ TEST(HistUtil, AdapterDeviceSketchBatches) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
auto MakeData(Context const* ctx, std::size_t n_samples, bst_feature_t n_features) {
|
||||||
|
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||||
|
auto n = n_samples * n_features;
|
||||||
|
std::vector<float> x;
|
||||||
|
x.resize(n);
|
||||||
|
|
||||||
|
std::iota(x.begin(), x.end(), 0);
|
||||||
|
std::int32_t c{0};
|
||||||
|
float missing = n_samples * n_features;
|
||||||
|
for (std::size_t i = 0; i < x.size(); ++i) {
|
||||||
|
if (i % 5 == 0) {
|
||||||
|
x[i] = missing;
|
||||||
|
c++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
thrust::device_vector<float> d_x;
|
||||||
|
d_x = x;
|
||||||
|
|
||||||
|
auto n_invalids = n / 10 * 2 + 1;
|
||||||
|
auto is_valid = data::IsValidFunctor{missing};
|
||||||
|
return std::tuple{x, d_x, n_invalids, is_valid};
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetColumnSize(std::size_t n_samples) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
bst_feature_t n_features = 12;
|
||||||
|
[[maybe_unused]] auto [x, d_x, n_invalids, is_valid] = MakeData(&ctx, n_samples, n_features);
|
||||||
|
|
||||||
|
auto adapter = AdapterFromData(d_x, n_samples, n_features);
|
||||||
|
auto batch = adapter.Value();
|
||||||
|
|
||||||
|
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||||
|
thrust::make_counting_iterator(0llu),
|
||||||
|
[=] __device__(std::size_t idx) { return batch.GetElement(idx); });
|
||||||
|
|
||||||
|
dh::caching_device_vector<std::size_t> column_sizes_scan;
|
||||||
|
column_sizes_scan.resize(n_features + 1);
|
||||||
|
std::vector<std::size_t> h_column_size(column_sizes_scan.size());
|
||||||
|
std::vector<std::size_t> h_column_size_1(column_sizes_scan.size());
|
||||||
|
|
||||||
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, true>(
|
||||||
|
ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
||||||
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin());
|
||||||
|
|
||||||
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), true, false>(
|
||||||
|
ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
||||||
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
|
|
||||||
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, true>(
|
||||||
|
ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
||||||
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
|
|
||||||
|
detail::LaunchGetColumnSizeKernel<decltype(batch_iter), false, false>(
|
||||||
|
ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan));
|
||||||
|
thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin());
|
||||||
|
ASSERT_EQ(h_column_size, h_column_size_1);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(HistUtil, GetColumnSize) {
|
||||||
|
bst_row_t n_samples = 4096;
|
||||||
|
TestGetColumnSize(n_samples);
|
||||||
|
}
|
||||||
|
|
||||||
// Check sketching from adapter or DMatrix results in the same answer
|
// Check sketching from adapter or DMatrix results in the same answer
|
||||||
// Consistency here is useful for testing and user experience
|
// Consistency here is useful for testing and user experience
|
||||||
TEST(HistUtil, SketchingEquivalent) {
|
TEST(HistUtil, SketchingEquivalent) {
|
||||||
|
|||||||
@ -50,7 +50,7 @@ void TestSketchUnique(float sparsity) {
|
|||||||
thrust::make_counting_iterator(0llu),
|
thrust::make_counting_iterator(0llu),
|
||||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||||
auto end = kCols * kRows;
|
auto end = kCols * kRows;
|
||||||
detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end,
|
detail::GetColumnSizesScan(0, kCols, n_cuts, IterSpan{batch_iter, end}, is_valid,
|
||||||
&cut_sizes_scan, &column_sizes_scan);
|
&cut_sizes_scan, &column_sizes_scan);
|
||||||
auto const& cut_sizes = cut_sizes_scan.HostVector();
|
auto const& cut_sizes = cut_sizes_scan.HostVector();
|
||||||
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
|
ASSERT_LE(sketch.Data().size(), cut_sizes.back());
|
||||||
|
|||||||
@ -1,15 +1,16 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2018 XGBoost contributors
|
* Copyright 2018-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <xgboost/span.h>
|
|
||||||
#include "test_span.h"
|
#include "test_span.h"
|
||||||
|
|
||||||
namespace xgboost {
|
#include <gtest/gtest.h>
|
||||||
namespace common {
|
#include <xgboost/span.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../../src/common/transform_iterator.h" // for MakeIndexTransformIter
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
TEST(Span, TestStatus) {
|
TEST(Span, TestStatus) {
|
||||||
int status = 1;
|
int status = 1;
|
||||||
TestTestStatus {&status}();
|
TestTestStatus {&status}();
|
||||||
@ -526,5 +527,17 @@ TEST(SpanDeathTest, Empty) {
|
|||||||
Span<float> s{data.data(), static_cast<Span<float>::index_type>(0)};
|
Span<float> s{data.data(), static_cast<Span<float>::index_type>(0)};
|
||||||
EXPECT_DEATH(s[0], ""); // not ok to use it.
|
EXPECT_DEATH(s[0], ""); // not ok to use it.
|
||||||
}
|
}
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
TEST(IterSpan, Basic) {
|
||||||
|
auto iter = common::MakeIndexTransformIter([](std::size_t i) { return i; });
|
||||||
|
std::size_t n = 13;
|
||||||
|
auto span = IterSpan{iter, n};
|
||||||
|
ASSERT_EQ(span.size(), n);
|
||||||
|
for (std::size_t i = 0; i < n; ++i) {
|
||||||
|
ASSERT_EQ(span[i], i);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(span.subspan(1).size(), n - 1);
|
||||||
|
ASSERT_EQ(span.subspan(1)[0], 1);
|
||||||
|
ASSERT_EQ(span.subspan(1, 2)[1], 2);
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
|
|||||||
@ -51,3 +51,22 @@ void TestCudfAdapter()
|
|||||||
TEST(DeviceAdapter, CudfAdapter) {
|
TEST(DeviceAdapter, CudfAdapter) {
|
||||||
TestCudfAdapter();
|
TestCudfAdapter();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
|
TEST(DeviceAdapter, GetRowCounts) {
|
||||||
|
auto ctx = MakeCUDACtx(0);
|
||||||
|
|
||||||
|
for (bst_feature_t n_features : {1, 2, 4, 64, 128, 256}) {
|
||||||
|
HostDeviceVector<float> storage;
|
||||||
|
auto str_arr = RandomDataGenerator{8192, n_features, 0.0}
|
||||||
|
.Device(ctx.gpu_id)
|
||||||
|
.GenerateArrayInterface(&storage);
|
||||||
|
auto adapter = CupyAdapter{str_arr};
|
||||||
|
HostDeviceVector<bst_row_t> offset(adapter.NumRows() + 1, 0);
|
||||||
|
offset.SetDevice(ctx.gpu_id);
|
||||||
|
auto rstride = GetRowCounts(adapter.Value(), offset.DeviceSpan(), ctx.gpu_id,
|
||||||
|
std::numeric_limits<float>::quiet_NaN());
|
||||||
|
ASSERT_EQ(rstride, n_features);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user