diff --git a/.gitignore b/.gitignore index 8a2df2a9b..88996f330 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ ipch *.filters *.user *log +rmm_log.txt Debug *suo .Rhistory diff --git a/src/common/cuda_pinned_allocator.h b/src/common/cuda_pinned_allocator.h index 90c34668a..4d7fa3158 100644 --- a/src/common/cuda_pinned_allocator.h +++ b/src/common/cuda_pinned_allocator.h @@ -10,6 +10,7 @@ #include // for size_t #include // for numeric_limits +#include // for bad_array_new_length #include "common.h" @@ -28,14 +29,14 @@ struct PinnedAllocPolicy { using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using value_type = T; // NOLINT: The type of the elements in the allocator - size_type max_size() const { // NOLINT + [[nodiscard]] constexpr size_type max_size() const { // NOLINT return std::numeric_limits::max() / sizeof(value_type); } [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT if (cnt > this->max_size()) { - throw std::bad_alloc{}; - } // end if + throw std::bad_array_new_length{}; + } pointer result(nullptr); dh::safe_cuda(cudaMallocHost(reinterpret_cast(&result), cnt * sizeof(value_type))); @@ -52,14 +53,14 @@ struct ManagedAllocPolicy { using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using value_type = T; // NOLINT: The type of the elements in the allocator - size_type max_size() const { // NOLINT + [[nodiscard]] constexpr size_type max_size() const { // NOLINT return std::numeric_limits::max() / sizeof(value_type); } [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT if (cnt > this->max_size()) { - throw std::bad_alloc{}; - } // end if + throw std::bad_array_new_length{}; + } pointer result(nullptr); dh::safe_cuda(cudaMallocManaged(reinterpret_cast(&result), cnt * sizeof(value_type))); @@ -78,14 +79,14 @@ struct SamAllocPolicy { using size_type = std::size_t; // NOLINT: The type used for the size of the allocation using value_type = T; // NOLINT: The type of the elements in the allocator - size_type max_size() const { // NOLINT + [[nodiscard]] constexpr size_type max_size() const { // NOLINT return std::numeric_limits::max() / sizeof(value_type); } [[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT if (cnt > this->max_size()) { - throw std::bad_alloc{}; - } // end if + throw std::bad_array_new_length{}; + } size_type n_bytes = cnt * sizeof(value_type); pointer result = reinterpret_cast(std::malloc(n_bytes)); @@ -139,10 +140,10 @@ class CudaHostAllocatorImpl : public Policy { }; template -using PinnedAllocator = CudaHostAllocatorImpl; // NOLINT +using PinnedAllocator = CudaHostAllocatorImpl; template -using ManagedAllocator = CudaHostAllocatorImpl; // NOLINT +using ManagedAllocator = CudaHostAllocatorImpl; template using SamAllocator = CudaHostAllocatorImpl; diff --git a/src/common/device_vector.cuh b/src/common/device_vector.cuh index 46265c765..b2065d333 100644 --- a/src/common/device_vector.cuh +++ b/src/common/device_vector.cuh @@ -177,8 +177,10 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { pointer thrust_ptr; if (use_cub_allocator_) { T *raw_ptr{nullptr}; + // NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection) auto errc = GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast(&raw_ptr), n * sizeof(T)); + // NOLINTEND(clang-analyzer-unix.BlockInCriticalSection) if (errc != cudaSuccess) { detail::ThrowOOMError("Caching allocator", n * sizeof(T)); } @@ -290,13 +292,13 @@ LoggingResource *GlobalLoggingResource(); /** * @brief Container class that doesn't initialize the data when RMM is used. */ -template -class DeviceUVector { +template +class DeviceUVectorImpl { private: #if defined(XGBOOST_USE_RMM) rmm::device_uvector data_{0, rmm::cuda_stream_per_thread, GlobalLoggingResource()}; #else - ::dh::device_vector data_; + std::conditional_t, ::dh::device_vector> data_; #endif // defined(XGBOOST_USE_RMM) public: @@ -307,12 +309,12 @@ class DeviceUVector { using const_reference = value_type const &; // NOLINT public: - DeviceUVector() = default; - explicit DeviceUVector(std::size_t n) { this->resize(n); } - DeviceUVector(DeviceUVector const &that) = delete; - DeviceUVector &operator=(DeviceUVector const &that) = delete; - DeviceUVector(DeviceUVector &&that) = default; - DeviceUVector &operator=(DeviceUVector &&that) = default; + DeviceUVectorImpl() = default; + explicit DeviceUVectorImpl(std::size_t n) { this->resize(n); } + DeviceUVectorImpl(DeviceUVectorImpl const &that) = delete; + DeviceUVectorImpl &operator=(DeviceUVectorImpl const &that) = delete; + DeviceUVectorImpl(DeviceUVectorImpl &&that) = default; + DeviceUVectorImpl &operator=(DeviceUVectorImpl &&that) = default; void resize(std::size_t n) { // NOLINT #if defined(XGBOOST_USE_RMM) @@ -356,4 +358,10 @@ class DeviceUVector { [[nodiscard]] auto data() { return thrust::raw_pointer_cast(data_.data()); } // NOLINT [[nodiscard]] auto data() const { return thrust::raw_pointer_cast(data_.data()); } // NOLINT }; + +template +using DeviceUVector = DeviceUVectorImpl; + +template +using CachingDeviceUVector = DeviceUVectorImpl; } // namespace dh diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index bc012fd9b..18747ab99 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors * \file device_adapter.cuh */ #ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_ @@ -7,13 +7,12 @@ #include // for make_counting_iterator #include // for none_of -#include // for size_t +#include // for size_t #include -#include #include +#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" -#include "../common/math.h" #include "adapter.h" #include "array_interface.h" @@ -208,11 +207,12 @@ class CupyAdapter : public detail::SingleBatchDataIter { // Returns maximum row length template -bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span offset, DeviceOrd device, - float missing) { +bst_idx_t GetRowCounts(Context const* ctx, const AdapterBatchT batch, + common::Span offset, DeviceOrd device, float missing) { dh::safe_cuda(cudaSetDevice(device.ordinal)); IsValidFunctor is_valid(missing); - dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes())); + dh::safe_cuda( + cudaMemsetAsync(offset.data(), '\0', offset.size_bytes(), ctx->CUDACtx()->Stream())); auto n_samples = batch.NumRows(); bst_feature_t n_features = batch.NumCols(); @@ -230,7 +230,7 @@ bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span offset } // Count elements per row - dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) { + dh::LaunchN(n_samples * stride, ctx->CUDACtx()->Stream(), [=] __device__(std::size_t idx) { bst_idx_t cnt{0}; auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride); SPAN_CHECK(ridx < n_samples); @@ -244,9 +244,8 @@ bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span offset &offset[ridx]), static_cast(cnt)); // NOLINT }); - dh::XGBCachingDeviceAllocator alloc; bst_idx_t row_stride = - dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), + dh::Reduce(ctx->CUDACtx()->CTP(), thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()) + offset.size(), static_cast(0), thrust::maximum()); return row_stride; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 8f8ab0af7..dc3f10c4e 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -9,6 +9,7 @@ #include // for move #include // for vector +#include "../common/algorithm.cuh" // for InclusiveScan #include "../common/categorical.h" #include "../common/cuda_context.cuh" #include "../common/cuda_rt_utils.h" // for SetDevice @@ -45,6 +46,7 @@ void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id) [[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); } // Bin each input data entry, store the bin indices in compressed form. +template __global__ void CompressBinEllpackKernel( common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer, // gidx_buffer @@ -73,12 +75,11 @@ __global__ void CompressBinEllpackKernel( // Assigning the bin in current entry. // S.t.: fvalue < feature_cuts[bin] if (is_cat) { - auto it = dh::MakeTransformIterator( - feature_cuts, [](float v) { return common::AsCat(v); }); + auto it = + dh::MakeTransformIterator(feature_cuts, [](float v) { return common::AsCat(v); }); bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it; } else { - bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, - fvalue) - + bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts, fvalue) - feature_cuts; } @@ -86,24 +87,54 @@ __global__ void CompressBinEllpackKernel( bin = ncuts - 1; } // Add the number of bins in previous features. - bin += cut_ptrs[feature]; + if (!kIsDense) { + bin += cut_ptrs[feature]; + } } // Write to gidx buffer. wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } -[[nodiscard]] std::size_t CalcNumSymbols(Context const*, bool /*is_dense*/, +namespace { +// Calculate the number of symbols for the compressed ellpack. Similar to what the CPU +// implementation does, we compress the dense data by subtracting the bin values with the +// starting bin of its feature. +[[nodiscard]] std::size_t CalcNumSymbols(Context const* ctx, bool is_dense, std::shared_ptr cuts) { - // Return the total number of symbols (total number of bins plus 1 for not found) - return cuts->cut_values_.Size() + 1; + // Cut values can be empty when the input data is empty. + if (!is_dense || cuts->cut_values_.Empty()) { + // Return the total number of symbols (total number of bins plus 1 for not found) + return cuts->cut_values_.Size() + 1; + } + + cuts->cut_ptrs_.SetDevice(ctx->Device()); + common::Span dptrs = cuts->cut_ptrs_.ConstDeviceSpan(); + auto cuctx = ctx->CUDACtx(); + using PtrT = typename decltype(dptrs)::value_type; + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(1ul), + [=] XGBOOST_DEVICE(std::size_t i) { return dptrs[i] - dptrs[i - 1]; }); + CHECK_GE(dptrs.size(), 2); + auto max_it = thrust::max_element(cuctx->CTP(), it, it + dptrs.size() - 1); + dh::CachingDeviceUVector max_element(1); + auto d_me = max_element.data(); + dh::LaunchN(1, cuctx->Stream(), [=] XGBOOST_DEVICE(std::size_t i) { d_me[i] = *max_it; }); + PtrT h_me{0}; + dh::safe_cuda( + cudaMemcpyAsync(&h_me, d_me, sizeof(PtrT), cudaMemcpyDeviceToHost, cuctx->Stream())); + cuctx->Stream().Sync(); + // No missing, hence no null value, hence no + 1 symbol. + // FIXME(jiamingy): When we extend this to use a sparsity threshold, +1 is needed back. + return h_me; } +} // namespace // Construct an ELLPACK matrix with the given number of empty rows. EllpackPageImpl::EllpackPageImpl(Context const* ctx, std::shared_ptr cuts, bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows) - : is_dense(is_dense), - cuts_(std::move(cuts)), + : is_dense{is_dense}, + cuts_{std::move(cuts)}, row_stride{row_stride}, n_rows{n_rows}, n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} { @@ -117,11 +148,14 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, std::shared_ptr cuts, const SparsePage& page, bool is_dense, size_t row_stride, common::Span feature_types) - : cuts_(std::move(cuts)), - is_dense(is_dense), - n_rows(page.Size()), - row_stride(row_stride), - n_symbols_(CalcNumSymbols(ctx, this->is_dense, this->cuts_)) { + : cuts_{std::move(cuts)}, + is_dense{is_dense}, + n_rows{page.Size()}, + row_stride{row_stride}, + n_symbols_{CalcNumSymbols(ctx, this->is_dense, this->cuts_)} { + monitor_.Init("ellpack_page"); + common::SetDevice(ctx->Ordinal()); + this->InitCompressedData(ctx); this->CreateHistIndices(ctx, page, feature_types); } @@ -147,8 +181,8 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* p_fmat, const Batc auto ft = p_fmat->Info().feature_types.ConstDeviceSpan(); monitor_.Start("BinningCompression"); CHECK(p_fmat->SingleColBlock()); - for (const auto& batch : p_fmat->GetBatches()) { - CreateHistIndices(ctx, batch, ft); + for (auto const& page : p_fmat->GetBatches()) { + this->CreateHistIndices(ctx, page, ft); } monitor_.Stop("BinningCompression"); } @@ -186,6 +220,9 @@ struct WriteCompressedEllpackFunctor { } else { bin_idx = accessor.SearchBin(e.value, e.column_idx); } + if (kIsDense) { + bin_idx -= accessor.feature_segments[e.column_idx]; + } writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); } return 0; @@ -257,7 +294,8 @@ void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch, common::InclusiveScan(ctx, key_value_index_iter, out, TupleScanOp{}, batch.Size()); } -void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span row_counts) { +void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, + common::Span row_counts) { // Write the null values auto device_accessor = dst->GetDeviceAccessor(ctx); common::CompressedBufferWriter writer(dst->NumSymbols()); @@ -276,7 +314,7 @@ void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, - bool is_dense, common::Span row_counts_span, + bool is_dense, common::Span row_counts_span, common::Span feature_types, size_t row_stride, bst_idx_t n_rows, std::shared_ptr cuts) @@ -292,10 +330,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, AdapterBatch batch, float m WriteNullValues(ctx, this, row_counts_span); } -#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ - template EllpackPageImpl::EllpackPageImpl( \ - Context const* ctx, __BATCH_T batch, float missing, bool is_dense, \ - common::Span row_counts_span, common::Span feature_types, \ +#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \ + template EllpackPageImpl::EllpackPageImpl( \ + Context const* ctx, __BATCH_T batch, float missing, bool is_dense, \ + common::Span row_counts_span, common::Span feature_types, \ size_t row_stride, size_t n_rows, std::shared_ptr cuts); ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) @@ -303,18 +341,15 @@ ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) namespace { void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page, - common::Span d_row_ptr, size_t row_stride, - common::CompressedByteT* d_compressed_buffer, size_t null) { + common::Span d_row_ptr, bst_idx_t row_stride, + bst_bin_t null, bst_idx_t n_symbols, + common::CompressedByteT* d_compressed_buffer) { dh::device_vector data(page.index.begin(), page.index.end()); auto d_data = dh::ToSpan(data); - dh::device_vector csc_indptr(page.index.Offset(), - page.index.Offset() + page.index.OffsetSize()); - auto d_csc_indptr = dh::ToSpan(csc_indptr); - + // GPU employs the same dense compression as CPU, no need to handle page.index.Offset() auto bin_type = page.index.GetBinTypeSize(); - common::CompressedBufferWriter writer{page.cut.TotalBins() + - static_cast(1)}; // +1 for null value + common::CompressedBufferWriter writer{n_symbols}; auto cuctx = ctx->CUDACtx(); dh::LaunchN(row_stride * page.Size(), cuctx->Stream(), [=] __device__(bst_idx_t idx) mutable { @@ -323,22 +358,17 @@ void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page, auto r_begin = d_row_ptr[ridx]; auto r_end = d_row_ptr[ridx + 1]; - size_t r_size = r_end - r_begin; + auto r_size = r_end - r_begin; if (ifeature >= r_size) { writer.AtomicWriteSymbol(d_compressed_buffer, null, idx); return; } - bst_idx_t offset = 0; - if (!d_csc_indptr.empty()) { - // is dense, ifeature is the actual feature index. - offset = d_csc_indptr[ifeature]; - } common::cuda::DispatchBinType(bin_type, [&](auto t) { using T = decltype(t); auto ptr = reinterpret_cast(d_data.data()); - auto bin_idx = ptr[r_begin + ifeature] + offset; + auto bin_idx = ptr[r_begin + ifeature]; writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx); }); }); @@ -348,14 +378,16 @@ void CopyGHistToEllpack(Context const* ctx, GHistIndexMatrix const& page, EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, common::Span ft) : is_dense{page.IsDense()}, + row_stride{[&] { + auto it = common::MakeIndexTransformIter( + [&](bst_idx_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); + return *std::max_element(it, it + page.Size()); + }()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{std::make_shared(page.cut)}, n_symbols_{CalcNumSymbols(ctx, page.IsDense(), cuts_)} { - auto it = common::MakeIndexTransformIter( - [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); - row_stride = *std::max_element(it, it + page.Size()); - + this->monitor_.Init("ellpack_page"); CHECK(ctx->IsCUDA()); this->InitCompressedData(ctx); @@ -367,12 +399,17 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream())); auto accessor = this->GetDeviceAccessor(ctx, ft); - auto null = accessor.NullValue(); this->monitor_.Start("CopyGHistToEllpack"); - CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, d_compressed_buffer, null); + CopyGHistToEllpack(ctx, page, d_row_ptr, row_stride, accessor.NullValue(), this->NumSymbols(), + d_compressed_buffer); this->monitor_.Stop("CopyGHistToEllpack"); } +EllpackPageImpl::~EllpackPageImpl() noexcept(false) { + // Sync the stream to make sure all running CUDA kernels finish before deallocation. + dh::DefaultStream().Sync(); +} + // A functor that copies the data from one EllpackPage to another. struct CopyPage { common::CompressedBufferWriter cbw; @@ -385,7 +422,7 @@ struct CopyPage { : cbw{dst->NumSymbols()}, dst_data_d{dst->gidx_buffer.data()}, src_iterator_d{src->gidx_buffer.data(), src->NumSymbols()}, - offset(offset) {} + offset{offset} {} __device__ void operator()(size_t element_id) { cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], element_id + offset); @@ -393,7 +430,7 @@ struct CopyPage { }; // Copy the data from the given EllpackPage to the current page. -size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) { +bst_idx_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) { monitor_.Start(__func__); bst_idx_t num_elements = page->n_rows * page->row_stride; CHECK_EQ(this->row_stride, page->row_stride); @@ -482,10 +519,12 @@ void EllpackPageImpl::InitCompressedData(Context const* ctx) { void EllpackPageImpl::CreateHistIndices(Context const* ctx, const SparsePage& row_batch, common::Span feature_types) { - if (row_batch.Size() == 0) return; - std::uint32_t null_gidx_value = NumSymbols() - 1; + if (row_batch.Size() == 0) { + return; + } + auto null_gidx_value = this->GetDeviceAccessor(ctx, feature_types).NullValue(); - const auto& offset_vec = row_batch.offset.ConstHostVector(); + auto const& offset_vec = row_batch.offset.ConstHostVector(); // bin and compress entries in batches of rows size_t gpu_batch_nrows = @@ -504,35 +543,46 @@ void EllpackPageImpl::CreateHistIndices(Context const* ctx, const auto ent_cnt_end = offset_vec[batch_row_end]; /*! \brief row offset in SparsePage (the input data). */ - dh::device_vector row_ptrs(batch_nrows + 1); - thrust::copy(offset_vec.data() + batch_row_begin, - offset_vec.data() + batch_row_end + 1, row_ptrs.begin()); + using OffT = typename std::remove_reference_t::value_type; + dh::DeviceUVector row_ptrs(batch_nrows + 1); + auto size = + std::distance(offset_vec.data() + batch_row_begin, offset_vec.data() + batch_row_end + 1); + dh::safe_cuda(cudaMemcpyAsync(row_ptrs.data(), offset_vec.data() + batch_row_begin, + size * sizeof(OffT), cudaMemcpyDefault, + ctx->CUDACtx()->Stream())); // number of entries in this batch. size_t n_entries = ent_cnt_end - ent_cnt_begin; - dh::device_vector entries_d(n_entries); + dh::DeviceUVector entries_d(n_entries); // copy data entries to device. if (row_batch.data.DeviceCanRead()) { auto const& d_data = row_batch.data.ConstDeviceSpan(); - dh::safe_cuda(cudaMemcpyAsync( - entries_d.data().get(), d_data.data() + ent_cnt_begin, - n_entries * sizeof(Entry), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpyAsync(entries_d.data(), d_data.data() + ent_cnt_begin, + n_entries * sizeof(Entry), cudaMemcpyDefault, + ctx->CUDACtx()->Stream())); } else { const std::vector& data_vec = row_batch.data.ConstHostVector(); - dh::safe_cuda(cudaMemcpyAsync( - entries_d.data().get(), data_vec.data() + ent_cnt_begin, - n_entries * sizeof(Entry), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpyAsync(entries_d.data(), data_vec.data() + ent_cnt_begin, + n_entries * sizeof(Entry), cudaMemcpyDefault, + ctx->CUDACtx()->Stream())); } const dim3 block3(32, 8, 1); // 256 threads const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(row_stride, block3.y), 1); auto device_accessor = this->GetDeviceAccessor(ctx); - dh::LaunchKernel{grid3, block3}( // NOLINT - CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), gidx_buffer.data(), - row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), - device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, - row_stride, null_gidx_value); + auto launcher = [&](auto kernel) { + dh::LaunchKernel{grid3, block3, 0, ctx->CUDACtx()->Stream()}( // NOLINT + kernel, common::CompressedBufferWriter(this->NumSymbols()), gidx_buffer.data(), + row_ptrs.data(), entries_d.data(), device_accessor.gidx_fvalue_map.data(), + device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, + row_stride, null_gidx_value); + }; + if (this->IsDense()) { + launcher(CompressBinEllpackKernel); + } else { + launcher(CompressBinEllpackKernel); + } } } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index a9766e347..78641c5ac 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -85,6 +85,7 @@ struct EllpackDeviceAccessor { bst_bin_t gidx = -1; if (is_dense) { gidx = gidx_iter[row_begin + fidx]; + gidx += this->feature_segments[fidx]; } else { gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx], feature_segments[fidx + 1]); @@ -175,7 +176,7 @@ class EllpackPageImpl { */ template explicit EllpackPageImpl(Context const* ctx, AdapterBatch batch, float missing, bool is_dense, - common::Span row_counts_span, + common::Span row_counts_span, common::Span feature_types, size_t row_stride, bst_idx_t n_rows, std::shared_ptr cuts); /** @@ -184,6 +185,14 @@ class EllpackPageImpl { explicit EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, common::Span ft); + EllpackPageImpl(EllpackPageImpl const& that) = delete; + EllpackPageImpl& operator=(EllpackPageImpl const& that) = delete; + + EllpackPageImpl(EllpackPageImpl&& that) = default; + EllpackPageImpl& operator=(EllpackPageImpl&& that) = default; + + ~EllpackPageImpl() noexcept(false); + /** * @brief Copy the elements of the given ELLPACK page into this page. * diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 5f6b50f50..9b1de14cb 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -9,16 +9,17 @@ #include // for accumulate #include // for move -#include "../common/common.h" // for safe_cuda -#include "../common/ref_resource_view.cuh" -#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream -#include "../common/resource.cuh" // for PrivateCudaMmapConstStream -#include "ellpack_page.cuh" // for EllpackPageImpl -#include "ellpack_page.h" // for EllpackPage +#include "../common/common.h" // for safe_cuda +#include "../common/cuda_rt_utils.h" // for SetDevice +#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream +#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc +#include "../common/resource.cuh" // for PrivateCudaMmapConstStream +#include "../common/transform_iterator.h" // for MakeIndexTransformIter +#include "ellpack_page.cuh" // for EllpackPageImpl +#include "ellpack_page.h" // for EllpackPage #include "ellpack_page_source.h" #include "proxy_dmatrix.cuh" // for Dispatch #include "xgboost/base.h" // for bst_idx_t -#include "../common/transform_iterator.h" // for MakeIndexTransformIter namespace xgboost::data { /** @@ -201,7 +202,7 @@ EllpackMmapStreamPolicy::CreateReader(StringVi */ template void EllpackPageSourceImpl::Fetch() { - dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); + common::SetDevice(this->Device().ordinal); if (!this->ReadCache()) { if (this->count_ != 0 && !this->sync_) { // source is initialized to be the 0th page during construction, so when count_ is 0 @@ -235,7 +236,7 @@ EllpackPageSourceImpl> */ template void ExtEllpackPageSourceImpl::Fetch() { - dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); + common::SetDevice(this->Device().ordinal); if (!this->ReadCache()) { auto iter = this->source_->Iter(); CHECK_EQ(this->count_, iter); @@ -250,7 +251,8 @@ void ExtEllpackPageSourceImpl::Fetch() { dh::device_vector row_counts(n_samples + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); cuda_impl::Dispatch(proxy_, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_); + return GetRowCounts(this->ctx_, value, row_counts_span, dh::GetDevice(this->ctx_), + this->missing_); }); this->page_.reset(new EllpackPage{}); diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 6b9f571be..14d3c7c64 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -94,12 +94,12 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, - common::HistogramCuts cuts, int32_t max_bins_per_feat, - bool isDense, double sparse_thresh, int32_t n_threads) + common::HistogramCuts cuts, bst_bin_t max_bins_per_feat, + bool is_dense, double sparse_thresh, std::int32_t n_threads) : cut{std::move(cuts)}, max_numeric_bins_per_feat{max_bins_per_feat}, base_rowid{batch.base_rowid}, - isDense_{isDense} { + isDense_{is_dense} { CHECK_GE(n_threads, 1); CHECK_EQ(row_ptr.size(), 0); row_ptr = common::MakeFixedVecWithMalloc(batch.Size() + 1, std::size_t{0}); diff --git a/src/data/gradient_index.cu b/src/data/gradient_index.cu index f8c8f8d48..ebdc99051 100644 --- a/src/data/gradient_index.cu +++ b/src/data/gradient_index.cu @@ -12,9 +12,9 @@ namespace xgboost { // Similar to GHistIndexMatrix::SetIndexData, but without the need for adaptor or bin // searching. Is there a way to unify the code? -template +template void SetIndexData(Context const* ctx, EllpackPageImpl const* page, - std::vector* p_hit_count_tloc, CompressOffset&& get_offset, + std::vector* p_hit_count_tloc, DecompressOffset&& get_offset, GHistIndexMatrix* out) { std::vector h_gidx_buffer; auto accessor = page->GetHostAccessor(ctx, &h_gidx_buffer); @@ -35,8 +35,8 @@ void SetIndexData(Context const* ctx, EllpackPageImpl const* page, for (size_t j = 0; j < r_size; ++j) { auto bin_idx = accessor.gidx_iter[in_rbegin + j]; assert(bin_idx != kNull); - index_data_span[out_rbegin + j] = get_offset(bin_idx, j); - ++hit_count_tloc[tid * n_bins_total + bin_idx]; + index_data_span[out_rbegin + j] = bin_idx; + ++hit_count_tloc[tid * n_bins_total + get_offset(bin_idx, j)]; } }); } @@ -86,10 +86,13 @@ GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info, auto n_bins_total = page->Cuts().TotalBins(); GetRowPtrFromEllpack(ctx, page, &this->row_ptr); - if (page->is_dense) { + if (page->IsDense()) { + auto offset = index.Offset(); common::DispatchBinType(this->index.GetBinTypeSize(), [&](auto dtype) { using T = decltype(dtype); - ::xgboost::SetIndexData(ctx, page, &hit_count_tloc_, index.MakeCompressor(), this); + ::xgboost::SetIndexData( + ctx, page, &hit_count_tloc_, + [offset](bst_bin_t bin_idx, bst_feature_t fidx) { return bin_idx + offset[fidx]; }, this); }); } else { // no compression diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 6c1a89079..3f17c97fb 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -189,7 +189,7 @@ class GHistIndexMatrix { * @brief Constructor for external memory. */ GHistIndexMatrix(SparsePage const& page, common::Span ft, - common::HistogramCuts cuts, int32_t max_bins_per_feat, bool is_dense, + common::HistogramCuts cuts, bst_bin_t max_bins_per_feat, bool is_dense, double sparse_thresh, std::int32_t n_threads); GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back. diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index d46f044ae..145349359 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -12,18 +12,17 @@ namespace xgboost::data { void GradientIndexPageSource::Fetch() { if (!this->ReadCache()) { - if (count_ != 0 && !sync_) { - // source is initialized to be the 0th page during construction, so when count_ is 0 - // there's no need to increment the source. - // + // source is initialized to be the 0th page during construction, so when count_ is 0 + // there's no need to increment the source. + if (this->count_ != 0 && !this->sync_) { // The mixin doesn't sync the source if `sync_` is false, we need to sync it // ourselves. ++(*source_); } // This is not read from cache so we still need it to be synced with sparse page source. - CHECK_EQ(count_, source_->Iter()); - auto const& csr = source_->Page(); - CHECK_NE(cuts_.Values().size(), 0); + CHECK_EQ(this->count_, this->source_->Iter()); + auto const& csr = this->source_->Page(); + CHECK_NE(this->cuts_.Values().size(), 0); this->page_.reset(new GHistIndexMatrix{*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_, nthreads_}); this->WriteCache(); diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 843dacbfa..f7588fe98 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -68,7 +68,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, dh::device_vector row_counts(rows + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); cuda_impl::Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); + return GetRowCounts(ctx, value, row_counts_span, dh::GetDevice(ctx), missing); }); auto is_dense = this->IsDense(); diff --git a/src/data/quantile_dmatrix.cu b/src/data/quantile_dmatrix.cu index ed7066412..605040ef0 100644 --- a/src/data/quantile_dmatrix.cu +++ b/src/data/quantile_dmatrix.cu @@ -72,7 +72,7 @@ void MakeSketches(Context const* ctx, collective::Op::kMax); SafeColl(rc); } else { - CHECK_EQ(ext_info.n_features, ::xgboost::data::BatchColumns(proxy)) + CHECK_EQ(ext_info.n_features, data::BatchColumns(proxy)) << "Inconsistent number of columns."; } @@ -97,7 +97,7 @@ void MakeSketches(Context const* ctx, lazy_init_sketch(); // Add a new level. } proxy->Info().weights_.SetDevice(dh::GetDevice(ctx)); - cuda_impl::Dispatch(proxy, [&](auto const& value) { + Dispatch(proxy, [&](auto const& value) { common::AdapterDeviceSketch(p_ctx, value, p.max_bin, proxy->Info(), missing, sketches.back().first.get()); sketches.back().second++; @@ -110,8 +110,8 @@ void MakeSketches(Context const* ctx, dh::device_vector row_counts(batch_rows + 1, 0); common::Span row_counts_span(row_counts.data().get(), row_counts.size()); ext_info.row_stride = - std::max(ext_info.row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) { - return GetRowCounts(value, row_counts_span, dh::GetDevice(ctx), missing); + std::max(ext_info.row_stride, Dispatch(proxy, [=](auto const& value) { + return GetRowCounts(ctx, value, row_counts_span, dh::GetDevice(ctx), missing); })); ext_info.nnz += thrust::reduce(ctx->CUDACtx()->CTP(), row_counts.begin(), row_counts.end()); ext_info.n_batches++; diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index 6247d66b3..724260512 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -10,9 +10,9 @@ namespace xgboost::data { void Cache::Commit() { - if (!written) { - std::partial_sum(offset.begin(), offset.end(), offset.begin()); - written = true; + if (!this->written) { + std::partial_sum(this->offset.begin(), this->offset.end(), this->offset.begin()); + this->written = true; } } diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 2f37aa413..471a84d60 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -241,6 +241,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol float missing_; std::int32_t nthreads_; bst_feature_t n_features_; + bst_idx_t fetch_cnt_{0}; // Used for sanity check. // Index to the current page. std::uint32_t count_{0}; // Total number of batches. @@ -267,8 +268,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol if (ring_->empty()) { ring_->resize(n_batches_); } - // An heuristic for number of pre-fetched batches. We can make it part of BatchParam - // to let user adjust number of pre-fetched batches when needed. + std::int32_t n_prefetches = std::min(nthreads_, this->param_.n_prefetch_batches); n_prefetches = std::max(n_prefetches, 1); std::int32_t n_prefetch_batches = std::min(static_cast(n_prefetches), n_batches_); @@ -277,14 +277,23 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol std::size_t fetch_it = count_; exce_.Rethrow(); + // Clear out the existing page before loading new ones. This helps reduce memory usage + // when page is not loaded with mmap, in addition, it triggers necessary CUDA + // synchronizations by freeing memory. + page_.reset(); for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { + bool restart = fetch_it == n_batches_; fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { continue; } auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); + // Make sure the new iteration starts with a copy to avoid spilling configuration. + if (restart) { + this->param_.prefetch_copy = true; + } ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, this] { auto page = std::make_shared(); this->exce_.Run([&] { @@ -298,17 +307,17 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol }); return page; }); + this->fetch_cnt_++; } CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; - monitor_.Start("Wait"); + monitor_.Start("Wait-" + std::to_string(count_)); CHECK((*ring_)[count_].valid()); page_ = (*ring_)[count_].get(); - CHECK(!(*ring_)[count_].valid()); - monitor_.Stop("Wait"); + monitor_.Stop("Wait-" + std::to_string(count_)); exce_.Rethrow(); @@ -328,8 +337,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol timer.Stop(); // Not entirely accurate, the kernels doesn't have to flush the data. - LOG(INFO) << static_cast(bytes) / 1024.0 / 1024.0 << " MB written in " - << timer.ElapsedSeconds() << " seconds."; + LOG(INFO) << common::HumanMemUnit(bytes) << " written in " << timer.ElapsedSeconds() + << " seconds."; cache_info_->Push(bytes); } @@ -373,7 +382,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol return at_end_; } - // Call this at the last iteration. + // Call this at the last iteration (it == n_batches). void EndIter() { CHECK_EQ(this->cache_info_->offset.size(), this->n_batches_ + 1); this->cache_info_->Commit(); @@ -387,18 +396,22 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol virtual void Reset(BatchParam const& param) { TryLockGuard guard{single_threaded_}; - this->at_end_ = false; - auto cnt = this->count_; - this->count_ = 0; + auto at_end = false; + std::swap(this->at_end_, at_end); + bool changed = this->param_.n_prefetch_batches != param.n_prefetch_batches; this->param_ = param; - if (cnt != 0 || changed) { + this->count_ = 0; + + if (!at_end || changed) { // The last iteration did not get to the end, clear the ring to start from 0. this->ring_ = std::make_unique(); - this->Fetch(); } + this->Fetch(); // Get the 0^th page, prefetch the next page. } + + [[nodiscard]] auto FetchCount() const { return this->fetch_cnt_; } }; #if defined(XGBOOST_USE_CUDA) @@ -413,10 +426,8 @@ class SparsePageSource : public SparsePageSourceImpl { DataIterProxy iter_; DMatrixProxy* proxy_; std::size_t base_row_id_{0}; - bst_idx_t fetch_cnt_{0}; // Used for sanity check. void Fetch() final { - fetch_cnt_++; page_ = std::make_shared(); // The first round of reading, this is responsible for initialization. if (!this->ReadCache()) { @@ -467,9 +478,10 @@ class SparsePageSource : public SparsePageSourceImpl { if (at_end_) { this->EndIter(); this->proxy_ = nullptr; + } else { + this->Fetch(); } - this->Fetch(); return *this; } @@ -481,13 +493,13 @@ class SparsePageSource : public SparsePageSourceImpl { SparsePageSourceImpl::Reset(param); TryLockGuard guard{single_threaded_}; - base_row_id_ = 0; + this->base_row_id_ = 0; } - - [[nodiscard]] auto FetchCount() const { return fetch_cnt_; } }; -// A mixin for advancing the iterator. +/** + * @brief A mixin for advancing the iterator with a sparse page source. + */ template > class PageSourceIncMixIn : public SparsePageSourceImpl { @@ -496,7 +508,7 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { using Super = SparsePageSourceImpl; // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page // so we avoid fetching it. - bool sync_{true}; + bool const sync_; public: PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features, @@ -506,8 +518,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { // can assume the source to be ready. [[nodiscard]] PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; + // Increment the source. - if (sync_) { + if (this->sync_) { ++(*source_); } // Increment self. @@ -516,24 +529,16 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { this->at_end_ = this->count_ == this->n_batches_; if (this->at_end_) { - // If this is the first round of iterations, we have just built the binary cache - // from soruce. For a non-sync page type, the source hasn't been updated to the end - // iteration yet due to skipped increment. We increment the source here and it will - // call the `EndIter` method itself. - bool src_need_inc = !sync_ && this->source_->Iter() != 0; - if (src_need_inc) { - CHECK_EQ(this->source_->Iter(), this->count_ - 1); - ++(*source_); - } this->EndIter(); - - if (src_need_inc) { - CHECK(this->cache_info_->written); + CHECK(this->cache_info_->written); + if (!this->sync_) { + source_.reset(); // Make sure no unnecessary fetch. } + } else { + this->Fetch(); } - this->Fetch(); - if (sync_) { + if (this->sync_) { // Sanity check. CHECK_EQ(source_->Iter(), this->count_); } @@ -541,7 +546,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { } void Reset(BatchParam const& param) final { - this->source_->Reset(param); + if (this->sync_ || !this->cache_info_->written) { + this->source_->Reset(param); + } Super::Reset(param); } }; @@ -625,8 +632,9 @@ class ExtQantileSourceMixin : public SparsePageSourceImpl CHECK(this->cache_info_->written); source_ = nullptr; // release the source + } else { + this->Fetch(); } - this->Fetch(); return *this; } diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 7f1f79dee..d50f7284e 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -24,16 +24,20 @@ __host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { return {lhs.first + rhs.first, lhs.second + rhs.second}; } +XGBOOST_DEV_INLINE bst_feature_t FeatIdx(FeatureGroup const& group, bst_idx_t idx, + std::int32_t feature_stride) { + auto fidx = group.start_feature + idx % feature_stride; + return fidx; +} + XGBOOST_DEV_INLINE bst_idx_t IterIdx(EllpackDeviceAccessor const& matrix, - RowPartitioner::RowIndexT ridx, FeatureGroup const& group, - bst_idx_t idx, std::int32_t feature_stride) { + RowPartitioner::RowIndexT ridx, bst_feature_t fidx) { // ridx_local = ridx - base_rowid <== Row index local to each batch // entry_idx = ridx_local * row_stride <== Starting entry index for this row in the matrix // entry_idx += start_feature <== Inside a row, first column inside this feature group // idx % feature_stride <== The feaature index local to the current feature group // entry_idx += idx % feature_stride <== Final index. - return (ridx - matrix.base_rowid) * matrix.row_stride + group.start_feature + - idx % feature_stride; + return (ridx - matrix.base_rowid) * matrix.row_stride + fidx; } } // anonymous namespace @@ -134,7 +138,7 @@ XGBOOST_DEV_INLINE void AtomicAddGpairGlobal(xgboost::GradientPairInt64* dest, *reinterpret_cast(&h)); } -template class HistogramAgent { GradientPairInt64* smem_arr_; @@ -159,7 +163,7 @@ class HistogramAgent { d_ridx_(d_ridx.data()), group_(group), matrix_(matrix), - feature_stride_(matrix.is_dense ? group.num_features : matrix.row_stride), + feature_stride_(kIsDense ? group.num_features : matrix.row_stride), n_elements_(feature_stride_ * d_ridx.size()), rounding_(rounding), d_gpair_(d_gpair) {} @@ -169,12 +173,19 @@ class HistogramAgent { idx < std::min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx += kBlockThreads) { Idx ridx = d_ridx_[idx / feature_stride_]; - bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)]; - if (matrix_.is_dense || gidx != matrix_.NullValue()) { + auto fidx = FeatIdx(group_, idx, feature_stride_); + bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)]; + if (kIsDense || compressed_bin != matrix_.NullValue()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); // Subtract start_bin to write to group-local histogram. If this is not a dense // matrix, then start_bin is 0 since featuregrouping doesn't support sparse data. - AtomicAddGpairShared(smem_arr_ + gidx - group_.start_bin, adjusted); + if (kIsDense) { + AtomicAddGpairShared( + smem_arr_ + compressed_bin + this->matrix_.feature_segments[fidx] - group_.start_bin, + adjusted); + } else { + AtomicAddGpairShared(smem_arr_ + compressed_bin - group_.start_bin, adjusted); + } } } } @@ -185,7 +196,7 @@ class HistogramAgent { __device__ void ProcessFullTileShared(std::size_t offset) { std::size_t idx[kItemsPerThread]; Idx ridx[kItemsPerThread]; - int gidx[kItemsPerThread]; + bst_bin_t gidx[kItemsPerThread]; GradientPair gpair[kItemsPerThread]; #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { @@ -198,11 +209,17 @@ class HistogramAgent { #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { gpair[i] = d_gpair_[ridx[i]]; - gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], group_, idx[i], feature_stride_)]; + auto fidx = FeatIdx(group_, idx[i], feature_stride_); + if (kIsDense) { + gidx[i] = + matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)] + matrix_.feature_segments[fidx]; + } else { + gidx[i] = matrix_.gidx_iter[IterIdx(matrix_, ridx[i], fidx)]; + } } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { - if ((matrix_.is_dense || gidx[i] != matrix_.NullValue())) { + if ((kIsDense || gidx[i] != matrix_.NullValue())) { auto adjusted = rounding_.ToFixedPoint(gpair[i]); AtomicAddGpairShared(smem_arr_ + gidx[i] - group_.start_bin, adjusted); } @@ -229,16 +246,22 @@ class HistogramAgent { __device__ void BuildHistogramWithGlobal() { for (auto idx : dh::GridStrideRange(static_cast(0), n_elements_)) { Idx ridx = d_ridx_[idx / feature_stride_]; - bst_bin_t gidx = matrix_.gidx_iter[IterIdx(matrix_, ridx, group_, idx, feature_stride_)]; - if (matrix_.is_dense || gidx != matrix_.NullValue()) { + auto fidx = FeatIdx(group_, idx, feature_stride_); + bst_bin_t compressed_bin = matrix_.gidx_iter[IterIdx(matrix_, ridx, fidx)]; + if (kIsDense || compressed_bin != matrix_.NullValue()) { auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); - AtomicAddGpairGlobal(d_node_hist_ + gidx, adjusted); + if (kIsDense) { + auto start_bin = this->matrix_.feature_segments[fidx]; + AtomicAddGpairGlobal(d_node_hist_ + compressed_bin + start_bin, adjusted); + } else { + AtomicAddGpairGlobal(d_node_hist_ + compressed_bin, adjusted); + } } } } }; -template +template __global__ void __launch_bounds__(kBlockThreads) SharedMemHistKernel(const EllpackDeviceAccessor matrix, const FeatureGroupsAccessor feature_groups, @@ -249,8 +272,8 @@ __global__ void __launch_bounds__(kBlockThreads) extern __shared__ char smem[]; const FeatureGroup group = feature_groups[blockIdx.y]; auto smem_arr = reinterpret_cast(smem); - auto agent = HistogramAgent(smem_arr, d_node_hist, group, matrix, - d_ridx, rounding, d_gpair); + auto agent = HistogramAgent( + smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair); if (use_shared_memory_histograms) { agent.BuildHistogramWithShared(); } else { @@ -265,11 +288,22 @@ constexpr std::int32_t ItemsPerTile() { return kBlockThreads * kItemsPerThread; } // namespace // Use auto deduction guide to workaround compiler error. -template , - auto Shared = SharedMemHistKernel> +template , + auto Global = SharedMemHistKernel, + auto SharedDense = SharedMemHistKernel, + auto Shared = SharedMemHistKernel> struct HistogramKernel { - decltype(Global) global_kernel{SharedMemHistKernel}; - decltype(Shared) shared_kernel{SharedMemHistKernel}; + // Kernel for working with dense Ellpack using the global memory. + decltype(Global) global_dense_kernel{ + SharedMemHistKernel}; + // Kernel for working with sparse Ellpack using the global memory. + decltype(Global) global_kernel{SharedMemHistKernel}; + // Kernel for working with dense Ellpack using the shared memory. + decltype(Shared) shared_dense_kernel{ + SharedMemHistKernel}; + // Kernel for working with sparse Ellpack using the shared memory. + decltype(Shared) shared_kernel{SharedMemHistKernel}; + bool shared{false}; std::uint32_t grid_size{0}; std::size_t smem_size{0}; @@ -303,28 +337,30 @@ struct HistogramKernel { // maximum number of blocks this->grid_size = n_blocks_per_mp * n_mps; }; - - init(this->global_kernel); - init(this->shared_kernel); + // Initialize all kernel instantiations + for (auto& kernel : {global_dense_kernel, global_kernel, shared_dense_kernel, shared_kernel}) { + init(kernel); + } } }; class DeviceHistogramBuilderImpl { std::unique_ptr> kernel_{nullptr}; - bool force_global_memory_{false}; public: void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, bool force_global_memory) { this->kernel_ = std::make_unique>(ctx, feature_groups, force_global_memory); - this->force_global_memory_ = force_global_memory; + if (force_global_memory) { + CHECK(!this->kernel_->shared); + } } void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span d_ridx, - common::Span histogram, GradientQuantiser rounding) { + common::Span histogram, GradientQuantiser rounding) const { CHECK(kernel_); // Otherwise launch blocks such that each block has a minimum amount of work to do // There are fixed costs to launching each block, e.g. zeroing shared memory @@ -338,17 +374,26 @@ class DeviceHistogramBuilderImpl { auto constexpr kMinItemsPerBlock = ItemsPerTile(); auto grid_size = std::min(kernel_->grid_size, static_cast(common::DivRoundUp( items_per_group, kMinItemsPerBlock))); + auto launcher = [&](auto kernel) { + dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT + static_cast(kBlockThreads), kernel_->smem_size, ctx->Stream()}( + kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); + }; - if (this->force_global_memory_ || !this->kernel_->shared) { - dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT - static_cast(kBlockThreads), kernel_->smem_size, - ctx->Stream()}(kernel_->global_kernel, matrix, feature_groups, d_ridx, - histogram.data(), gpair.data(), rounding); + if (!this->kernel_->shared) { + CHECK_EQ(this->kernel_->smem_size, 0); + if (matrix.is_dense) { + launcher(this->kernel_->global_dense_kernel); + } else { + launcher(this->kernel_->global_kernel); + } } else { - dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT - static_cast(kBlockThreads), kernel_->smem_size, - ctx->Stream()}(kernel_->shared_kernel, matrix, feature_groups, d_ridx, - histogram.data(), gpair.data(), rounding); + CHECK_NE(this->kernel_->smem_size, 0); + if (matrix.is_dense) { + launcher(this->kernel_->shared_dense_kernel); + } else { + launcher(this->kernel_->shared_kernel); + } } } }; diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 95a00fd79..55e398e1b 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -172,8 +172,8 @@ class DeviceHistogramBuilder { // Attempt to do subtraction trick // return true if succeeded - [[nodiscard]] bool SubtractionTrick(bst_node_t nidx_parent, bst_node_t nidx_histogram, - bst_node_t nidx_subtraction) { + [[nodiscard]] bool SubtractionTrick(Context const* ctx, bst_node_t nidx_parent, + bst_node_t nidx_histogram, bst_node_t nidx_subtraction) { if (!hist_.HistogramExists(nidx_histogram) || !hist_.HistogramExists(nidx_parent)) { return false; } @@ -181,13 +181,13 @@ class DeviceHistogramBuilder { auto d_node_hist_histogram = hist_.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist_.GetNodeHistogram(nidx_subtraction); - dh::LaunchN(d_node_hist_parent.size(), [=] __device__(size_t idx) { + dh::LaunchN(d_node_hist_parent.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); return true; } - [[nodiscard]] auto SubtractHist(std::vector const& candidates, + [[nodiscard]] auto SubtractHist(Context const* ctx, std::vector const& candidates, std::vector const& build_nidx, std::vector const& subtraction_nidx) { this->monitor_.Start(__func__); @@ -197,7 +197,7 @@ class DeviceHistogramBuilder { auto subtraction_trick_nidx = subtraction_nidx.at(i); auto parent_nidx = candidates.at(i).nid; - if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { + if (!this->SubtractionTrick(ctx, parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { need_build.push_back(subtraction_trick_nidx); } } diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 0101be085..8eb5fb7f7 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -129,7 +129,7 @@ struct WriteResultsFunctor { * @param d_batch_info Node data, with the size of the input number of nodes. */ template -void SortPositionBatch(common::Span> d_batch_info, +void SortPositionBatch(Context const* ctx, common::Span> d_batch_info, common::Span ridx, common::Span ridx_tmp, common::Span d_counts, bst_idx_t total_rows, OpT op, @@ -150,17 +150,28 @@ void SortPositionBatch(common::Span> d_batch_info, return IndexFlagTuple{static_cast(item_idx), go_left, nidx_in_batch, go_left}; }); - std::size_t temp_bytes = 0; - // Restriction imposed by cub. - CHECK_LE(total_rows, static_cast(std::numeric_limits::max())); + // Avoid using int as the offset type + std::size_t n_bytes = 0; if (tmp->empty()) { - dh::safe_cuda(cub::DeviceScan::InclusiveScan( - nullptr, temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp{}, total_rows)); - tmp->resize(temp_bytes); + auto ret = + cub::DispatchScan::Dispatch(nullptr, n_bytes, input_iterator, + discard_write_iterator, + IndexFlagOp{}, cub::NullType{}, + total_rows, + ctx->CUDACtx()->Stream()); + dh::safe_cuda(ret); + tmp->resize(n_bytes); } - temp_bytes = tmp->size(); - dh::safe_cuda(cub::DeviceScan::InclusiveScan(tmp->data(), temp_bytes, input_iterator, - discard_write_iterator, IndexFlagOp{}, total_rows)); + n_bytes = tmp->size(); + auto ret = + cub::DispatchScan::Dispatch(tmp->data(), n_bytes, input_iterator, + discard_write_iterator, + IndexFlagOp{}, cub::NullType{}, + total_rows, + ctx->CUDACtx()->Stream()); + dh::safe_cuda(ret); constexpr int kBlockSize = 256; @@ -169,7 +180,8 @@ void SortPositionBatch(common::Span> d_batch_info, const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); SortPositionCopyKernel - <<>>(batch_info_itr, ridx, ridx_tmp, total_rows); + <<CUDACtx()->Stream()>>>(batch_info_itr, ridx, ridx_tmp, + total_rows); } struct NodePositionInfo { @@ -293,7 +305,7 @@ class RowPartitioner { * second. Returns true if this training instance goes on the left partition. */ template - void UpdatePositionBatch(const std::vector& nidx, + void UpdatePositionBatch(Context const* ctx, const std::vector& nidx, const std::vector& left_nidx, const std::vector& right_nidx, const std::vector& op_data, UpdatePositionOpT op) { @@ -316,21 +328,22 @@ class RowPartitioner { } dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), h_batch_info.size() * sizeof(PerNodeData), - cudaMemcpyDefault)); + cudaMemcpyDefault, ctx->CUDACtx()->Stream())); // Temporary arrays - auto h_counts = pinned_.GetSpan(nidx.size(), 0); + auto h_counts = pinned_.GetSpan(nidx.size()); + // Must initialize with 0 as 0 count is not written in the kernel. dh::TemporaryArray d_counts(nidx.size(), 0); // Partition the rows according to the operator - SortPositionBatch(dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), + SortPositionBatch(ctx, dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), total_rows, op, &tmp_); dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), - cudaMemcpyDefault)); + cudaMemcpyDefault, ctx->CUDACtx()->Stream())); // TODO(Rory): this synchronisation hurts performance a lot // Future optimisation should find a way to skip this - dh::DefaultStream().Sync(); + ctx->CUDACtx()->Stream().Sync(); // Update segments for (std::size_t i = 0; i < nidx.size(); i++) { @@ -341,9 +354,9 @@ class RowPartitioner { std::max(left_nidx[i], right_nidx[i]) + 1)); ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]}; ridx_segments_[left_nidx[i]] = - NodePositionInfo{Segment(segment.begin, segment.begin + left_count)}; + NodePositionInfo{Segment{segment.begin, segment.begin + left_count}}; ridx_segments_[right_nidx[i]] = - NodePositionInfo{Segment(segment.begin + left_count, segment.end)}; + NodePositionInfo{Segment{segment.begin + left_count, segment.end}}; } } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index f0e353e22..0fdc30822 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -119,17 +119,15 @@ struct DeviceSplitCandidate { }; namespace cuda_impl { +constexpr auto DftPrefetchBatches() { return 2; } + inline BatchParam HistBatch(TrainParam const& param) { auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; p.prefetch_copy = true; - p.n_prefetch_batches = 1; + p.n_prefetch_batches = DftPrefetchBatches(); return p; } -inline BatchParam HistBatch(bst_bin_t max_bin) { - return {max_bin, TrainParam::DftSparseThreshold()}; -} - inline BatchParam ApproxBatch(TrainParam const& p, common::Span hess, ObjInfo const& task) { return BatchParam{p.max_bin, hess, !task.const_hess}; @@ -139,7 +137,7 @@ inline BatchParam ApproxBatch(TrainParam const& p, common::Span hes inline BatchParam StaticBatch(bool prefetch_copy) { BatchParam p; p.prefetch_copy = prefetch_copy; - p.n_prefetch_batches = 1; + p.n_prefetch_batches = DftPrefetchBatches(); return p; } } // namespace cuda_impl diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 283a8af1b..390422ce1 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -70,7 +70,6 @@ void AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer, common::Span nodes_to_build, common::Span nodes_to_sub) { auto const& tree = *p_tree; std::size_t nidx_in_set{0}; - double total{0.0}, smaller{0.0}; auto p_build_nidx = nodes_to_build.data(); auto p_sub_nidx = nodes_to_sub.data(); for (auto& e : candidates) { @@ -81,15 +80,12 @@ void AssignNodes(RegTree const* p_tree, GradientQuantiser const* quantizer, auto left_sum = quantizer->ToFloatingPoint(e.split.left_sum); auto right_sum = quantizer->ToFloatingPoint(e.split.right_sum); bool fewer_right = right_sum.GetHess() < left_sum.GetHess(); - total += left_sum.GetHess() + right_sum.GetHess(); if (fewer_right) { p_build_nidx[nidx_in_set] = tree[e.nid].RightChild(); p_sub_nidx[nidx_in_set] = tree[e.nid].LeftChild(); - smaller += right_sum.GetHess(); } else { p_build_nidx[nidx_in_set] = tree[e.nid].LeftChild(); p_sub_nidx[nidx_in_set] = tree[e.nid].RightChild(); - smaller += left_sum.GetHess(); } ++nidx_in_set; } @@ -348,7 +344,7 @@ struct GPUHistMakerDevice { // This gives much better latency in a distributed setting when processing a large batch this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size()); // Perform subtraction for sibiling nodes - auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx); + auto need_build = this->histogram_.SubtractHist(ctx_, candidates, build_nidx, subtraction_nidx); if (need_build.empty()) { this->monitor.Stop(__func__); return; @@ -383,12 +379,14 @@ struct GPUHistMakerDevice { BitVector decision_bits{dh::ToSpan(decision_storage)}; BitVector missing_bits{dh::ToSpan(missing_storage)}; + auto cuctx = this->ctx_->CUDACtx(); dh::TemporaryArray split_data_storage(num_candidates); dh::safe_cuda(cudaMemcpyAsync(split_data_storage.data().get(), split_data.data(), - num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault)); + num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault, + cuctx->Stream())); auto d_split_data = dh::ToSpan(split_data_storage); - dh::LaunchN(d_matrix.n_rows, [=] __device__(std::size_t ridx) mutable { + dh::LaunchN(d_matrix.n_rows, cuctx->Stream(), [=] __device__(std::size_t ridx) mutable { for (auto i = 0; i < num_candidates; i++) { auto const& data = d_split_data[i]; auto const cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); @@ -421,7 +419,7 @@ struct GPUHistMakerDevice { CHECK_EQ(partitioners_.size(), 1) << "External memory with column split is not yet supported."; partitioners_.front()->UpdatePositionBatch( - nidx, left_nidx, right_nidx, split_data, + ctx_, nidx, left_nidx, right_nidx, split_data, [=] __device__(bst_uint ridx, int nidx_in_batch, NodeSplitData const& data) { auto const index = ridx * num_candidates + nidx_in_batch; bool go_left; @@ -495,10 +493,11 @@ struct GPUHistMakerDevice { UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); } else { partitioners_.at(k)->UpdatePositionBatch( - nidx, left_nidx, right_nidx, split_data, + ctx_, nidx, left_nidx, right_nidx, split_data, [=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/, const NodeSplitData& data) { return go_left(ridx, data); }); } + monitor.Stop("UpdatePositionBatch"); for (auto nidx : build_nidx) { @@ -556,7 +555,7 @@ struct GPUHistMakerDevice { return; } - dh::caching_device_vector categories; + dh::CachingDeviceUVector categories; dh::CopyTo(p_tree->GetSplitCategories(), &categories, this->ctx_->CUDACtx()->Stream()); auto const& cat_segments = p_tree->GetSplitCategoriesPtr(); auto d_categories = dh::ToSpan(categories); @@ -575,7 +574,7 @@ struct GPUHistMakerDevice { } auto go_left_op = GoLeftOp{d_matrix}; - dh::caching_device_vector d_split_data; + dh::CachingDeviceUVector d_split_data; dh::CopyTo(split_data, &d_split_data, this->ctx_->CUDACtx()->Stream()); auto s_split_data = dh::ToSpan(d_split_data); @@ -610,7 +609,7 @@ struct GPUHistMakerDevice { // Use the nodes from tree, the leaf value might be changed by the objective since the // last update tree call. - dh::caching_device_vector nodes; + dh::CachingDeviceUVector nodes; dh::CopyTo(p_tree->GetNodes(), &nodes, this->ctx_->CUDACtx()->Stream()); common::Span d_nodes = dh::ToSpan(nodes); CHECK_EQ(out_preds_d.Shape(1), 1); @@ -820,6 +819,7 @@ class GPUHistMaker : public TreeUpdater { } void InitDataOnce(TrainParam const* param, DMatrix* p_fmat) { + monitor_.Start(__func__); CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device"; // Synchronise the column sampling seed @@ -840,24 +840,22 @@ class GPUHistMaker : public TreeUpdater { p_last_fmat_ = p_fmat; initialised_ = true; + monitor_.Stop(__func__); } void InitData(TrainParam const* param, DMatrix* dmat, RegTree const* p_tree) { + monitor_.Start(__func__); if (!initialised_) { - monitor_.Start("InitDataOnce"); this->InitDataOnce(param, dmat); - monitor_.Stop("InitDataOnce"); } p_last_tree_ = p_tree; CHECK(hist_maker_param_.GetInitialised()); + monitor_.Stop(__func__); } void UpdateTree(TrainParam const* param, HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree, HostDeviceVector* p_out_position) { - monitor_.Start("InitData"); this->InitData(param, p_fmat, p_tree); - monitor_.Stop("InitData"); - gpair->SetDevice(ctx_->Device()); maker->UpdateTree(gpair, p_fmat, task_, p_tree, p_out_position); } diff --git a/tests/cpp/common/test_common.cc b/tests/cpp/common/test_common.cc new file mode 100644 index 000000000..abc760ec2 --- /dev/null +++ b/tests/cpp/common/test_common.cc @@ -0,0 +1,19 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include + +#include "../../../src/common/common.h" + +namespace xgboost::common { +TEST(Common, HumanMemUnit) { + auto name = HumanMemUnit(1024 * 1024 * 1024ul); + ASSERT_EQ(name, "1GB"); + name = HumanMemUnit(1024 * 1024ul); + ASSERT_EQ(name, "1MB"); + name = HumanMemUnit(1024); + ASSERT_EQ(name, "1KB"); + name = HumanMemUnit(1); + ASSERT_EQ(name, "1B"); +} +} // namespace xgboost::common diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index 61cc9463c..f0bb2b7d5 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -1,9 +1,9 @@ -// Copyright (c) 2019 by Contributors +/** + * Copyright 2019-2024, XGBoost contributors + */ #include #include #include "../../../src/data/adapter.h" -#include "../../../src/data/simple_dmatrix.h" -#include "../../../src/common/timer.h" #include "../helpers.h" #include #include "../../../src/data/device_adapter.cuh" @@ -64,7 +64,7 @@ TEST(DeviceAdapter, GetRowCounts) { auto adapter = CupyAdapter{str_arr}; HostDeviceVector offset(adapter.NumRows() + 1, 0); offset.SetDevice(ctx.Device()); - auto rstride = GetRowCounts(adapter.Value(), offset.DeviceSpan(), ctx.Device(), + auto rstride = GetRowCounts(&ctx, adapter.Value(), offset.DeviceSpan(), ctx.Device(), std::numeric_limits::quiet_NaN()); ASSERT_EQ(rstride, n_features); } diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 8a441d6ce..55375a5a7 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -30,13 +30,13 @@ TEST(EllpackPage, EmptyDMatrix) { } TEST(EllpackPage, BuildGidxDense) { - int constexpr kNRows = 16, kNCols = 8; + bst_idx_t n_samples = 16, n_features = 8; auto ctx = MakeCUDACtx(0); - auto page = BuildEllpackPage(&ctx, kNRows, kNCols); + auto page = BuildEllpackPage(&ctx, n_samples, n_features); std::vector h_gidx_buffer; auto h_accessor = page->GetHostAccessor(&ctx, &h_gidx_buffer); - ASSERT_EQ(page->row_stride, kNCols); + ASSERT_EQ(page->row_stride, n_features); std::vector solution = { 0, 3, 8, 9, 14, 17, 20, 21, @@ -56,8 +56,9 @@ TEST(EllpackPage, BuildGidxDense) { 2, 4, 8, 10, 14, 15, 19, 22, 1, 4, 7, 10, 14, 16, 19, 21, }; - for (size_t i = 0; i < kNRows * kNCols; ++i) { - ASSERT_EQ(solution[i], h_accessor.gidx_iter[i]); + for (size_t i = 0; i < n_samples * n_features; ++i) { + auto fidx = i % n_features; + ASSERT_EQ(solution[i], h_accessor.gidx_iter[i] + h_accessor.feature_segments[fidx]); } } @@ -263,12 +264,12 @@ class EllpackPageTest : public testing::TestWithParam { ASSERT_EQ(from_sparse_page->base_rowid, from_ghist->base_rowid); ASSERT_EQ(from_sparse_page->n_rows, from_ghist->n_rows); ASSERT_EQ(from_sparse_page->gidx_buffer.size(), from_ghist->gidx_buffer.size()); + ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols()); std::vector h_gidx_from_sparse, h_gidx_from_ghist; auto from_ghist_acc = from_ghist->GetHostAccessor(&gpu_ctx, &h_gidx_from_ghist); auto from_sparse_acc = from_sparse_page->GetHostAccessor(&gpu_ctx, &h_gidx_from_sparse); - ASSERT_EQ(from_sparse_page->NumSymbols(), from_ghist->NumSymbols()); for (size_t i = 0; i < from_ghist->n_rows * from_ghist->row_stride; ++i) { - EXPECT_EQ(from_ghist_acc.gidx_iter[i], from_sparse_acc.gidx_iter[i]); + ASSERT_EQ(from_ghist_acc.gidx_iter[i], from_sparse_acc.gidx_iter[i]); } } } diff --git a/tests/cpp/data/test_iterative_dmatrix.cu b/tests/cpp/data/test_iterative_dmatrix.cu index 8797fc18d..8d2e837ff 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cu +++ b/tests/cpp/data/test_iterative_dmatrix.cu @@ -106,9 +106,11 @@ TEST(IterativeDeviceDMatrix, RowMajor) { common::Span s_data{static_cast(loaded.data), cols * rows}; dh::CopyDeviceSpanToVector(&h_data, s_data); - for(auto i = 0ull; i < rows * cols; i++) { + auto cut_ptr = h_accessor.feature_segments; + for (auto i = 0ull; i < rows * cols; i++) { int column_idx = i % cols; - EXPECT_EQ(impl->Cuts().SearchBin(h_data[i], column_idx), h_accessor.gidx_iter[i]); + EXPECT_EQ(impl->Cuts().SearchBin(h_data[i], column_idx), + h_accessor.gidx_iter[i] + cut_ptr[column_idx]); } EXPECT_EQ(m.Info().num_col_, cols); EXPECT_EQ(m.Info().num_row_, rows); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index a557b7f62..f6991cfd5 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -12,6 +12,7 @@ #include "../../../src/data/file_iterator.h" #include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/sparse_page_dmatrix.h" +#include "../../../src/tree/param.h" // for TrainParam #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" @@ -115,6 +116,47 @@ TEST(SparsePageDMatrix, RetainSparsePage) { TestRetainPage(); } +class TestGradientIndexExt : public ::testing::TestWithParam { + protected: + void Run(bool is_dense) { + constexpr bst_idx_t kRows = 64; + constexpr size_t kCols = 2; + float sparsity = is_dense ? 0.0 : 0.4; + bst_bin_t n_bins = 16; + Context ctx; + auto p_ext_fmat = + RandomDataGenerator{kRows, kCols, sparsity}.Batches(4).GenerateSparsePageDMatrix("temp", + true); + + auto cuts = common::SketchOnDMatrix(&ctx, p_ext_fmat.get(), n_bins, false, {}); + std::vector> pages; + for (auto const &page : p_ext_fmat->GetBatches()) { + pages.emplace_back(std::make_unique( + page, common::Span{}, cuts, n_bins, is_dense, 0.8, ctx.Threads())); + } + std::int32_t k = 0; + for (auto const &page : p_ext_fmat->GetBatches( + &ctx, BatchParam{n_bins, tree::TrainParam::DftSparseThreshold()})) { + auto const &from_sparse = pages[k]; + ASSERT_TRUE(std::equal(page.index.begin(), page.index.end(), from_sparse->index.begin())); + if (is_dense) { + ASSERT_TRUE(std::equal(page.index.Offset(), page.index.Offset() + kCols, + from_sparse->index.Offset())); + } else { + ASSERT_FALSE(page.index.Offset()); + ASSERT_FALSE(from_sparse->index.Offset()); + } + ASSERT_TRUE( + std::equal(page.row_ptr.cbegin(), page.row_ptr.cend(), from_sparse->row_ptr.cbegin())); + ++k; + } + } +}; + +TEST_P(TestGradientIndexExt, Basic) { this->Run(this->GetParam()); } + +INSTANTIATE_TEST_SUITE_P(SparsePageDMatrix, TestGradientIndexExt, testing::Bool()); + // Test GHistIndexMatrix can avoid loading sparse page after the initialization. TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) { dmlc::TemporaryDirectory tmpdir; diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 55151c807..ff65b6ae5 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -40,10 +40,9 @@ TEST(SparsePageDMatrix, EllpackPage) { TEST(SparsePageDMatrix, EllpackSkipSparsePage) { // Test Ellpack can avoid loading sparse page after the initialization. - dmlc::TemporaryDirectory tmpdir; std::size_t n_batches = 6; - auto Xy = RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix( - tmpdir.path + "/", true); + auto Xy = + RandomDataGenerator{180, 12, 0.0}.Batches(n_batches).GenerateSparsePageDMatrix("temp", true); auto ctx = MakeCUDACtx(0); auto cpu = ctx.MakeCPU(); bst_bin_t n_bins{256}; @@ -117,7 +116,6 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) { TEST(SparsePageDMatrix, MultipleEllpackPages) { auto ctx = MakeCUDACtx(0); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; - dmlc::TemporaryDirectory tmpdir; auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true); // Loop over the batches and count the records @@ -155,18 +153,24 @@ TEST(SparsePageDMatrix, RetainEllpackPage) { auto const& d_src = (*it).Impl()->gidx_buffer; dh::safe_cuda(cudaMemcpyAsync(d_dst, d_src.data(), d_src.size_bytes(), cudaMemcpyDefault)); } - ASSERT_GE(iterators.size(), 2); + ASSERT_EQ(iterators.size(), 8); for (size_t i = 0; i < iterators.size(); ++i) { std::vector h_buf; [[maybe_unused]] auto h_acc = (*iterators[i]).Impl()->GetHostAccessor(&ctx, &h_buf); ASSERT_EQ(h_buf, gidx_buffers.at(i).HostVector()); - ASSERT_EQ(iterators[i].use_count(), 1); + // The last page is still kept in the DMatrix until Reset is called. + if (i == iterators.size() - 1) { + ASSERT_EQ(iterators[i].use_count(), 2); + } else { + ASSERT_EQ(iterators[i].use_count(), 1); + } } // make sure it's const and the caller can not modify the content of page. for (auto& page : m->GetBatches(&ctx, param)) { static_assert(std::is_const_v>); + break; } // The above iteration clears out all references inside DMatrix. @@ -190,13 +194,10 @@ class TestEllpackPageExt : public ::testing::TestWithParamGetBatches(&ctx, param).begin()).Impl(); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 5dee9c909..e26c8b980 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -73,13 +73,13 @@ TEST(Histogram, SubtractionTrack) { histogram.AllocateHistograms(&ctx, {0, 1, 2}); GPUExpandEntry root; root.nid = 0; - auto need_build = histogram.SubtractHist({root}, {0}, {1}); + auto need_build = histogram.SubtractHist(&ctx, {root}, {0}, {1}); std::vector candidates(2); candidates[0].nid = 1; candidates[1].nid = 2; - need_build = histogram.SubtractHist(candidates, {3, 5}, {4, 6}); + need_build = histogram.SubtractHist(&ctx, candidates, {3, 5}, {4, 6}); ASSERT_EQ(need_build.size(), 2); ASSERT_EQ(need_build[0], 4); ASSERT_EQ(need_build[1], 6); diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 48e916efb..76d3c7d07 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -33,9 +33,9 @@ void TestUpdatePositionBatch() { std::vector extra_data = {0}; // Send the first five training instances to the right node // and the second 5 to the left node - rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { - return ridx > 4; - }); + rp.UpdatePositionBatch( + &ctx, {0}, {1}, {2}, extra_data, + [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx > 4; }); rows = rp.GetRowsHost(1); for (auto r : rows) { EXPECT_GT(r, 4); @@ -46,9 +46,9 @@ void TestUpdatePositionBatch() { } // Split the left node again - rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) { - return ridx < 7; - }); + rp.UpdatePositionBatch( + &ctx, {1}, {3}, {4}, extra_data, + [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx < 7; }); EXPECT_EQ(rp.GetRows(3).size(), 2); EXPECT_EQ(rp.GetRows(4).size(), 3); } @@ -56,6 +56,7 @@ void TestUpdatePositionBatch() { TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); } void TestSortPositionBatch(const std::vector& ridx_in, const std::vector& segments) { + auto ctx = MakeCUDACtx(0); thrust::device_vector ridx = ridx_in; thrust::device_vector ridx_tmp(ridx_in.size()); thrust::device_vector counts(segments.size()); @@ -74,7 +75,7 @@ void TestSortPositionBatch(const std::vector& ridx_in, const std::vector), cudaMemcpyDefault, nullptr)); dh::DeviceUVector tmp; - SortPositionBatch(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), + SortPositionBatch(&ctx, dh::ToSpan(d_batch_info), dh::ToSpan(ridx), dh::ToSpan(ridx_tmp), dh::ToSpan(counts), total_rows, op, &tmp); @@ -145,7 +146,7 @@ void TestExternalMemory() { std::vector splits{tree[0]}; auto acc = page.Impl()->GetDeviceAccessor(&ctx); partitioners.back()->UpdatePositionBatch( - {0}, {1}, {2}, splits, + &ctx, {0}, {1}, {2}, splits, [=] __device__(bst_idx_t ridx, std::int32_t nidx_in_batch, RegTree::Node const& node) { auto fvalue = acc.GetFvalue(ridx, node.SplitIndex()); return fvalue <= node.SplitCond();