[EM] Compress dense ellpack. (#10821)

This helps reduce the memory copying needed for dense data. In addition, it helps reduce memory usage even if external memory is not used.

- Decouple the number of symbols needed in the compressor with the number of features when the data is dense.
- Remove the fetch call in the `at_end_` iteration.
- Reduce synchronization and kernel launches by using the `uvector` and ctx.
This commit is contained in:
Jiaming Yuan 2024-09-20 18:20:56 +08:00 committed by GitHub
parent d5e1c41b69
commit 24241ed6e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 485 additions and 285 deletions

1
.gitignore vendored
View File

@ -33,6 +33,7 @@ ipch
*.filters *.filters
*.user *.user
*log *log
rmm_log.txt
Debug Debug
*suo *suo
.Rhistory .Rhistory

View File

@ -10,6 +10,7 @@
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <limits> // for numeric_limits #include <limits> // for numeric_limits
#include <new> // for bad_array_new_length
#include "common.h" #include "common.h"
@ -28,14 +29,14 @@ struct PinnedAllocPolicy {
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using value_type = T; // NOLINT: The type of the elements in the allocator using value_type = T; // NOLINT: The type of the elements in the allocator
size_type max_size() const { // NOLINT [[nodiscard]] constexpr size_type max_size() const { // NOLINT
return std::numeric_limits<size_type>::max() / sizeof(value_type); return std::numeric_limits<size_type>::max() / sizeof(value_type);
} }
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
if (cnt > this->max_size()) { if (cnt > this->max_size()) {
throw std::bad_alloc{}; throw std::bad_array_new_length{};
} // end if }
pointer result(nullptr); pointer result(nullptr);
dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type))); dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
@ -52,14 +53,14 @@ struct ManagedAllocPolicy {
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using value_type = T; // NOLINT: The type of the elements in the allocator using value_type = T; // NOLINT: The type of the elements in the allocator
size_type max_size() const { // NOLINT [[nodiscard]] constexpr size_type max_size() const { // NOLINT
return std::numeric_limits<size_type>::max() / sizeof(value_type); return std::numeric_limits<size_type>::max() / sizeof(value_type);
} }
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
if (cnt > this->max_size()) { if (cnt > this->max_size()) {
throw std::bad_alloc{}; throw std::bad_array_new_length{};
} // end if }
pointer result(nullptr); pointer result(nullptr);
dh::safe_cuda(cudaMallocManaged(reinterpret_cast<void**>(&result), cnt * sizeof(value_type))); dh::safe_cuda(cudaMallocManaged(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
@ -78,14 +79,14 @@ struct SamAllocPolicy {
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
using value_type = T; // NOLINT: The type of the elements in the allocator using value_type = T; // NOLINT: The type of the elements in the allocator
size_type max_size() const { // NOLINT [[nodiscard]] constexpr size_type max_size() const { // NOLINT
return std::numeric_limits<size_type>::max() / sizeof(value_type); return std::numeric_limits<size_type>::max() / sizeof(value_type);
} }
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
if (cnt > this->max_size()) { if (cnt > this->max_size()) {
throw std::bad_alloc{}; throw std::bad_array_new_length{};
} // end if }
size_type n_bytes = cnt * sizeof(value_type); size_type n_bytes = cnt * sizeof(value_type);
pointer result = reinterpret_cast<pointer>(std::malloc(n_bytes)); pointer result = reinterpret_cast<pointer>(std::malloc(n_bytes));
@ -139,10 +140,10 @@ class CudaHostAllocatorImpl : public Policy<T> {
}; };
template <typename T> template <typename T>
using PinnedAllocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>; // NOLINT using PinnedAllocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>;
template <typename T> template <typename T>
using ManagedAllocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>; // NOLINT using ManagedAllocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>;
template <typename T> template <typename T>
using SamAllocator = CudaHostAllocatorImpl<T, SamAllocPolicy>; using SamAllocator = CudaHostAllocatorImpl<T, SamAllocPolicy>;

View File

@ -177,8 +177,10 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
pointer thrust_ptr; pointer thrust_ptr;
if (use_cub_allocator_) { if (use_cub_allocator_) {
T *raw_ptr{nullptr}; T *raw_ptr{nullptr};
// NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection)
auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&raw_ptr), auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&raw_ptr),
n * sizeof(T)); n * sizeof(T));
// NOLINTEND(clang-analyzer-unix.BlockInCriticalSection)
if (errc != cudaSuccess) { if (errc != cudaSuccess) {
detail::ThrowOOMError("Caching allocator", n * sizeof(T)); detail::ThrowOOMError("Caching allocator", n * sizeof(T));
} }
@ -290,13 +292,13 @@ LoggingResource *GlobalLoggingResource();
/** /**
* @brief Container class that doesn't initialize the data when RMM is used. * @brief Container class that doesn't initialize the data when RMM is used.
*/ */
template <typename T> template <typename T, bool is_caching>
class DeviceUVector { class DeviceUVectorImpl {
private: private:
#if defined(XGBOOST_USE_RMM) #if defined(XGBOOST_USE_RMM)
rmm::device_uvector<T> data_{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()}; rmm::device_uvector<T> data_{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()};
#else #else
::dh::device_vector<T> data_; std::conditional_t<is_caching, ::dh::caching_device_vector<T>, ::dh::device_vector<T>> data_;
#endif // defined(XGBOOST_USE_RMM) #endif // defined(XGBOOST_USE_RMM)
public: public:
@ -307,12 +309,12 @@ class DeviceUVector {
using const_reference = value_type const &; // NOLINT using const_reference = value_type const &; // NOLINT
public: public:
DeviceUVector() = default; DeviceUVectorImpl() = default;
explicit DeviceUVector(std::size_t n) { this->resize(n); } explicit DeviceUVectorImpl(std::size_t n) { this->resize(n); }
DeviceUVector(DeviceUVector const &that) = delete; DeviceUVectorImpl(DeviceUVectorImpl const &that) = delete;
DeviceUVector &operator=(DeviceUVector const &that) = delete; DeviceUVectorImpl &operator=(DeviceUVectorImpl const &that) = delete;
DeviceUVector(DeviceUVector &&that) = default; DeviceUVectorImpl(DeviceUVectorImpl &&that) = default;
DeviceUVector &operator=(DeviceUVector &&that) = default; DeviceUVectorImpl &operator=(DeviceUVectorImpl &&that) = default;
void resize(std::size_t n) { // NOLINT void resize(std::size_t n) { // NOLINT
#if defined(XGBOOST_USE_RMM) #if defined(XGBOOST_USE_RMM)
@ -356,4 +358,10 @@ class DeviceUVector {
[[nodiscard]] auto data() { return thrust::raw_pointer_cast(data_.data()); } // NOLINT [[nodiscard]] auto data() { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
[[nodiscard]] auto data() const { return thrust::raw_pointer_cast(data_.data()); } // NOLINT [[nodiscard]] auto data() const { return thrust::raw_pointer_cast(data_.data()); } // NOLINT
}; };
template <typename T>
using DeviceUVector = DeviceUVectorImpl<T, false>;
template <typename T>
using CachingDeviceUVector = DeviceUVectorImpl<T, true>;
} // namespace dh } // namespace dh

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2023 by XGBoost Contributors * Copyright 2019-2024, XGBoost Contributors
* \file device_adapter.cuh * \file device_adapter.cuh
*/ */
#ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_ #ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_
@ -7,13 +7,12 @@
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator #include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/logical.h> // for none_of #include <thrust/logical.h> // for none_of
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <limits> #include <limits>
#include <memory>
#include <string> #include <string>
#include "../common/cuda_context.cuh"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/math.h"
#include "adapter.h" #include "adapter.h"
#include "array_interface.h" #include "array_interface.h"
@ -208,11 +207,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
// Returns maximum row length // Returns maximum row length
template <typename AdapterBatchT> template <typename AdapterBatchT>
bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_idx_t> offset, DeviceOrd device, bst_idx_t GetRowCounts(Context const* ctx, const AdapterBatchT batch,
float missing) { common::Span<bst_idx_t> offset, DeviceOrd device, float missing) {
dh::safe_cuda(cudaSetDevice(device.ordinal)); dh::safe_cuda(cudaSetDevice(device.ordinal));
IsValidFunctor is_valid(missing); IsValidFunctor is_valid(missing);
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes())); dh::safe_cuda(
cudaMemsetAsync(offset.data(), '\0', offset.size_bytes(), ctx->CUDACtx()->Stream()));
auto n_samples = batch.NumRows(); auto n_samples = batch.NumRows();
bst_feature_t n_features = batch.NumCols(); bst_feature_t n_features = batch.NumCols();
@ -230,7 +230,7 @@ bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_idx_t> offset
} }
// Count elements per row // Count elements per row
dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) { dh::LaunchN(n_samples * stride, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) {
bst_idx_t cnt{0}; bst_idx_t cnt{0};
auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride); auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride);
SPAN_CHECK(ridx < n_samples); SPAN_CHECK(ridx < n_samples);
@ -244,9 +244,8 @@ bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_idx_t> offset
&offset[ridx]), &offset[ridx]),
static_cast<unsigned long long>(cnt)); // NOLINT static_cast<unsigned long long>(cnt)); // NOLINT
}); });
dh::XGBCachingDeviceAllocator<char> alloc;
bst_idx_t row_stride = bst_idx_t row_stride =
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), dh::Reduce(ctx->CUDACtx()->CTP(), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(), thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<bst_idx_t>(0), thrust::maximum<bst_idx_t>()); static_cast<bst_idx_t>(0), thrust::maximum<bst_idx_t>());
return row_stride; return row_stride;

View File

@ -9,6 +9,7 @@
#include <utility> // for move #include <utility> // for move
#include <vector> // for vector #include <vector> // for vector
#include "../common/algorithm.cuh" // for InclusiveScan
#include "../common/categorical.h" #include "../common/categorical.h"
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
#include "../common/cuda_rt_utils.h" // for SetDevice #include "../common/cuda_rt_utils.h" // for SetDevice
@ -45,6 +46,7 @@ void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id)
[[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); } [[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); }
// Bin each input data entry, store the bin indices in compressed form. // Bin each input data entry, store the bin indices in compressed form.
template <bool kIsDense>
__global__ void CompressBinEllpackKernel( __global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr, common::CompressedBufferWriter wr,
common::CompressedByteT* __restrict__ buffer, // gidx_buffer common::CompressedByteT* __restrict__ buffer, // gidx_buffer
@ -73,12 +75,11 @@ __global__ void CompressBinEllpackKernel(
// Assigning the bin in current entry. // Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin] // S.t.: fvalue < feature_cuts[bin]
if (is_cat) { if (is_cat) {
auto it = dh::MakeTransformIterator<int>( auto it =
feature_cuts, [](float v) { return common::AsCat(v); }); dh::MakeTransformIterator<int>(feature_cuts, [](float v) { return common::AsCat(v); });
bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it; bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it;
} else { } else {
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, fvalue) -
fvalue) -
feature_cuts; feature_cuts;
} }
@ -86,24 +87,54 @@ __global__ void CompressBinEllpackKernel(
bin = ncuts - 1; bin = ncuts - 1;
} }
// Add the number of bins in previous features. // Add the number of bins in previous features.
bin += cut_ptrs[feature]; if (!kIsDense) {
bin += cut_ptrs[feature];
}
} }
// Write to gidx buffer. // Write to gidx buffer.
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
} }
[[nodiscard]] std::size_t CalcNumSymbols(Context const*, bool /*is_dense*/, namespace {
// Calculate the number of symbols for the compressed ellpack. Similar to what the CPU
// implementation does, we compress the dense data by subtracting the bin values with the
// starting bin of its feature.
[[nodiscard]] std::size_t CalcNumSymbols(Context const* ctx, bool is_dense,
std::shared_ptr<common::HistogramCuts const> cuts) { std::shared_ptr<common::HistogramCuts const> cuts) {
// Return the total number of symbols (total number of bins plus 1 for not found) // Cut values can be empty when the input data is empty.
return cuts->cut_values_.Size() + 1; if (!is_dense || cuts->cut_values_.Empty()) {
// Return the total number of symbols (total number of bins plus 1 for not found)
return cuts->cut_values_.Size() + 1;
}
cuts->cut_ptrs_.SetDevice(ctx->Device());
common::Span<std::uint32_t const> dptrs = cuts->cut_ptrs_.ConstDeviceSpan();
auto cuctx = ctx->CUDACtx();
using PtrT = typename decltype(dptrs)::value_type;
auto it = dh::MakeTransformIterator<PtrT>(
thrust::make_counting_iterator(1ul),
[=] XGBOOST_DEVICE(std::size_t i) { return dptrs[i] - dptrs[i - 1]; });
CHECK_GE(dptrs.size(), 2);
auto max_it = thrust::max_element(cuctx->CTP(), it, it + dptrs.size() - 1);
dh::CachingDeviceUVector<PtrT> max_element(1);
auto d_me = max_element.data();
dh::LaunchN(1, cuctx->Stream(), [=] XGBOOST_DEVICE(std::size_t i) { d_me[i] = *max_it; });
PtrT h_me{0};
dh::safe_cuda(
cudaMemcpyAsync(&h_me, d_me, sizeof(PtrT), cudaMemcpyDeviceToHost, cuctx->Stream()));
cuctx->Stream().Sync();
// No missing, hence no null value, hence no + 1 symbol.
// FIXME(jiamingy): When we extend this to use a sparsity threshold, +1 is needed back.
return h_me;
} }
} // namespace
// Construct an ELLPACK matrix with the given number of empty rows. // Construct an ELLPACK matrix with the given number of empty rows.
EllpackPageImpl::EllpackPageImpl(Context const* ctx, EllpackPageImpl::EllpackPageImpl(Context const* ctx,
std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense, std::shared_ptr<common::HistogramCuts const> cuts, bool is_dense,
bst_idx_t row_stride, bst_idx_t n_rows) bst_idx_t row_stride, bst_idx_t n_rows)
: is_dense(is_dense), : is_dense{is_dense},
cuts_(std::move(cuts)), cuts_{std::move(cuts)},
row_stride{row_stride}, row_stride{row_stride},
n_rows{n_rows}, n_rows{n_rows},
n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} { n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
@ -117,11 +148,14 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx,
std::shared_ptr<common::HistogramCuts const> cuts, std::shared_ptr<common::HistogramCuts const> cuts,
const SparsePage& page, bool is_dense, size_t row_stride, const SparsePage& page, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types) common::Span<FeatureType const> feature_types)
: cuts_(std::move(cuts)), : cuts_{std::move(cuts)},
is_dense(is_dense), is_dense{is_dense},
n_rows(page.Size()), n_rows{page.Size()},
row_stride(row_stride), row_stride{row_stride},
n_symbols_(CalcNumSymbols(ctx, this->is_dense, this->cuts_)) { n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} {
monitor_.Init("ellpack_page");
common::SetDevice(ctx->Ordinal());
this->InitCompressedData(ctx); this->InitCompressedData(ctx);
this->CreateHistIndices(ctx, page, feature_types); this->CreateHistIndices(ctx, page, feature_types);
} }
@ -147,8 +181,8 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* p_fmat, const Batc
auto ft = p_fmat->Info().feature_types.ConstDeviceSpan(); auto ft = p_fmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression"); monitor_.Start("BinningCompression");
CHECK(p_fmat->SingleColBlock()); CHECK(p_fmat->SingleColBlock());
for (const auto& batch : p_fmat->GetBatches<SparsePage>()) { for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
CreateHistIndices(ctx, batch, ft); this->CreateHistIndices(ctx, page, ft);
} }
monitor_.Stop("BinningCompression"); monitor_.Stop("BinningCompression");
} }
@ -186,6 +220,9 @@ struct WriteCompressedEllpackFunctor {
} else { } else {
bin_idx = accessor.SearchBin<false>(e.value, e.column_idx); bin_idx = accessor.SearchBin<false>(e.value, e.column_idx);
} }
if (kIsDense) {
bin_idx -= accessor.feature_segments[e.column_idx];
}
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
} }
return 0; return 0;
@ -257,7 +294,8 @@ void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch,
common::InclusiveScan(ctx, key_value_index_iter, out, TupleScanOp<Tuple>{}, batch.Size()); common::InclusiveScan(ctx, key_value_index_iter, out, TupleScanOp<Tuple>{}, batch.Size());
} }
void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) { void WriteNullValues(Context const* ctx, EllpackPageImpl* dst,
common::Span<size_t const> row_counts) {
// Write the null values // Write the null values
auto device_accessor = dst->GetDeviceAccessor(ctx); auto device_accessor = dst->GetDeviceAccessor(ctx);
common::CompressedBufferWriter writer(dst->NumSymbols()); common::CompressedBufferWriter writer(dst->NumSymbols());
@ -276,7 +314,7 @@ void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size
template <typename AdapterBatch> template <typename AdapterBatch>
EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing,
bool is_dense, common::Span<size_t> row_counts_span, bool is_dense, common::Span<size_t const> row_counts_span,
common::Span<FeatureType const> feature_types, size_t row_stride, common::Span<FeatureType const> feature_types, size_t row_stride,
bst_idx_t n_rows, bst_idx_t n_rows,
std::shared_ptr<common::HistogramCuts const> cuts) std::shared_ptr<common::HistogramCuts const> cuts)
@ -292,10 +330,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float m
WriteNullValues(ctx, this, row_counts_span); WriteNullValues(ctx, this, row_counts_span);
} }
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ #define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
template EllpackPageImpl::EllpackPageImpl( \ template EllpackPageImpl::EllpackPageImpl( \
Context const* ctx, __BATCH_T batch, float missing, bool is_dense, \ Context const* ctx, __BATCH_T batch, float missing, bool is_dense, \
common::Span<size_t> row_counts_span, common::Span<FeatureType const> feature_types, \ common::Span<size_t const> row_counts_span, common::Span<FeatureType const> feature_types, \
size_t row_stride, size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts); size_t row_stride, size_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch)
@ -303,18 +341,15 @@ ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)
namespace { namespace {
void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page, void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page,
common::Span<size_t const> d_row_ptr, size_t row_stride, common::Span<bst_idx_t const> d_row_ptr, bst_idx_t row_stride,
common::CompressedByteT* d_compressed_buffer, size_t null) { bst_bin_t null, bst_idx_t n_symbols,
common::CompressedByteT* d_compressed_buffer) {
dh::device_vector<uint8_t> data(page.index.begin(), page.index.end()); dh::device_vector<uint8_t> data(page.index.begin(), page.index.end());
auto d_data = dh::ToSpan(data); auto d_data = dh::ToSpan(data);
dh::device_vector<size_t> csc_indptr(page.index.Offset(), // GPU employs the same dense compression as CPU, no need to handle page.index.Offset()
page.index.Offset() + page.index.OffsetSize());
auto d_csc_indptr = dh::ToSpan(csc_indptr);
auto bin_type = page.index.GetBinTypeSize(); auto bin_type = page.index.GetBinTypeSize();
common::CompressedBufferWriter writer{page.cut.TotalBins() + common::CompressedBufferWriter writer{n_symbols};
static_cast<std::size_t>(1)}; // +1 for null value
auto cuctx = ctx->CUDACtx(); auto cuctx = ctx->CUDACtx();
dh::LaunchN(row_stride * page.Size(), cuctx->Stream(), [=] __device__(bst_idx_t idx) mutable { dh::LaunchN(row_stride * page.Size(), cuctx->Stream(), [=] __device__(bst_idx_t idx) mutable {
@ -323,22 +358,17 @@ void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page,
auto r_begin = d_row_ptr[ridx]; auto r_begin = d_row_ptr[ridx];
auto r_end = d_row_ptr[ridx + 1]; auto r_end = d_row_ptr[ridx + 1];
size_t r_size = r_end - r_begin; auto r_size = r_end - r_begin;
if (ifeature >= r_size) { if (ifeature >= r_size) {
writer.AtomicWriteSymbol(d_compressed_buffer, null, idx); writer.AtomicWriteSymbol(d_compressed_buffer, null, idx);
return; return;
} }
bst_idx_t offset = 0;
if (!d_csc_indptr.empty()) {
// is dense, ifeature is the actual feature index.
offset = d_csc_indptr[ifeature];
}
common::cuda::DispatchBinType(bin_type, [&](auto t) { common::cuda::DispatchBinType(bin_type, [&](auto t) {
using T = decltype(t); using T = decltype(t);
auto ptr = reinterpret_cast<T const*>(d_data.data()); auto ptr = reinterpret_cast<T const*>(d_data.data());
auto bin_idx = ptr[r_begin + ifeature] + offset; auto bin_idx = ptr[r_begin + ifeature];
writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx); writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx);
}); });
}); });
@ -348,14 +378,16 @@ void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page,
EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
common::Span<FeatureType const> ft) common::Span<FeatureType const> ft)
: is_dense{page.IsDense()}, : is_dense{page.IsDense()},
row_stride{[&] {
auto it = common::MakeIndexTransformIter(
[&](bst_idx_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
return *std::max_element(it, it + page.Size());
}()},
base_rowid{page.base_rowid}, base_rowid{page.base_rowid},
n_rows{page.Size()}, n_rows{page.Size()},
cuts_{std::make_shared<common::HistogramCuts>(page.cut)}, cuts_{std::make_shared<common::HistogramCuts>(page.cut)},
n_symbols_{CalcNumSymbols(ctx, page.IsDense(), cuts_)} { n_symbols_{CalcNumSymbols(ctx, page.IsDense(), cuts_)} {
auto it = common::MakeIndexTransformIter( this->monitor_.Init("ellpack_page");
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
row_stride = *std::max_element(it, it + page.Size());
CHECK(ctx->IsCUDA()); CHECK(ctx->IsCUDA());
this->InitCompressedData(ctx); this->InitCompressedData(ctx);
@ -367,12 +399,17 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream())); cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream()));
auto accessor = this->GetDeviceAccessor(ctx, ft); auto accessor = this->GetDeviceAccessor(ctx, ft);
auto null = accessor.NullValue();
this->monitor_.Start("CopyGHistToEllpack"); this->monitor_.Start("CopyGHistToEllpack");
CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, d_compressed_buffer, null); CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, accessor.NullValue(), this->NumSymbols(),
d_compressed_buffer);
this->monitor_.Stop("CopyGHistToEllpack"); this->monitor_.Stop("CopyGHistToEllpack");
} }
EllpackPageImpl::~EllpackPageImpl() noexcept(false) {
// Sync the stream to make sure all running CUDA kernels finish before deallocation.
dh::DefaultStream().Sync();
}
// A functor that copies the data from one EllpackPage to another. // A functor that copies the data from one EllpackPage to another.
struct CopyPage { struct CopyPage {
common::CompressedBufferWriter cbw; common::CompressedBufferWriter cbw;
@ -385,7 +422,7 @@ struct CopyPage {
: cbw{dst->NumSymbols()}, : cbw{dst->NumSymbols()},
dst_data_d{dst->gidx_buffer.data()}, dst_data_d{dst->gidx_buffer.data()},
src_iterator_d{src->gidx_buffer.data(), src->NumSymbols()}, src_iterator_d{src->gidx_buffer.data(), src->NumSymbols()},
offset(offset) {} offset{offset} {}
__device__ void operator()(size_t element_id) { __device__ void operator()(size_t element_id) {
cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset); cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset);
@ -393,7 +430,7 @@ struct CopyPage {
}; };
// Copy the data from the given EllpackPage to the current page. // Copy the data from the given EllpackPage to the current page.
size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) { bst_idx_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) {
monitor_.Start(__func__); monitor_.Start(__func__);
bst_idx_t num_elements = page->n_rows * page->row_stride; bst_idx_t num_elements = page->n_rows * page->row_stride;
CHECK_EQ(this->row_stride, page->row_stride); CHECK_EQ(this->row_stride, page->row_stride);
@ -482,10 +519,12 @@ void EllpackPageImpl::InitCompressedData(Context const* ctx) {
void EllpackPageImpl::CreateHistIndices(Context const* ctx, void EllpackPageImpl::CreateHistIndices(Context const* ctx,
const SparsePage& row_batch, const SparsePage& row_batch,
common::Span<FeatureType const> feature_types) { common::Span<FeatureType const> feature_types) {
if (row_batch.Size() == 0) return; if (row_batch.Size() == 0) {
std::uint32_t null_gidx_value = NumSymbols() - 1; return;
}
auto null_gidx_value = this->GetDeviceAccessor(ctx, feature_types).NullValue();
const auto& offset_vec = row_batch.offset.ConstHostVector(); auto const& offset_vec = row_batch.offset.ConstHostVector();
// bin and compress entries in batches of rows // bin and compress entries in batches of rows
size_t gpu_batch_nrows = size_t gpu_batch_nrows =
@ -504,35 +543,46 @@ void EllpackPageImpl::CreateHistIndices(Context const* ctx,
const auto ent_cnt_end = offset_vec[batch_row_end]; const auto ent_cnt_end = offset_vec[batch_row_end];
/*! \brief row offset in SparsePage (the input data). */ /*! \brief row offset in SparsePage (the input data). */
dh::device_vector<size_t> row_ptrs(batch_nrows + 1); using OffT = typename std::remove_reference_t<decltype(offset_vec)>::value_type;
thrust::copy(offset_vec.data() + batch_row_begin, dh::DeviceUVector<OffT> row_ptrs(batch_nrows + 1);
offset_vec.data() + batch_row_end + 1, row_ptrs.begin()); auto size =
std::distance(offset_vec.data() + batch_row_begin, offset_vec.data() + batch_row_end + 1);
dh::safe_cuda(cudaMemcpyAsync(row_ptrs.data(), offset_vec.data() + batch_row_begin,
size * sizeof(OffT), cudaMemcpyDefault,
ctx->CUDACtx()->Stream()));
// number of entries in this batch. // number of entries in this batch.
size_t n_entries = ent_cnt_end - ent_cnt_begin; size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries); dh::DeviceUVector<Entry> entries_d(n_entries);
// copy data entries to device. // copy data entries to device.
if (row_batch.data.DeviceCanRead()) { if (row_batch.data.DeviceCanRead()) {
auto const& d_data = row_batch.data.ConstDeviceSpan(); auto const& d_data = row_batch.data.ConstDeviceSpan();
dh::safe_cuda(cudaMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(entries_d.data(), d_data.data() + ent_cnt_begin,
entries_d.data().get(), d_data.data() + ent_cnt_begin, n_entries * sizeof(Entry), cudaMemcpyDefault,
n_entries * sizeof(Entry), cudaMemcpyDefault)); ctx->CUDACtx()->Stream()));
} else { } else {
const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector(); const std::vector<Entry>& data_vec = row_batch.data.ConstHostVector();
dh::safe_cuda(cudaMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(entries_d.data(), data_vec.data() + ent_cnt_begin,
entries_d.data().get(), data_vec.data() + ent_cnt_begin, n_entries * sizeof(Entry), cudaMemcpyDefault,
n_entries * sizeof(Entry), cudaMemcpyDefault)); ctx->CUDACtx()->Stream()));
} }
const dim3 block3(32, 8, 1); // 256 threads const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1); common::DivRoundUp(row_stride, block3.y), 1);
auto device_accessor = this->GetDeviceAccessor(ctx); auto device_accessor = this->GetDeviceAccessor(ctx);
dh::LaunchKernel{grid3, block3}( // NOLINT auto launcher = [&](auto kernel) {
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.data(), dh::LaunchKernel{grid3, block3, 0, ctx->CUDACtx()->Stream()}( // NOLINT
row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), kernel, common::CompressedBufferWriter(this->NumSymbols()), gidx_buffer.data(),
device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, row_ptrs.data(), entries_d.data(), device_accessor.gidx_fvalue_map.data(),
row_stride, null_gidx_value); device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows,
row_stride, null_gidx_value);
};
if (this->IsDense()) {
launcher(CompressBinEllpackKernel<true>);
} else {
launcher(CompressBinEllpackKernel<false>);
}
} }
} }

View File

@ -85,6 +85,7 @@ struct EllpackDeviceAccessor {
bst_bin_t gidx = -1; bst_bin_t gidx = -1;
if (is_dense) { if (is_dense) {
gidx = gidx_iter[row_begin + fidx]; gidx = gidx_iter[row_begin + fidx];
gidx += this->feature_segments[fidx];
} else { } else {
gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx], gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx],
feature_segments[fidx + 1]); feature_segments[fidx + 1]);
@ -175,7 +176,7 @@ class EllpackPageImpl {
*/ */
template <typename AdapterBatch> template <typename AdapterBatch>
explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense, explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense,
common::Span<size_t> row_counts_span, common::Span<size_t const> row_counts_span,
common::Span<FeatureType const> feature_types, size_t row_stride, common::Span<FeatureType const> feature_types, size_t row_stride,
bst_idx_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts); bst_idx_t n_rows, std::shared_ptr<common::HistogramCuts const> cuts);
/** /**
@ -184,6 +185,14 @@ class EllpackPageImpl {
explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
common::Span<FeatureType const> ft); common::Span<FeatureType const> ft);
EllpackPageImpl(EllpackPageImpl const& that) = delete;
EllpackPageImpl& operator=(EllpackPageImpl const& that) = delete;
EllpackPageImpl(EllpackPageImpl&& that) = default;
EllpackPageImpl& operator=(EllpackPageImpl&& that) = default;
~EllpackPageImpl() noexcept(false);
/** /**
* @brief Copy the elements of the given ELLPACK page into this page. * @brief Copy the elements of the given ELLPACK page into this page.
* *

View File

@ -9,16 +9,17 @@
#include <numeric> // for accumulate #include <numeric> // for accumulate
#include <utility> // for move #include <utility> // for move
#include "../common/common.h" // for safe_cuda #include "../common/common.h" // for safe_cuda
#include "../common/ref_resource_view.cuh" #include "../common/cuda_rt_utils.h" // for SetDevice
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream #include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream #include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
#include "ellpack_page.cuh" // for EllpackPageImpl #include "../common/resource.cuh" // for PrivateCudaMmapConstStream
#include "ellpack_page.h" // for EllpackPage #include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "ellpack_page.cuh" // for EllpackPageImpl
#include "ellpack_page.h" // for EllpackPage
#include "ellpack_page_source.h" #include "ellpack_page_source.h"
#include "proxy_dmatrix.cuh" // for Dispatch #include "proxy_dmatrix.cuh" // for Dispatch
#include "xgboost/base.h" // for bst_idx_t #include "xgboost/base.h" // for bst_idx_t
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
namespace xgboost::data { namespace xgboost::data {
/** /**
@ -201,7 +202,7 @@ EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(StringVi
*/ */
template <typename F> template <typename F>
void EllpackPageSourceImpl<F>::Fetch() { void EllpackPageSourceImpl<F>::Fetch() {
dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); common::SetDevice(this->Device().ordinal);
if (!this->ReadCache()) { if (!this->ReadCache()) {
if (this->count_ != 0 && !this->sync_) { if (this->count_ != 0 && !this->sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0 // source is initialized to be the 0th page during construction, so when count_ is 0
@ -235,7 +236,7 @@ EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>
*/ */
template <typename F> template <typename F>
void ExtEllpackPageSourceImpl<F>::Fetch() { void ExtEllpackPageSourceImpl<F>::Fetch() {
dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); common::SetDevice(this->Device().ordinal);
if (!this->ReadCache()) { if (!this->ReadCache()) {
auto iter = this->source_->Iter(); auto iter = this->source_->Iter();
CHECK_EQ(this->count_, iter); CHECK_EQ(this->count_, iter);
@ -250,7 +251,8 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
dh::device_vector<size_t> row_counts(n_samples + 1, 0); dh::device_vector<size_t> row_counts(n_samples + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size()); common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
cuda_impl::Dispatch(proxy_, [=](auto const& value) { cuda_impl::Dispatch(proxy_, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_); return GetRowCounts(this->ctx_, value, row_counts_span, dh::GetDevice(this->ctx_),
this->missing_);
}); });
this->page_.reset(new EllpackPage{}); this->page_.reset(new EllpackPage{});

View File

@ -94,12 +94,12 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span<FeatureTy
} }
GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<FeatureType const> ft, GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<FeatureType const> ft,
common::HistogramCuts cuts, int32_t max_bins_per_feat, common::HistogramCuts cuts, bst_bin_t max_bins_per_feat,
bool isDense, double sparse_thresh, int32_t n_threads) bool is_dense, double sparse_thresh, std::int32_t n_threads)
: cut{std::move(cuts)}, : cut{std::move(cuts)},
max_numeric_bins_per_feat{max_bins_per_feat}, max_numeric_bins_per_feat{max_bins_per_feat},
base_rowid{batch.base_rowid}, base_rowid{batch.base_rowid},
isDense_{isDense} { isDense_{is_dense} {
CHECK_GE(n_threads, 1); CHECK_GE(n_threads, 1);
CHECK_EQ(row_ptr.size(), 0); CHECK_EQ(row_ptr.size(), 0);
row_ptr = common::MakeFixedVecWithMalloc(batch.Size() + 1, std::size_t{0}); row_ptr = common::MakeFixedVecWithMalloc(batch.Size() + 1, std::size_t{0});

View File

@ -12,9 +12,9 @@
namespace xgboost { namespace xgboost {
// Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin // Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin
// searching. Is there a way to unify the code? // searching. Is there a way to unify the code?
template <typename BinT, typename CompressOffset> template <typename BinT, typename DecompressOffset>
void SetIndexData(Context const* ctx, EllpackPageImpl const* page, void SetIndexData(Context const* ctx, EllpackPageImpl const* page,
std::vector<size_t>* p_hit_count_tloc, CompressOffset&& get_offset, std::vector<size_t>* p_hit_count_tloc, DecompressOffset&& get_offset,
GHistIndexMatrix* out) { GHistIndexMatrix* out) {
std::vector<common::CompressedByteT> h_gidx_buffer; std::vector<common::CompressedByteT> h_gidx_buffer;
auto accessor = page->GetHostAccessor(ctx, &h_gidx_buffer); auto accessor = page->GetHostAccessor(ctx, &h_gidx_buffer);
@ -35,8 +35,8 @@ void SetIndexData(Context const* ctx, EllpackPageImpl const* page,
for (size_t j = 0; j < r_size; ++j) { for (size_t j = 0; j < r_size; ++j) {
auto bin_idx = accessor.gidx_iter[in_rbegin + j]; auto bin_idx = accessor.gidx_iter[in_rbegin + j];
assert(bin_idx != kNull); assert(bin_idx != kNull);
index_data_span[out_rbegin + j] = get_offset(bin_idx, j); index_data_span[out_rbegin + j] = bin_idx;
++hit_count_tloc[tid * n_bins_total + bin_idx]; ++hit_count_tloc[tid * n_bins_total + get_offset(bin_idx, j)];
} }
}); });
} }
@ -86,10 +86,13 @@ GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info,
auto n_bins_total = page->Cuts().TotalBins(); auto n_bins_total = page->Cuts().TotalBins();
GetRowPtrFromEllpack(ctx, page, &this->row_ptr); GetRowPtrFromEllpack(ctx, page, &this->row_ptr);
if (page->is_dense) { if (page->IsDense()) {
auto offset = index.Offset();
common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) { common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) {
using T = decltype(dtype); using T = decltype(dtype);
::xgboost::SetIndexData<T>(ctx, page, &hit_count_tloc_, index.MakeCompressor<T>(), this); ::xgboost::SetIndexData<T>(
ctx, page, &hit_count_tloc_,
[offset](bst_bin_t bin_idx, bst_feature_t fidx) { return bin_idx + offset[fidx]; }, this);
}); });
} else { } else {
// no compression // no compression

View File

@ -189,7 +189,7 @@ class GHistIndexMatrix {
* @brief Constructor for external memory. * @brief Constructor for external memory.
*/ */
GHistIndexMatrix(SparsePage const& page, common::Span<FeatureType const> ft, GHistIndexMatrix(SparsePage const& page, common::Span<FeatureType const> ft,
common::HistogramCuts cuts, int32_t max_bins_per_feat, bool is_dense, common::HistogramCuts cuts, bst_bin_t max_bins_per_feat, bool is_dense,
double sparse_thresh, std::int32_t n_threads); double sparse_thresh, std::int32_t n_threads);
GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back. GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back.

View File

@ -12,18 +12,17 @@
namespace xgboost::data { namespace xgboost::data {
void GradientIndexPageSource::Fetch() { void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) { if (!this->ReadCache()) {
if (count_ != 0 && !sync_) { // source is initialized to be the 0th page during construction, so when count_ is 0
// source is initialized to be the 0th page during construction, so when count_ is 0 // there's no need to increment the source.
// there's no need to increment the source. if (this->count_ != 0 && !this->sync_) {
//
// The mixin doesn't sync the source if `sync_` is false, we need to sync it // The mixin doesn't sync the source if `sync_` is false, we need to sync it
// ourselves. // ourselves.
++(*source_); ++(*source_);
} }
// This is not read from cache so we still need it to be synced with sparse page source. // This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter()); CHECK_EQ(this->count_, this->source_->Iter());
auto const& csr = source_->Page(); auto const& csr = this->source_->Page();
CHECK_NE(cuts_.Values().size(), 0); CHECK_NE(this->cuts_.Values().size(), 0);
this->page_.reset(new GHistIndexMatrix{*csr, feature_types_, cuts_, max_bin_per_feat_, this->page_.reset(new GHistIndexMatrix{*csr, feature_types_, cuts_, max_bin_per_feat_,
is_dense_, sparse_thresh_, nthreads_}); is_dense_, sparse_thresh_, nthreads_});
this->WriteCache(); this->WriteCache();

View File

@ -68,7 +68,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
dh::device_vector<size_t> row_counts(rows + 1, 0); dh::device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size()); common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
cuda_impl::Dispatch(proxy, [=](auto const& value) { cuda_impl::Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); return GetRowCounts(ctx, value, row_counts_span, dh::GetDevice(ctx), missing);
}); });
auto is_dense = this->IsDense(); auto is_dense = this->IsDense();

View File

@ -72,7 +72,7 @@ void MakeSketches(Context const* ctx,
collective::Op::kMax); collective::Op::kMax);
SafeColl(rc); SafeColl(rc);
} else { } else {
CHECK_EQ(ext_info.n_features, ::xgboost::data::BatchColumns(proxy)) CHECK_EQ(ext_info.n_features, data::BatchColumns(proxy))
<< "Inconsistent number of columns."; << "Inconsistent number of columns.";
} }
@ -97,7 +97,7 @@ void MakeSketches(Context const* ctx,
lazy_init_sketch(); // Add a new level. lazy_init_sketch(); // Add a new level.
} }
proxy->Info().weights_.SetDevice(dh::GetDevice(ctx)); proxy->Info().weights_.SetDevice(dh::GetDevice(ctx));
cuda_impl::Dispatch(proxy, [&](auto const& value) { Dispatch(proxy, [&](auto const& value) {
common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing,
sketches.back().first.get()); sketches.back().first.get());
sketches.back().second++; sketches.back().second++;
@ -110,8 +110,8 @@ void MakeSketches(Context const* ctx,
dh::device_vector<size_t> row_counts(batch_rows + 1, 0); dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size()); common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
ext_info.row_stride = ext_info.row_stride =
std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { std::max(ext_info.row_stride, Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); return GetRowCounts(ctx, value, row_counts_span, dh::GetDevice(ctx), missing);
})); }));
ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end()); ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end());
ext_info.n_batches++; ext_info.n_batches++;

View File

@ -10,9 +10,9 @@
namespace xgboost::data { namespace xgboost::data {
void Cache::Commit() { void Cache::Commit() {
if (!written) { if (!this->written) {
std::partial_sum(offset.begin(), offset.end(), offset.begin()); std::partial_sum(this->offset.begin(), this->offset.end(), this->offset.begin());
written = true; this->written = true;
} }
} }

View File

@ -241,6 +241,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
float missing_; float missing_;
std::int32_t nthreads_; std::int32_t nthreads_;
bst_feature_t n_features_; bst_feature_t n_features_;
bst_idx_t fetch_cnt_{0}; // Used for sanity check.
// Index to the current page. // Index to the current page.
std::uint32_t count_{0}; std::uint32_t count_{0};
// Total number of batches. // Total number of batches.
@ -267,8 +268,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
if (ring_->empty()) { if (ring_->empty()) {
ring_->resize(n_batches_); ring_->resize(n_batches_);
} }
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed.
std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches); std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches);
n_prefetches = std::max(n_prefetches, 1); n_prefetches = std::max(n_prefetches, 1);
std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_); std::int32_t n_prefetch_batches = std::min(static_cast<bst_idx_t>(n_prefetches), n_batches_);
@ -277,14 +277,23 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
std::size_t fetch_it = count_; std::size_t fetch_it = count_;
exce_.Rethrow(); exce_.Rethrow();
// Clear out the existing page before loading new ones. This helps reduce memory usage
// when page is not loaded with mmap, in addition, it triggers necessary CUDA
// synchronizations by freeing memory.
page_.reset();
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
bool restart = fetch_it == n_batches_;
fetch_it %= n_batches_; // ring fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { if (ring_->at(fetch_it).valid()) {
continue; continue;
} }
auto const* self = this; // make sure it's const auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size()); CHECK_LT(fetch_it, cache_info_->offset.size());
// Make sure the new iteration starts with a copy to avoid spilling configuration.
if (restart) {
this->param_.prefetch_copy = true;
}
ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] { ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] {
auto page = std::make_shared<S>(); auto page = std::make_shared<S>();
this->exce_.Run([&] { this->exce_.Run([&] {
@ -298,17 +307,17 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
}); });
return page; return page;
}); });
this->fetch_cnt_++;
} }
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches) n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration."; << "Sparse DMatrix assumes forward iteration.";
monitor_.Start("Wait"); monitor_.Start("Wait-" + std::to_string(count_));
CHECK((*ring_)[count_].valid()); CHECK((*ring_)[count_].valid());
page_ = (*ring_)[count_].get(); page_ = (*ring_)[count_].get();
CHECK(!(*ring_)[count_].valid()); monitor_.Stop("Wait-" + std::to_string(count_));
monitor_.Stop("Wait");
exce_.Rethrow(); exce_.Rethrow();
@ -328,8 +337,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
timer.Stop(); timer.Stop();
// Not entirely accurate, the kernels doesn't have to flush the data. // Not entirely accurate, the kernels doesn't have to flush the data.
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in " LOG(INFO) << common::HumanMemUnit(bytes) << " written in " << timer.ElapsedSeconds()
<< timer.ElapsedSeconds() << " seconds."; << " seconds.";
cache_info_->Push(bytes); cache_info_->Push(bytes);
} }
@ -373,7 +382,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
return at_end_; return at_end_;
} }
// Call this at the last iteration. // Call this at the last iteration (it == n_batches).
void EndIter() { void EndIter() {
CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1); CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1);
this->cache_info_->Commit(); this->cache_info_->Commit();
@ -387,18 +396,22 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
virtual void Reset(BatchParam const& param) { virtual void Reset(BatchParam const& param) {
TryLockGuard guard{single_threaded_}; TryLockGuard guard{single_threaded_};
this->at_end_ = false; auto at_end = false;
auto cnt = this->count_; std::swap(this->at_end_, at_end);
this->count_ = 0;
bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches; bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches;
this->param_ = param; this->param_ = param;
if (cnt != 0 || changed) { this->count_ = 0;
if (!at_end || changed) {
// The last iteration did not get to the end, clear the ring to start from 0. // The last iteration did not get to the end, clear the ring to start from 0.
this->ring_ = std::make_unique<Ring>(); this->ring_ = std::make_unique<Ring>();
this->Fetch();
} }
this->Fetch(); // Get the 0^th page, prefetch the next page.
} }
[[nodiscard]] auto FetchCount() const { return this->fetch_cnt_; }
}; };
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
@ -413,10 +426,8 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter_; DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter_;
DMatrixProxy* proxy_; DMatrixProxy* proxy_;
std::size_t base_row_id_{0}; std::size_t base_row_id_{0};
bst_idx_t fetch_cnt_{0}; // Used for sanity check.
void Fetch() final { void Fetch() final {
fetch_cnt_++;
page_ = std::make_shared<SparsePage>(); page_ = std::make_shared<SparsePage>();
// The first round of reading, this is responsible for initialization. // The first round of reading, this is responsible for initialization.
if (!this->ReadCache()) { if (!this->ReadCache()) {
@ -467,9 +478,10 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
if (at_end_) { if (at_end_) {
this->EndIter(); this->EndIter();
this->proxy_ = nullptr; this->proxy_ = nullptr;
} else {
this->Fetch();
} }
this->Fetch();
return *this; return *this;
} }
@ -481,13 +493,13 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
SparsePageSourceImpl::Reset(param); SparsePageSourceImpl::Reset(param);
TryLockGuard guard{single_threaded_}; TryLockGuard guard{single_threaded_};
base_row_id_ = 0; this->base_row_id_ = 0;
} }
[[nodiscard]] auto FetchCount() const { return fetch_cnt_; }
}; };
// A mixin for advancing the iterator. /**
* @brief A mixin for advancing the iterator with a sparse page source.
*/
template <typename S, template <typename S,
typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>> typename FormatCreatePolicy = DefaultFormatStreamPolicy<S, DefaultFormatPolicy>>
class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> { class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
@ -496,7 +508,7 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
using Super = SparsePageSourceImpl<S, FormatCreatePolicy>; using Super = SparsePageSourceImpl<S, FormatCreatePolicy>;
// synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page
// so we avoid fetching it. // so we avoid fetching it.
bool sync_{true}; bool const sync_;
public: public:
PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features, PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features,
@ -506,8 +518,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
// can assume the source to be ready. // can assume the source to be ready.
[[nodiscard]] PageSourceIncMixIn& operator++() final { [[nodiscard]] PageSourceIncMixIn& operator++() final {
TryLockGuard guard{this->single_threaded_}; TryLockGuard guard{this->single_threaded_};
// Increment the source. // Increment the source.
if (sync_) { if (this->sync_) {
++(*source_); ++(*source_);
} }
// Increment self. // Increment self.
@ -516,24 +529,16 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
this->at_end_ = this->count_ == this->n_batches_; this->at_end_ = this->count_ == this->n_batches_;
if (this->at_end_) { if (this->at_end_) {
// If this is the first round of iterations, we have just built the binary cache
// from soruce. For a non-sync page type, the source hasn't been updated to the end
// iteration yet due to skipped increment. We increment the source here and it will
// call the `EndIter` method itself.
bool src_need_inc = !sync_ && this->source_->Iter() != 0;
if (src_need_inc) {
CHECK_EQ(this->source_->Iter(), this->count_ - 1);
++(*source_);
}
this->EndIter(); this->EndIter();
CHECK(this->cache_info_->written);
if (src_need_inc) { if (!this->sync_) {
CHECK(this->cache_info_->written); source_.reset(); // Make sure no unnecessary fetch.
} }
} else {
this->Fetch();
} }
this->Fetch();
if (sync_) { if (this->sync_) {
// Sanity check. // Sanity check.
CHECK_EQ(source_->Iter(), this->count_); CHECK_EQ(source_->Iter(), this->count_);
} }
@ -541,7 +546,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl<S, FormatCreatePolicy> {
} }
void Reset(BatchParam const& param) final { void Reset(BatchParam const& param) final {
this->source_->Reset(param); if (this->sync_ || !this->cache_info_->written) {
this->source_->Reset(param);
}
Super::Reset(param); Super::Reset(param);
} }
}; };
@ -625,8 +632,9 @@ class ExtQantileSourceMixin : public SparsePageSourceImpl<S, FormatCreatePolicy>
CHECK(this->cache_info_->written); CHECK(this->cache_info_->written);
source_ = nullptr; // release the source source_ = nullptr; // release the source
} else {
this->Fetch();
} }
this->Fetch();
return *this; return *this;
} }

View File

@ -24,16 +24,20 @@ __host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
return {lhs.first + rhs.first, lhs.second + rhs.second}; return {lhs.first + rhs.first, lhs.second + rhs.second};
} }
XGBOOST_DEV_INLINE bst_feature_t FeatIdx(FeatureGroup const& group, bst_idx_t idx,
std::int32_t feature_stride) {
auto fidx = group.start_feature + idx % feature_stride;
return fidx;
}
XGBOOST_DEV_INLINE bst_idx_t IterIdx(EllpackDeviceAccessor const& matrix, XGBOOST_DEV_INLINE bst_idx_t IterIdx(EllpackDeviceAccessor const& matrix,
RowPartitioner::RowIndexT ridx, FeatureGroup const& group, RowPartitioner::RowIndexT ridx, bst_feature_t fidx) {
bst_idx_t idx, std::int32_t feature_stride) {
// ridx_local = ridx - base_rowid <== Row index local to each batch // ridx_local = ridx - base_rowid <== Row index local to each batch
// entry_idx = ridx_local * row_stride <== Starting entry index for this row in the matrix // entry_idx = ridx_local * row_stride <== Starting entry index for this row in the matrix
// entry_idx += start_feature <== Inside a row, first column inside this feature group // entry_idx += start_feature <== Inside a row, first column inside this feature group
// idx % feature_stride <== The feaature index local to the current feature group // idx % feature_stride <== The feaature index local to the current feature group
// entry_idx += idx % feature_stride <== Final index. // entry_idx += idx % feature_stride <== Final index.
return (ridx - matrix.base_rowid) * matrix.row_stride + group.start_feature + return (ridx - matrix.base_rowid) * matrix.row_stride + fidx;
idx % feature_stride;
} }
} // anonymous namespace } // anonymous namespace
@ -134,7 +138,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest,
*reinterpret_cast<uint64_t*>(&h)); *reinterpret_cast<uint64_t*>(&h));
} }
template <int kBlockThreads, int kItemsPerThread, template <bool kIsDense, int kBlockThreads, int kItemsPerThread,
int kItemsPerTile = kBlockThreads * kItemsPerThread> int kItemsPerTile = kBlockThreads * kItemsPerThread>
class HistogramAgent { class HistogramAgent {
GradientPairInt64* smem_arr_; GradientPairInt64* smem_arr_;
@ -159,7 +163,7 @@ class HistogramAgent {
d_ridx_(d_ridx.data()), d_ridx_(d_ridx.data()),
group_(group), group_(group),
matrix_(matrix), matrix_(matrix),
feature_stride_(matrix.is_dense ? group.num_features : matrix.row_stride), feature_stride_(kIsDense ? group.num_features : matrix.row_stride),
n_elements_(feature_stride_ * d_ridx.size()), n_elements_(feature_stride_ * d_ridx.size()),
rounding_(rounding), rounding_(rounding),
d_gpair_(d_gpair) {} d_gpair_(d_gpair) {}
@ -169,12 +173,19 @@ class HistogramAgent {
idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_);
idx += kBlockThreads) { idx += kBlockThreads) {
Idx ridx = d_ridx_[idx / feature_stride_]; Idx ridx = d_ridx_[idx / feature_stride_];
bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)]; auto fidx = FeatIdx(group_, idx, feature_stride_);
if (matrix_.is_dense || gidx != matrix_.NullValue()) { bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)];
if (kIsDense || compressed_bin != matrix_.NullValue()) {
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
// Subtract start_bin to write to group-local histogram. If this is not a dense // Subtract start_bin to write to group-local histogram. If this is not a dense
// matrix, then start_bin is 0 since featuregrouping doesn't support sparse data. // matrix, then start_bin is 0 since featuregrouping doesn't support sparse data.
AtomicAddGpairShared(smem_arr_ + gidx - group_.start_bin, adjusted); if (kIsDense) {
AtomicAddGpairShared(
smem_arr_ + compressed_bin + this->matrix_.feature_segments[fidx] - group_.start_bin,
adjusted);
} else {
AtomicAddGpairShared(smem_arr_ + compressed_bin - group_.start_bin, adjusted);
}
} }
} }
} }
@ -185,7 +196,7 @@ class HistogramAgent {
__device__ void ProcessFullTileShared(std::size_t offset) { __device__ void ProcessFullTileShared(std::size_t offset) {
std::size_t idx[kItemsPerThread]; std::size_t idx[kItemsPerThread];
Idx ridx[kItemsPerThread]; Idx ridx[kItemsPerThread];
int gidx[kItemsPerThread]; bst_bin_t gidx[kItemsPerThread];
GradientPair gpair[kItemsPerThread]; GradientPair gpair[kItemsPerThread];
#pragma unroll #pragma unroll
for (int i = 0; i < kItemsPerThread; i++) { for (int i = 0; i < kItemsPerThread; i++) {
@ -198,11 +209,17 @@ class HistogramAgent {
#pragma unroll #pragma unroll
for (int i = 0; i < kItemsPerThread; i++) { for (int i = 0; i < kItemsPerThread; i++) {
gpair[i] = d_gpair_[ridx[i]]; gpair[i] = d_gpair_[ridx[i]];
gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], group_, idx[i], feature_stride_)]; auto fidx = FeatIdx(group_, idx[i], feature_stride_);
if (kIsDense) {
gidx[i] =
matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)] + matrix_.feature_segments[fidx];
} else {
gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)];
}
} }
#pragma unroll #pragma unroll
for (int i = 0; i < kItemsPerThread; i++) { for (int i = 0; i < kItemsPerThread; i++) {
if ((matrix_.is_dense || gidx[i] != matrix_.NullValue())) { if ((kIsDense || gidx[i] != matrix_.NullValue())) {
auto adjusted = rounding_.ToFixedPoint(gpair[i]); auto adjusted = rounding_.ToFixedPoint(gpair[i]);
AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted); AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
} }
@ -229,16 +246,22 @@ class HistogramAgent {
__device__ void BuildHistogramWithGlobal() { __device__ void BuildHistogramWithGlobal() {
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) { for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
Idx ridx = d_ridx_[idx / feature_stride_]; Idx ridx = d_ridx_[idx / feature_stride_];
bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)]; auto fidx = FeatIdx(group_, idx, feature_stride_);
if (matrix_.is_dense || gidx != matrix_.NullValue()) { bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)];
if (kIsDense || compressed_bin != matrix_.NullValue()) {
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted); if (kIsDense) {
auto start_bin = this->matrix_.feature_segments[fidx];
AtomicAddGpairGlobal(d_node_hist_ + compressed_bin + start_bin, adjusted);
} else {
AtomicAddGpairGlobal(d_node_hist_ + compressed_bin, adjusted);
}
} }
} }
} }
}; };
template <bool use_shared_memory_histograms, int kBlockThreads, int kItemsPerThread> template <bool kIsDense, bool use_shared_memory_histograms, int kBlockThreads, int kItemsPerThread>
__global__ void __launch_bounds__(kBlockThreads) __global__ void __launch_bounds__(kBlockThreads)
SharedMemHistKernel(const EllpackDeviceAccessor matrix, SharedMemHistKernel(const EllpackDeviceAccessor matrix,
const FeatureGroupsAccessor feature_groups, const FeatureGroupsAccessor feature_groups,
@ -249,8 +272,8 @@ __global__ void __launch_bounds__(kBlockThreads)
extern __shared__ char smem[]; extern __shared__ char smem[];
const FeatureGroup group = feature_groups[blockIdx.y]; const FeatureGroup group = feature_groups[blockIdx.y];
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem); auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
auto agent = HistogramAgent<kBlockThreads, kItemsPerThread>(smem_arr, d_node_hist, group, matrix, auto agent = HistogramAgent<kIsDense, kBlockThreads, kItemsPerThread>(
d_ridx, rounding, d_gpair); smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
if (use_shared_memory_histograms) { if (use_shared_memory_histograms) {
agent.BuildHistogramWithShared(); agent.BuildHistogramWithShared();
} else { } else {
@ -265,11 +288,22 @@ constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread;
} // namespace } // namespace
// Use auto deduction guide to workaround compiler error. // Use auto deduction guide to workaround compiler error.
template <auto Global = SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>, template <auto GlobalDense = SharedMemHistKernel<true, false, kBlockThreads, kItemsPerThread>,
auto Shared = SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>> auto Global = SharedMemHistKernel<false, false, kBlockThreads, kItemsPerThread>,
auto SharedDense = SharedMemHistKernel<true, true, kBlockThreads, kItemsPerThread>,
auto Shared = SharedMemHistKernel<false, true, kBlockThreads, kItemsPerThread>>
struct HistogramKernel { struct HistogramKernel {
decltype(Global) global_kernel{SharedMemHistKernel<false, kBlockThreads, kItemsPerThread>}; // Kernel for working with dense Ellpack using the global memory.
decltype(Shared) shared_kernel{SharedMemHistKernel<true, kBlockThreads, kItemsPerThread>}; decltype(Global) global_dense_kernel{
SharedMemHistKernel<true, false, kBlockThreads, kItemsPerThread>};
// Kernel for working with sparse Ellpack using the global memory.
decltype(Global) global_kernel{SharedMemHistKernel<false, false, kBlockThreads, kItemsPerThread>};
// Kernel for working with dense Ellpack using the shared memory.
decltype(Shared) shared_dense_kernel{
SharedMemHistKernel<true, true, kBlockThreads, kItemsPerThread>};
// Kernel for working with sparse Ellpack using the shared memory.
decltype(Shared) shared_kernel{SharedMemHistKernel<false, true, kBlockThreads, kItemsPerThread>};
bool shared{false}; bool shared{false};
std::uint32_t grid_size{0}; std::uint32_t grid_size{0};
std::size_t smem_size{0}; std::size_t smem_size{0};
@ -303,28 +337,30 @@ struct HistogramKernel {
// maximum number of blocks // maximum number of blocks
this->grid_size = n_blocks_per_mp * n_mps; this->grid_size = n_blocks_per_mp * n_mps;
}; };
// Initialize all kernel instantiations
init(this->global_kernel); for (auto& kernel : {global_dense_kernel, global_kernel, shared_dense_kernel, shared_kernel}) {
init(this->shared_kernel); init(kernel);
}
} }
}; };
class DeviceHistogramBuilderImpl { class DeviceHistogramBuilderImpl {
std::unique_ptr<HistogramKernel<>> kernel_{nullptr}; std::unique_ptr<HistogramKernel<>> kernel_{nullptr};
bool force_global_memory_{false};
public: public:
void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory) { bool force_global_memory) {
this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory); this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory);
this->force_global_memory_ = force_global_memory; if (force_global_memory) {
CHECK(!this->kernel_->shared);
}
} }
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups, FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair, common::Span<GradientPair const> gpair,
common::Span<const cuda_impl::RowIndexT> d_ridx, common::Span<const cuda_impl::RowIndexT> d_ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding) { common::Span<GradientPairInt64> histogram, GradientQuantiser rounding) const {
CHECK(kernel_); CHECK(kernel_);
// Otherwise launch blocks such that each block has a minimum amount of work to do // Otherwise launch blocks such that each block has a minimum amount of work to do
// There are fixed costs to launching each block, e.g. zeroing shared memory // There are fixed costs to launching each block, e.g. zeroing shared memory
@ -338,17 +374,26 @@ class DeviceHistogramBuilderImpl {
auto constexpr kMinItemsPerBlock = ItemsPerTile(); auto constexpr kMinItemsPerBlock = ItemsPerTile();
auto grid_size = std::min(kernel_->grid_size, static_cast<std::uint32_t>(common::DivRoundUp( auto grid_size = std::min(kernel_->grid_size, static_cast<std::uint32_t>(common::DivRoundUp(
items_per_group, kMinItemsPerBlock))); items_per_group, kMinItemsPerBlock)));
auto launcher = [&](auto kernel) {
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size, ctx->Stream()}(
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding);
};
if (this->force_global_memory_ || !this->kernel_->shared) { if (!this->kernel_->shared) {
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT CHECK_EQ(this->kernel_->smem_size, 0);
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size, if (matrix.is_dense) {
ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx, launcher(this->kernel_->global_dense_kernel);
histogram.data(), gpair.data(), rounding); } else {
launcher(this->kernel_->global_kernel);
}
} else { } else {
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT CHECK_NE(this->kernel_->smem_size, 0);
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size, if (matrix.is_dense) {
ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx, launcher(this->kernel_->shared_dense_kernel);
histogram.data(), gpair.data(), rounding); } else {
launcher(this->kernel_->shared_kernel);
}
} }
} }
}; };

View File

@ -172,8 +172,8 @@ class DeviceHistogramBuilder {
// Attempt to do subtraction trick // Attempt to do subtraction trick
// return true if succeeded // return true if succeeded
[[nodiscard]] bool SubtractionTrick(bst_node_t nidx_parent, bst_node_t nidx_histogram, [[nodiscard]] bool SubtractionTrick(Context const* ctx, bst_node_t nidx_parent,
bst_node_t nidx_subtraction) { bst_node_t nidx_histogram, bst_node_t nidx_subtraction) {
if (!hist_.HistogramExists(nidx_histogram) || !hist_.HistogramExists(nidx_parent)) { if (!hist_.HistogramExists(nidx_histogram) || !hist_.HistogramExists(nidx_parent)) {
return false; return false;
} }
@ -181,13 +181,13 @@ class DeviceHistogramBuilder {
auto d_node_hist_histogram = hist_.GetNodeHistogram(nidx_histogram); auto d_node_hist_histogram = hist_.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist_.GetNodeHistogram(nidx_subtraction); auto d_node_hist_subtraction = hist_.GetNodeHistogram(nidx_subtraction);
dh::LaunchN(d_node_hist_parent.size(), [=] __device__(size_t idx) { dh::LaunchN(d_node_hist_parent.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) {
d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx];
}); });
return true; return true;
} }
[[nodiscard]] auto SubtractHist(std::vector<GPUExpandEntry> const& candidates, [[nodiscard]] auto SubtractHist(Context const* ctx, std::vector<GPUExpandEntry> const& candidates,
std::vector<bst_node_t> const& build_nidx, std::vector<bst_node_t> const& build_nidx,
std::vector<bst_node_t> const& subtraction_nidx) { std::vector<bst_node_t> const& subtraction_nidx) {
this->monitor_.Start(__func__); this->monitor_.Start(__func__);
@ -197,7 +197,7 @@ class DeviceHistogramBuilder {
auto subtraction_trick_nidx = subtraction_nidx.at(i); auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid; auto parent_nidx = candidates.at(i).nid;
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { if (!this->SubtractionTrick(ctx, parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
need_build.push_back(subtraction_trick_nidx); need_build.push_back(subtraction_trick_nidx);
} }
} }

View File

@ -129,7 +129,7 @@ struct WriteResultsFunctor {
* @param d_batch_info Node data, with the size of the input number of nodes. * @param d_batch_info Node data, with the size of the input number of nodes.
*/ */
template <typename OpT, typename OpDataT> template <typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info, void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<cuda_impl::RowIndexT> ridx, common::Span<cuda_impl::RowIndexT> ridx,
common::Span<cuda_impl::RowIndexT> ridx_tmp, common::Span<cuda_impl::RowIndexT> ridx_tmp,
common::Span<cuda_impl::RowIndexT> d_counts, bst_idx_t total_rows, OpT op, common::Span<cuda_impl::RowIndexT> d_counts, bst_idx_t total_rows, OpT op,
@ -150,17 +150,28 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch, return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
go_left}; go_left};
}); });
std::size_t temp_bytes = 0; // Avoid using int as the offset type
// Restriction imposed by cub. std::size_t n_bytes = 0;
CHECK_LE(total_rows, static_cast<bst_idx_t>(std::numeric_limits<std::int32_t>::max()));
if (tmp->empty()) { if (tmp->empty()) {
dh::safe_cuda(cub::DeviceScan::InclusiveScan( auto ret =
nullptr, temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp{}, total_rows)); cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
tmp->resize(temp_bytes); cub::NullType, std::int64_t>::Dispatch(nullptr, n_bytes, input_iterator,
discard_write_iterator,
IndexFlagOp{}, cub::NullType{},
total_rows,
ctx->CUDACtx()->Stream());
dh::safe_cuda(ret);
tmp->resize(n_bytes);
} }
temp_bytes = tmp->size(); n_bytes = tmp->size();
dh::safe_cuda(cub::DeviceScan::InclusiveScan(tmp->data(), temp_bytes, input_iterator, auto ret =
discard_write_iterator, IndexFlagOp{}, total_rows)); cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
cub::NullType, std::int64_t>::Dispatch(tmp->data(), n_bytes, input_iterator,
discard_write_iterator,
IndexFlagOp{}, cub::NullType{},
total_rows,
ctx->CUDACtx()->Stream());
dh::safe_cuda(ret);
constexpr int kBlockSize = 256; constexpr int kBlockSize = 256;
@ -169,7 +180,8 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
SortPositionCopyKernel<kBlockSize, OpDataT> SortPositionCopyKernel<kBlockSize, OpDataT>
<<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows); <<<grid_size, kBlockSize, 0, ctx->CUDACtx()->Stream()>>>(batch_info_itr, ridx, ridx_tmp,
total_rows);
} }
struct NodePositionInfo { struct NodePositionInfo {
@ -293,7 +305,7 @@ class RowPartitioner {
* second. Returns true if this training instance goes on the left partition. * second. Returns true if this training instance goes on the left partition.
*/ */
template <typename UpdatePositionOpT, typename OpDataT> template <typename UpdatePositionOpT, typename OpDataT>
void UpdatePositionBatch(const std::vector<bst_node_t>& nidx, void UpdatePositionBatch(Context const* ctx, const std::vector<bst_node_t>& nidx,
const std::vector<bst_node_t>& left_nidx, const std::vector<bst_node_t>& left_nidx,
const std::vector<bst_node_t>& right_nidx, const std::vector<bst_node_t>& right_nidx,
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) { const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
@ -316,21 +328,22 @@ class RowPartitioner {
} }
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>), h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
cudaMemcpyDefault)); cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
// Temporary arrays // Temporary arrays
auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size(), 0); auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size());
// Must initialize with 0 as 0 count is not written in the kernel.
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0); dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
// Partition the rows according to the operator // Partition the rows according to the operator
SortPositionBatch<UpdatePositionOpT, OpDataT>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, dh::ToSpan(d_batch_info), dh::ToSpan(ridx_),
dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_); total_rows, op, &tmp_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault)); cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
// TODO(Rory): this synchronisation hurts performance a lot // TODO(Rory): this synchronisation hurts performance a lot
// Future optimisation should find a way to skip this // Future optimisation should find a way to skip this
dh::DefaultStream().Sync(); ctx->CUDACtx()->Stream().Sync();
// Update segments // Update segments
for (std::size_t i = 0; i < nidx.size(); i++) { for (std::size_t i = 0; i < nidx.size(); i++) {
@ -341,9 +354,9 @@ class RowPartitioner {
std::max(left_nidx[i], right_nidx[i]) + 1)); std::max(left_nidx[i], right_nidx[i]) + 1));
ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]}; ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]};
ridx_segments_[left_nidx[i]] = ridx_segments_[left_nidx[i]] =
NodePositionInfo{Segment(segment.begin, segment.begin + left_count)}; NodePositionInfo{Segment{segment.begin, segment.begin + left_count}};
ridx_segments_[right_nidx[i]] = ridx_segments_[right_nidx[i]] =
NodePositionInfo{Segment(segment.begin + left_count, segment.end)}; NodePositionInfo{Segment{segment.begin + left_count, segment.end}};
} }
} }

View File

@ -119,17 +119,15 @@ struct DeviceSplitCandidate {
}; };
namespace cuda_impl { namespace cuda_impl {
constexpr auto DftPrefetchBatches() { return 2; }
inline BatchParam HistBatch(TrainParam const& param) { inline BatchParam HistBatch(TrainParam const& param) {
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
p.prefetch_copy = true; p.prefetch_copy = true;
p.n_prefetch_batches = 1; p.n_prefetch_batches = DftPrefetchBatches();
return p; return p;
} }
inline BatchParam HistBatch(bst_bin_t max_bin) {
return {max_bin, TrainParam::DftSparseThreshold()};
}
inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess, inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess,
ObjInfo const& task) { ObjInfo const& task) {
return BatchParam{p.max_bin, hess, !task.const_hess}; return BatchParam{p.max_bin, hess, !task.const_hess};
@ -139,7 +137,7 @@ inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hes
inline BatchParam StaticBatch(bool prefetch_copy) { inline BatchParam StaticBatch(bool prefetch_copy) {
BatchParam p; BatchParam p;
p.prefetch_copy = prefetch_copy; p.prefetch_copy = prefetch_copy;
p.n_prefetch_batches = 1; p.n_prefetch_batches = DftPrefetchBatches();
return p; return p;
} }
} // namespace cuda_impl } // namespace cuda_impl

View File

@ -70,7 +70,6 @@ void AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer,
common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub) { common::Span<bst_node_t> nodes_to_build, common::Span<bst_node_t> nodes_to_sub) {
auto const& tree = *p_tree; auto const& tree = *p_tree;
std::size_t nidx_in_set{0}; std::size_t nidx_in_set{0};
double total{0.0}, smaller{0.0};
auto p_build_nidx = nodes_to_build.data(); auto p_build_nidx = nodes_to_build.data();
auto p_sub_nidx = nodes_to_sub.data(); auto p_sub_nidx = nodes_to_sub.data();
for (auto& e : candidates) { for (auto& e : candidates) {
@ -81,15 +80,12 @@ void AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer,
auto left_sum = quantizer->ToFloatingPoint(e.split.left_sum); auto left_sum = quantizer->ToFloatingPoint(e.split.left_sum);
auto right_sum = quantizer->ToFloatingPoint(e.split.right_sum); auto right_sum = quantizer->ToFloatingPoint(e.split.right_sum);
bool fewer_right = right_sum.GetHess() < left_sum.GetHess(); bool fewer_right = right_sum.GetHess() < left_sum.GetHess();
total += left_sum.GetHess() + right_sum.GetHess();
if (fewer_right) { if (fewer_right) {
p_build_nidx[nidx_in_set] = tree[e.nid].RightChild(); p_build_nidx[nidx_in_set] = tree[e.nid].RightChild();
p_sub_nidx[nidx_in_set] = tree[e.nid].LeftChild(); p_sub_nidx[nidx_in_set] = tree[e.nid].LeftChild();
smaller += right_sum.GetHess();
} else { } else {
p_build_nidx[nidx_in_set] = tree[e.nid].LeftChild(); p_build_nidx[nidx_in_set] = tree[e.nid].LeftChild();
p_sub_nidx[nidx_in_set] = tree[e.nid].RightChild(); p_sub_nidx[nidx_in_set] = tree[e.nid].RightChild();
smaller += left_sum.GetHess();
} }
++nidx_in_set; ++nidx_in_set;
} }
@ -348,7 +344,7 @@ struct GPUHistMakerDevice {
// This gives much better latency in a distributed setting when processing a large batch // This gives much better latency in a distributed setting when processing a large batch
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size()); this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size());
// Perform subtraction for sibiling nodes // Perform subtraction for sibiling nodes
auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx); auto need_build = this->histogram_.SubtractHist(ctx_, candidates, build_nidx, subtraction_nidx);
if (need_build.empty()) { if (need_build.empty()) {
this->monitor.Stop(__func__); this->monitor.Stop(__func__);
return; return;
@ -383,12 +379,14 @@ struct GPUHistMakerDevice {
BitVector decision_bits{dh::ToSpan(decision_storage)}; BitVector decision_bits{dh::ToSpan(decision_storage)};
BitVector missing_bits{dh::ToSpan(missing_storage)}; BitVector missing_bits{dh::ToSpan(missing_storage)};
auto cuctx = this->ctx_->CUDACtx();
dh::TemporaryArray<NodeSplitData> split_data_storage(num_candidates); dh::TemporaryArray<NodeSplitData> split_data_storage(num_candidates);
dh::safe_cuda(cudaMemcpyAsync(split_data_storage.data().get(), split_data.data(), dh::safe_cuda(cudaMemcpyAsync(split_data_storage.data().get(), split_data.data(),
num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault)); num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault,
cuctx->Stream()));
auto d_split_data = dh::ToSpan(split_data_storage); auto d_split_data = dh::ToSpan(split_data_storage);
dh::LaunchN(d_matrix.n_rows, [=] __device__(std::size_t ridx) mutable { dh::LaunchN(d_matrix.n_rows, cuctx->Stream(), [=] __device__(std::size_t ridx) mutable {
for (auto i = 0; i < num_candidates; i++) { for (auto i = 0; i < num_candidates; i++) {
auto const& data = d_split_data[i]; auto const& data = d_split_data[i];
auto const cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); auto const cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
@ -421,7 +419,7 @@ struct GPUHistMakerDevice {
CHECK_EQ(partitioners_.size(), 1) << "External memory with column split is not yet supported."; CHECK_EQ(partitioners_.size(), 1) << "External memory with column split is not yet supported.";
partitioners_.front()->UpdatePositionBatch( partitioners_.front()->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data, ctx_, nidx, left_nidx, right_nidx, split_data,
[=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) { [=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) {
auto const index = ridx * num_candidates + nidx_in_batch; auto const index = ridx * num_candidates + nidx_in_batch;
bool go_left; bool go_left;
@ -495,10 +493,11 @@ struct GPUHistMakerDevice {
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
} else { } else {
partitioners_.at(k)->UpdatePositionBatch( partitioners_.at(k)->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data, ctx_, nidx, left_nidx, right_nidx, split_data,
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
const NodeSplitData& data) { return go_left(ridx, data); }); const NodeSplitData& data) { return go_left(ridx, data); });
} }
monitor.Stop("UpdatePositionBatch"); monitor.Stop("UpdatePositionBatch");
for (auto nidx : build_nidx) { for (auto nidx : build_nidx) {
@ -556,7 +555,7 @@ struct GPUHistMakerDevice {
return; return;
} }
dh::caching_device_vector<uint32_t> categories; dh::CachingDeviceUVector<std::uint32_t> categories;
dh::CopyTo(p_tree->GetSplitCategories(), &categories, this->ctx_->CUDACtx()->Stream()); dh::CopyTo(p_tree->GetSplitCategories(), &categories, this->ctx_->CUDACtx()->Stream());
auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto const& cat_segments = p_tree->GetSplitCategoriesPtr();
auto d_categories = dh::ToSpan(categories); auto d_categories = dh::ToSpan(categories);
@ -575,7 +574,7 @@ struct GPUHistMakerDevice {
} }
auto go_left_op = GoLeftOp{d_matrix}; auto go_left_op = GoLeftOp{d_matrix};
dh::caching_device_vector<NodeSplitData> d_split_data; dh::CachingDeviceUVector<NodeSplitData> d_split_data;
dh::CopyTo(split_data, &d_split_data, this->ctx_->CUDACtx()->Stream()); dh::CopyTo(split_data, &d_split_data, this->ctx_->CUDACtx()->Stream());
auto s_split_data = dh::ToSpan(d_split_data); auto s_split_data = dh::ToSpan(d_split_data);
@ -610,7 +609,7 @@ struct GPUHistMakerDevice {
// Use the nodes from tree, the leaf value might be changed by the objective since the // Use the nodes from tree, the leaf value might be changed by the objective since the
// last update tree call. // last update tree call.
dh::caching_device_vector<RegTree::Node> nodes; dh::CachingDeviceUVector<RegTree::Node> nodes;
dh::CopyTo(p_tree->GetNodes(), &nodes, this->ctx_->CUDACtx()->Stream()); dh::CopyTo(p_tree->GetNodes(), &nodes, this->ctx_->CUDACtx()->Stream());
common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes); common::Span<RegTree::Node> d_nodes = dh::ToSpan(nodes);
CHECK_EQ(out_preds_d.Shape(1), 1); CHECK_EQ(out_preds_d.Shape(1), 1);
@ -820,6 +819,7 @@ class GPUHistMaker : public TreeUpdater {
} }
void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) { void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) {
monitor_.Start(__func__);
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
// Synchronise the column sampling seed // Synchronise the column sampling seed
@ -840,24 +840,22 @@ class GPUHistMaker : public TreeUpdater {
p_last_fmat_ = p_fmat; p_last_fmat_ = p_fmat;
initialised_ = true; initialised_ = true;
monitor_.Stop(__func__);
} }
void InitData(TrainParam const* param, DMatrix* dmat, RegTree const* p_tree) { void InitData(TrainParam const* param, DMatrix* dmat, RegTree const* p_tree) {
monitor_.Start(__func__);
if (!initialised_) { if (!initialised_) {
monitor_.Start("InitDataOnce");
this->InitDataOnce(param, dmat); this->InitDataOnce(param, dmat);
monitor_.Stop("InitDataOnce");
} }
p_last_tree_ = p_tree; p_last_tree_ = p_tree;
CHECK(hist_maker_param_.GetInitialised()); CHECK(hist_maker_param_.GetInitialised());
monitor_.Stop(__func__);
} }
void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void UpdateTree(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) { RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
monitor_.Start("InitData");
this->InitData(param, p_fmat, p_tree); this->InitData(param, p_fmat, p_tree);
monitor_.Stop("InitData");
gpair->SetDevice(ctx_->Device()); gpair->SetDevice(ctx_->Device());
maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position);
} }

View File

@ -0,0 +1,19 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/common.h"
namespace xgboost::common {
TEST(Common, HumanMemUnit) {
auto name = HumanMemUnit(1024 * 1024 * 1024ul);
ASSERT_EQ(name, "1GB");
name = HumanMemUnit(1024 * 1024ul);
ASSERT_EQ(name, "1MB");
name = HumanMemUnit(1024);
ASSERT_EQ(name, "1KB");
name = HumanMemUnit(1);
ASSERT_EQ(name, "1B");
}
} // namespace xgboost::common

View File

@ -1,9 +1,9 @@
// Copyright (c) 2019 by Contributors /**
* Copyright 2019-2024, XGBoost contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include "../../../src/data/adapter.h" #include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../../../src/common/timer.h"
#include "../helpers.h" #include "../helpers.h"
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "../../../src/data/device_adapter.cuh" #include "../../../src/data/device_adapter.cuh"
@ -64,7 +64,7 @@ TEST(DeviceAdapter, GetRowCounts) {
auto adapter = CupyAdapter{str_arr}; auto adapter = CupyAdapter{str_arr};
HostDeviceVector<bst_idx_t> offset(adapter.NumRows() + 1, 0); HostDeviceVector<bst_idx_t> offset(adapter.NumRows() + 1, 0);
offset.SetDevice(ctx.Device()); offset.SetDevice(ctx.Device());
auto rstride = GetRowCounts(adapter.Value(), offset.DeviceSpan(), ctx.Device(), auto rstride = GetRowCounts(&ctx, adapter.Value(), offset.DeviceSpan(), ctx.Device(),
std::numeric_limits<float>::quiet_NaN()); std::numeric_limits<float>::quiet_NaN());
ASSERT_EQ(rstride, n_features); ASSERT_EQ(rstride, n_features);
} }

View File

@ -30,13 +30,13 @@ TEST(EllpackPage, EmptyDMatrix) {
} }
TEST(EllpackPage, BuildGidxDense) { TEST(EllpackPage, BuildGidxDense) {
int constexpr kNRows = 16, kNCols = 8; bst_idx_t n_samples = 16, n_features = 8;
auto ctx = MakeCUDACtx(0); auto ctx = MakeCUDACtx(0);
auto page = BuildEllpackPage(&ctx, kNRows, kNCols); auto page = BuildEllpackPage(&ctx, n_samples, n_features);
std::vector<common::CompressedByteT> h_gidx_buffer; std::vector<common::CompressedByteT> h_gidx_buffer;
auto h_accessor = page->GetHostAccessor(&ctx, &h_gidx_buffer); auto h_accessor = page->GetHostAccessor(&ctx, &h_gidx_buffer);
ASSERT_EQ(page->row_stride, kNCols); ASSERT_EQ(page->row_stride, n_features);
std::vector<uint32_t> solution = { std::vector<uint32_t> solution = {
0, 3, 8, 9, 14, 17, 20, 21, 0, 3, 8, 9, 14, 17, 20, 21,
@ -56,8 +56,9 @@ TEST(EllpackPage, BuildGidxDense) {
2, 4, 8, 10, 14, 15, 19, 22, 2, 4, 8, 10, 14, 15, 19, 22,
1, 4, 7, 10, 14, 16, 19, 21, 1, 4, 7, 10, 14, 16, 19, 21,
}; };
for (size_t i = 0; i < kNRows * kNCols; ++i) { for (size_t i = 0; i < n_samples * n_features; ++i) {
ASSERT_EQ(solution[i], h_accessor.gidx_iter[i]); auto fidx = i % n_features;
ASSERT_EQ(solution[i], h_accessor.gidx_iter[i] + h_accessor.feature_segments[fidx]);
} }
} }
@ -263,12 +264,12 @@ class EllpackPageTest : public testing::TestWithParam<float> {
ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid); ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid);
ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows); ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows);
ASSERT_EQ(from_sparse_page->gidx_buffer.size(), from_ghist->gidx_buffer.size()); ASSERT_EQ(from_sparse_page->gidx_buffer.size(), from_ghist->gidx_buffer.size());
ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols());
std::vector<common::CompressedByteT> h_gidx_from_sparse, h_gidx_from_ghist; std::vector<common::CompressedByteT> h_gidx_from_sparse, h_gidx_from_ghist;
auto from_ghist_acc = from_ghist->GetHostAccessor(&gpu_ctx, &h_gidx_from_ghist); auto from_ghist_acc = from_ghist->GetHostAccessor(&gpu_ctx, &h_gidx_from_ghist);
auto from_sparse_acc = from_sparse_page->GetHostAccessor(&gpu_ctx, &h_gidx_from_sparse); auto from_sparse_acc = from_sparse_page->GetHostAccessor(&gpu_ctx, &h_gidx_from_sparse);
ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols());
for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) { for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) {
EXPECT_EQ(from_ghist_acc.gidx_iter[i], from_sparse_acc.gidx_iter[i]); ASSERT_EQ(from_ghist_acc.gidx_iter[i], from_sparse_acc.gidx_iter[i]);
} }
} }
} }

View File

@ -106,9 +106,11 @@ TEST(IterativeDeviceDMatrix, RowMajor) {
common::Span<float const> s_data{static_cast<float const*>(loaded.data), cols * rows}; common::Span<float const> s_data{static_cast<float const*>(loaded.data), cols * rows};
dh::CopyDeviceSpanToVector(&h_data, s_data); dh::CopyDeviceSpanToVector(&h_data, s_data);
for(auto i = 0ull; i < rows * cols; i++) { auto cut_ptr = h_accessor.feature_segments;
for (auto i = 0ull; i < rows * cols; i++) {
int column_idx = i % cols; int column_idx = i % cols;
EXPECT_EQ(impl->Cuts().SearchBin(h_data[i], column_idx), h_accessor.gidx_iter[i]); EXPECT_EQ(impl->Cuts().SearchBin(h_data[i], column_idx),
h_accessor.gidx_iter[i] + cut_ptr[column_idx]);
} }
EXPECT_EQ(m.Info().num_col_, cols); EXPECT_EQ(m.Info().num_col_, cols);
EXPECT_EQ(m.Info().num_row_, rows); EXPECT_EQ(m.Info().num_row_, rows);

View File

@ -12,6 +12,7 @@
#include "../../../src/data/file_iterator.h" #include "../../../src/data/file_iterator.h"
#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/simple_dmatrix.h"
#include "../../../src/data/sparse_page_dmatrix.h" #include "../../../src/data/sparse_page_dmatrix.h"
#include "../../../src/tree/param.h" // for TrainParam
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" #include "../helpers.h"
@ -115,6 +116,47 @@ TEST(SparsePageDMatrix, RetainSparsePage) {
TestRetainPage<SortedCSCPage>(); TestRetainPage<SortedCSCPage>();
} }
class TestGradientIndexExt : public ::testing::TestWithParam<bool> {
protected:
void Run(bool is_dense) {
constexpr bst_idx_t kRows = 64;
constexpr size_t kCols = 2;
float sparsity = is_dense ? 0.0 : 0.4;
bst_bin_t n_bins = 16;
Context ctx;
auto p_ext_fmat =
RandomDataGenerator{kRows, kCols, sparsity}.Batches(4).GenerateSparsePageDMatrix("temp",
true);
auto cuts = common::SketchOnDMatrix(&ctx, p_ext_fmat.get(), n_bins, false, {});
std::vector<std::unique_ptr<GHistIndexMatrix>> pages;
for (auto const &page : p_ext_fmat->GetBatches<SparsePage>()) {
pages.emplace_back(std::make_unique<GHistIndexMatrix>(
page, common::Span<FeatureType const>{}, cuts, n_bins, is_dense, 0.8, ctx.Threads()));
}
std::int32_t k = 0;
for (auto const &page : p_ext_fmat->GetBatches<GHistIndexMatrix>(
&ctx, BatchParam{n_bins, tree::TrainParam::DftSparseThreshold()})) {
auto const &from_sparse = pages[k];
ASSERT_TRUE(std::equal(page.index.begin(), page.index.end(), from_sparse->index.begin()));
if (is_dense) {
ASSERT_TRUE(std::equal(page.index.Offset(), page.index.Offset() + kCols,
from_sparse->index.Offset()));
} else {
ASSERT_FALSE(page.index.Offset());
ASSERT_FALSE(from_sparse->index.Offset());
}
ASSERT_TRUE(
std::equal(page.row_ptr.cbegin(), page.row_ptr.cend(), from_sparse->row_ptr.cbegin()));
++k;
}
}
};
TEST_P(TestGradientIndexExt, Basic) { this->Run(this->GetParam()); }
INSTANTIATE_TEST_SUITE_P(SparsePageDMatrix, TestGradientIndexExt, testing::Bool());
// Test GHistIndexMatrix can avoid loading sparse page after the initialization. // Test GHistIndexMatrix can avoid loading sparse page after the initialization.
TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) { TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;

View File

@ -40,10 +40,9 @@ TEST(SparsePageDMatrix, EllpackPage) {
TEST(SparsePageDMatrix, EllpackSkipSparsePage) { TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
// Test Ellpack can avoid loading sparse page after the initialization. // Test Ellpack can avoid loading sparse page after the initialization.
dmlc::TemporaryDirectory tmpdir;
std::size_t n_batches = 6; std::size_t n_batches = 6;
auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix( auto Xy =
tmpdir.path + "/", true); RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix("temp", true);
auto ctx = MakeCUDACtx(0); auto ctx = MakeCUDACtx(0);
auto cpu = ctx.MakeCPU(); auto cpu = ctx.MakeCPU();
bst_bin_t n_bins{256}; bst_bin_t n_bins{256};
@ -117,7 +116,6 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
TEST(SparsePageDMatrix, MultipleEllpackPages) { TEST(SparsePageDMatrix, MultipleEllpackPages) {
auto ctx = MakeCUDACtx(0); auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
dmlc::TemporaryDirectory tmpdir;
auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true); auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
// Loop over the batches and count the records // Loop over the batches and count the records
@ -155,18 +153,24 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
auto const& d_src = (*it).Impl()->gidx_buffer; auto const& d_src = (*it).Impl()->gidx_buffer;
dh::safe_cuda(cudaMemcpyAsync(d_dst, d_src.data(), d_src.size_bytes(), cudaMemcpyDefault)); dh::safe_cuda(cudaMemcpyAsync(d_dst, d_src.data(), d_src.size_bytes(), cudaMemcpyDefault));
} }
ASSERT_GE(iterators.size(), 2); ASSERT_EQ(iterators.size(), 8);
for (size_t i = 0; i < iterators.size(); ++i) { for (size_t i = 0; i < iterators.size(); ++i) {
std::vector<common::CompressedByteT> h_buf; std::vector<common::CompressedByteT> h_buf;
[[maybe_unused]] auto h_acc = (*iterators[i]).Impl()->GetHostAccessor(&ctx, &h_buf); [[maybe_unused]] auto h_acc = (*iterators[i]).Impl()->GetHostAccessor(&ctx, &h_buf);
ASSERT_EQ(h_buf, gidx_buffers.at(i).HostVector()); ASSERT_EQ(h_buf, gidx_buffers.at(i).HostVector());
ASSERT_EQ(iterators[i].use_count(), 1); // The last page is still kept in the DMatrix until Reset is called.
if (i == iterators.size() - 1) {
ASSERT_EQ(iterators[i].use_count(), 2);
} else {
ASSERT_EQ(iterators[i].use_count(), 1);
}
} }
// make sure it's const and the caller can not modify the content of page. // make sure it's const and the caller can not modify the content of page.
for (auto& page : m->GetBatches<EllpackPage>(&ctx, param)) { for (auto& page : m->GetBatches<EllpackPage>(&ctx, param)) {
static_assert(std::is_const_v<std::remove_reference_t<decltype(page)>>); static_assert(std::is_const_v<std::remove_reference_t<decltype(page)>>);
break;
} }
// The above iteration clears out all references inside DMatrix. // The above iteration clears out all references inside DMatrix.
@ -190,13 +194,10 @@ class TestEllpackPageExt : public ::testing::TestWithParam<std::tuple<bool, bool
auto p_fmat = RandomDataGenerator{kRows, kCols, sparsity}.GenerateDMatrix(true); auto p_fmat = RandomDataGenerator{kRows, kCols, sparsity}.GenerateDMatrix(true);
// Create a DMatrix with multiple batches. // Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
auto prefix = tmpdir.path + "/cache";
auto p_ext_fmat = RandomDataGenerator{kRows, kCols, sparsity} auto p_ext_fmat = RandomDataGenerator{kRows, kCols, sparsity}
.Batches(4) .Batches(4)
.OnHost(on_host) .OnHost(on_host)
.GenerateSparsePageDMatrix(prefix, true); .GenerateSparsePageDMatrix("temp", true);
auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()}; auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()};
auto impl = (*p_fmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl(); auto impl = (*p_fmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();

View File

@ -73,13 +73,13 @@ TEST(Histogram, SubtractionTrack) {
histogram.AllocateHistograms(&ctx, {0, 1, 2}); histogram.AllocateHistograms(&ctx, {0, 1, 2});
GPUExpandEntry root; GPUExpandEntry root;
root.nid = 0; root.nid = 0;
auto need_build = histogram.SubtractHist({root}, {0}, {1}); auto need_build = histogram.SubtractHist(&ctx, {root}, {0}, {1});
std::vector<GPUExpandEntry> candidates(2); std::vector<GPUExpandEntry> candidates(2);
candidates[0].nid = 1; candidates[0].nid = 1;
candidates[1].nid = 2; candidates[1].nid = 2;
need_build = histogram.SubtractHist(candidates, {3, 5}, {4, 6}); need_build = histogram.SubtractHist(&ctx, candidates, {3, 5}, {4, 6});
ASSERT_EQ(need_build.size(), 2); ASSERT_EQ(need_build.size(), 2);
ASSERT_EQ(need_build[0], 4); ASSERT_EQ(need_build[0], 4);
ASSERT_EQ(need_build[1], 6); ASSERT_EQ(need_build[1], 6);

View File

@ -33,9 +33,9 @@ void TestUpdatePositionBatch() {
std::vector<int> extra_data = {0}; std::vector<int> extra_data = {0};
// Send the first five training instances to the right node // Send the first five training instances to the right node
// and the second 5 to the left node // and the second 5 to the left node
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { rp.UpdatePositionBatch(
return ridx > 4; &ctx, {0}, {1}, {2}, extra_data,
}); [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx > 4; });
rows = rp.GetRowsHost(1); rows = rp.GetRowsHost(1);
for (auto r : rows) { for (auto r : rows) {
EXPECT_GT(r, 4); EXPECT_GT(r, 4);
@ -46,9 +46,9 @@ void TestUpdatePositionBatch() {
} }
// Split the left node again // Split the left node again
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) { rp.UpdatePositionBatch(
return ridx < 7; &ctx, {1}, {3}, {4}, extra_data,
}); [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx < 7; });
EXPECT_EQ(rp.GetRows(3).size(), 2); EXPECT_EQ(rp.GetRows(3).size(), 2);
EXPECT_EQ(rp.GetRows(4).size(), 3); EXPECT_EQ(rp.GetRows(4).size(), 3);
} }
@ -56,6 +56,7 @@ void TestUpdatePositionBatch() {
TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); } TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); }
void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Segment>& segments) { void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Segment>& segments) {
auto ctx = MakeCUDACtx(0);
thrust::device_vector<cuda_impl::RowIndexT> ridx = ridx_in; thrust::device_vector<cuda_impl::RowIndexT> ridx = ridx_in;
thrust::device_vector<cuda_impl::RowIndexT> ridx_tmp(ridx_in.size()); thrust::device_vector<cuda_impl::RowIndexT> ridx_tmp(ridx_in.size());
thrust::device_vector<cuda_impl::RowIndexT> counts(segments.size()); thrust::device_vector<cuda_impl::RowIndexT> counts(segments.size());
@ -74,7 +75,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault, h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
nullptr)); nullptr));
dh::DeviceUVector<int8_t> tmp; dh::DeviceUVector<int8_t> tmp;
SortPositionBatch<decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), SortPositionBatch<decltype(op), int>(&ctx, dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op, dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op,
&tmp); &tmp);
@ -145,7 +146,7 @@ void TestExternalMemory() {
std::vector<RegTree::Node> splits{tree[0]}; std::vector<RegTree::Node> splits{tree[0]};
auto acc = page.Impl()->GetDeviceAccessor(&ctx); auto acc = page.Impl()->GetDeviceAccessor(&ctx);
partitioners.back()->UpdatePositionBatch( partitioners.back()->UpdatePositionBatch(
{0}, {1}, {2}, splits, &ctx, {0}, {1}, {2}, splits,
[=] __device__(bst_idx_t ridx, std::int32_t nidx_in_batch, RegTree::Node const& node) { [=] __device__(bst_idx_t ridx, std::int32_t nidx_in_batch, RegTree::Node const& node) {
auto fvalue = acc.GetFvalue(ridx, node.SplitIndex()); auto fvalue = acc.GetFvalue(ridx, node.SplitIndex());
return fvalue <= node.SplitCond(); return fvalue <= node.SplitCond();