diff --git a/include/xgboost/data.h b/include/xgboost/data.h index ef8734136..6f8c818c8 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -113,7 +113,10 @@ class MetaInfo { MetaInfo Slice(common::Span ridxs) const; MetaInfo Copy() const; - + /** + * @brief Whether the matrix is dense. + */ + bool IsDense() const { return num_col_ * num_row_ == num_nonzero_; } /*! * \brief Get weight of each instances. * \param i Instance index. @@ -538,10 +541,10 @@ class DMatrix { /*! \brief virtual destructor */ virtual ~DMatrix(); - /*! \brief Whether the matrix is dense. */ - [[nodiscard]] bool IsDense() const { - return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; - } + /** + * @brief Whether the matrix is dense. + */ + [[nodiscard]] bool IsDense() const { return this->Info().IsDense(); } /** * \brief Load DMatrix from URI. diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index a328a6120..2ceb35821 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -9,7 +9,6 @@ #include #include #include -#include #include "allreduce.h" #include "broadcast.h" diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 8f940500f..867d671e2 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -162,6 +162,17 @@ class HistogramCuts { } return vals[bin_idx - 1]; } + + void SetDevice(DeviceOrd d) const { + this->cut_ptrs_.SetDevice(d); + this->cut_ptrs_.ConstDevicePointer(); + + this->cut_values_.SetDevice(d); + this->cut_values_.ConstDevicePointer(); + + this->min_vals_.SetDevice(d); + this->min_vals_.ConstDevicePointer(); + } }; /** diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc index 59cfd1943..d58364635 100644 --- a/src/data/ellpack_page.cc +++ b/src/data/ellpack_page.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023, XGBoost contributors + * Copyright 2019-2024, XGBoost contributors */ #ifndef XGBOOST_USE_CUDA @@ -7,15 +7,17 @@ #include +#include // for shared_ptr + // dummy implementation of EllpackPage in case CUDA is not used namespace xgboost { class EllpackPageImpl { - common::HistogramCuts cuts_; + std::shared_ptr cuts_; public: - [[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; } - [[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; } + [[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; } + [[nodiscard]] std::shared_ptr CutsShared() const { return cuts_; } }; EllpackPage::EllpackPage() = default; @@ -40,12 +42,6 @@ size_t EllpackPage::Size() const { return 0; } -[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() { - LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " - "EllpackPage is required"; - return impl_->Cuts(); -} - [[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const { LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " "EllpackPage is required"; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index d9ea85919..bfbb7f076 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -12,8 +12,8 @@ #include "../common/cuda_context.cuh" #include "../common/hist_util.cuh" #include "../common/transform_iterator.h" // MakeIndexTransformIter -#include "./ellpack_page.cuh" -#include "device_adapter.cuh" // for NoInfInData +#include "device_adapter.cuh" // for NoInfInData +#include "ellpack_page.cuh" #include "ellpack_page.h" #include "gradient_index.h" #include "xgboost/data.h" @@ -33,11 +33,6 @@ size_t EllpackPage::Size() const { return impl_->Size(); } void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); } -[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() { - CHECK(impl_); - return impl_->Cuts(); -} - [[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const { CHECK(impl_); return impl_->Cuts(); @@ -94,7 +89,8 @@ __global__ void CompressBinEllpackKernel( } // Construct an ELLPACK matrix with the given number of empty rows. -EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense, +EllpackPageImpl::EllpackPageImpl(DeviceOrd device, + std::shared_ptr cuts, bool is_dense, size_t row_stride, size_t n_rows) : is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) { monitor_.Init("ellpack_page"); @@ -105,12 +101,11 @@ EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, b monitor_.Stop("InitCompressedData"); } -EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, - const SparsePage &page, bool is_dense, - size_t row_stride, +EllpackPageImpl::EllpackPageImpl(DeviceOrd device, + std::shared_ptr cuts, + const SparsePage& page, bool is_dense, size_t row_stride, common::Span feature_types) - : cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()), - row_stride(row_stride) { + : cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()), row_stride(row_stride) { this->InitCompressedData(device); this->CreateHistIndices(device, page, feature_types); } @@ -127,9 +122,10 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP // Create the quantile sketches for the dmatrix and initialize HistogramCuts. row_stride = GetRowStride(dmat); if (!param.hess.empty()) { - cuts_ = common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess); + cuts_ = std::make_shared( + common::DeviceSketchWithHessian(ctx, dmat, param.max_bin, param.hess)); } else { - cuts_ = common::DeviceSketch(ctx, dmat, param.max_bin); + cuts_ = std::make_shared(common::DeviceSketch(ctx, dmat, param.max_bin)); } monitor_.Stop("Quantiles"); @@ -297,7 +293,7 @@ template EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense, common::Span row_counts_span, common::Span feature_types, size_t row_stride, - size_t n_rows, common::HistogramCuts const& cuts) { + size_t n_rows, std::shared_ptr cuts) { dh::safe_cuda(cudaSetDevice(device.ordinal)); *this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows); @@ -309,7 +305,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd de template EllpackPageImpl::EllpackPageImpl( \ __BATCH_T batch, float missing, DeviceOrd device, bool is_dense, \ common::Span row_counts_span, common::Span feature_types, \ - size_t row_stride, size_t n_rows, common::HistogramCuts const& cuts); + size_t row_stride, size_t n_rows, std::shared_ptr cuts); ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) @@ -359,7 +355,11 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page, common::Span ft) - : is_dense{page.IsDense()}, base_rowid{page.base_rowid}, n_rows{page.Size()}, cuts_{page.cut} { + : is_dense{page.IsDense()}, + base_rowid{page.base_rowid}, + n_rows{page.Size()}, + // This makes a copy of the cut values. + cuts_{std::make_shared(page.cut)} { auto it = common::MakeIndexTransformIter( [&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; }); row_stride = *std::max_element(it, it + page.Size()); diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index c64462082..a0fafbe74 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -23,20 +23,20 @@ struct EllpackDeviceAccessor { bool is_dense; /*! \brief Row length for ELLPACK, equal to number of features. */ size_t row_stride; - size_t base_rowid{}; - size_t n_rows{}; - common::CompressedIterator gidx_iter; + bst_idx_t base_rowid{0}; + bst_idx_t n_rows{0}; + common::CompressedIterator gidx_iter; /*! \brief Minimum value for each feature. Size equals to number of features. */ - common::Span min_fvalue; + common::Span min_fvalue; /*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ - common::Span feature_segments; + common::Span feature_segments; /*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ - common::Span gidx_fvalue_map; + common::Span gidx_fvalue_map; common::Span feature_types; - EllpackDeviceAccessor(DeviceOrd device, const common::HistogramCuts& cuts, bool is_dense, - size_t row_stride, size_t base_rowid, size_t n_rows, + EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr cuts, + bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows, common::CompressedIterator gidx_iter, common::Span feature_types) : is_dense(is_dense), @@ -46,16 +46,16 @@ struct EllpackDeviceAccessor { gidx_iter(gidx_iter), feature_types{feature_types} { if (device.IsCPU()) { - gidx_fvalue_map = cuts.cut_values_.ConstHostSpan(); - feature_segments = cuts.cut_ptrs_.ConstHostSpan(); - min_fvalue = cuts.min_vals_.ConstHostSpan(); + gidx_fvalue_map = cuts->cut_values_.ConstHostSpan(); + feature_segments = cuts->cut_ptrs_.ConstHostSpan(); + min_fvalue = cuts->min_vals_.ConstHostSpan(); } else { - cuts.cut_values_.SetDevice(device); - cuts.cut_ptrs_.SetDevice(device); - cuts.min_vals_.SetDevice(device); - gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan(); - feature_segments = cuts.cut_ptrs_.ConstDeviceSpan(); - min_fvalue = cuts.min_vals_.ConstDeviceSpan(); + cuts->cut_values_.SetDevice(device); + cuts->cut_ptrs_.SetDevice(device); + cuts->min_vals_.SetDevice(device); + gidx_fvalue_map = cuts->cut_values_.ConstDeviceSpan(); + feature_segments = cuts->cut_ptrs_.ConstDeviceSpan(); + min_fvalue = cuts->min_vals_.ConstDeviceSpan(); } } // Get a matrix element, uses binary search for look up Return NaN if missing @@ -142,13 +142,14 @@ class EllpackPageImpl { * This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo * and the given number of rows. */ - EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense, size_t row_stride, - size_t n_rows); + EllpackPageImpl(DeviceOrd device, std::shared_ptr cuts, + bool is_dense, size_t row_stride, size_t n_rows); /*! * \brief Constructor used for external memory. */ - EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, const SparsePage& page, - bool is_dense, size_t row_stride, common::Span feature_types); + EllpackPageImpl(DeviceOrd device, std::shared_ptr cuts, + const SparsePage& page, bool is_dense, size_t row_stride, + common::Span feature_types); /*! * \brief Constructor from an existing DMatrix. @@ -162,7 +163,7 @@ class EllpackPageImpl { explicit EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense, common::Span row_counts_span, common::Span feature_types, size_t row_stride, - size_t n_rows, common::HistogramCuts const& cuts); + size_t n_rows, std::shared_ptr cuts); /** * \brief Constructor from an existing CPU gradient index. */ @@ -194,8 +195,9 @@ class EllpackPageImpl { base_rowid = row_id; } - [[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; } - [[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; } + [[nodiscard]] common::HistogramCuts const& Cuts() const { return *cuts_; } + [[nodiscard]] std::shared_ptr CutsShared() const { return cuts_; } + void SetCuts(std::shared_ptr cuts) { cuts_ = cuts; } /*! \return Estimation of memory cost of this page. */ static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ; @@ -203,7 +205,7 @@ class EllpackPageImpl { /*! \brief Return the total number of symbols (total number of bins plus 1 for * not found). */ - [[nodiscard]] std::size_t NumSymbols() const { return cuts_.TotalBins() + 1; } + [[nodiscard]] std::size_t NumSymbols() const { return cuts_->TotalBins() + 1; } [[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor( DeviceOrd device, common::Span feature_types = {}) const; @@ -225,19 +227,18 @@ class EllpackPageImpl { */ void InitCompressedData(DeviceOrd device); - -public: + public: /*! \brief Whether or not if the matrix is dense. */ bool is_dense; /*! \brief Row length for ELLPACK. */ size_t row_stride; - size_t base_rowid{0}; - size_t n_rows{}; + bst_idx_t base_rowid{0}; + bst_idx_t n_rows{}; /*! \brief global index of histogram, which is stored in ELLPACK format. */ HostDeviceVector gidx_buffer; private: - common::HistogramCuts cuts_; + std::shared_ptr cuts_; common::Monitor monitor_; }; diff --git a/src/data/ellpack_page.h b/src/data/ellpack_page.h index 07d6949b1..77d1124e0 100644 --- a/src/data/ellpack_page.h +++ b/src/data/ellpack_page.h @@ -49,7 +49,6 @@ class EllpackPage { [[nodiscard]] const EllpackPageImpl* Impl() const { return impl_.get(); } EllpackPageImpl* Impl() { return impl_.get(); } - [[nodiscard]] common::HistogramCuts& Cuts(); [[nodiscard]] common::HistogramCuts const& Cuts() const; private: diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 62b29640c..3bf528ea8 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -1,60 +1,78 @@ /** - * Copyright 2019-2023, XGBoost contributors + * Copyright 2019-2024, XGBoost contributors */ #include #include // for size_t +#include // for uint64_t #include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream #include "../common/ref_resource_view.h" // for ReadVec, WriteVec -#include "ellpack_page.cuh" -#include "histogram_cut_format.h" // for ReadHistogramCuts, WriteHistogramCuts -#include "sparse_page_writer.h" // for SparsePageFormat +#include "ellpack_page.cuh" // for EllpackPage +#include "ellpack_page_raw_format.h" namespace xgboost::data { DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); -class EllpackPageRawFormat : public SparsePageFormat { - public: - bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override { - auto* impl = page->Impl(); - if (!ReadHistogramCuts(&impl->Cuts(), fi)) { - return false; - } - if (!fi->Read(&impl->n_rows)) { - return false; - } - if (!fi->Read(&impl->is_dense)) { - return false; - } - if (!fi->Read(&impl->row_stride)) { - return false; - } - if (!common::ReadVec(fi, &impl->gidx_buffer.HostVector())) { - return false; - } - if (!fi->Read(&impl->base_rowid)) { - return false; - } - dh::DefaultStream().Sync(); +namespace { +template +[[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi, HostDeviceVector* vec) { + std::uint64_t n{0}; + if (!fi->Read(&n)) { + return false; + } + if (n == 0) { return true; } - size_t Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) override { - std::size_t bytes{0}; - auto* impl = page.Impl(); - bytes += WriteHistogramCuts(impl->Cuts(), fo); - bytes += fo->Write(impl->n_rows); - bytes += fo->Write(impl->is_dense); - bytes += fo->Write(impl->row_stride); - CHECK(!impl->gidx_buffer.ConstHostVector().empty()); - bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector()); - bytes += fo->Write(impl->base_rowid); - return bytes; - } -}; + auto expected_bytes = sizeof(T) * n; -XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw) - .describe("Raw ELLPACK binary data format.") - .set_body([]() { return new EllpackPageRawFormat(); }); + auto [ptr, n_bytes] = fi->Consume(expected_bytes); + if (n_bytes != expected_bytes) { + return false; + } + + vec->SetDevice(DeviceOrd::CUDA(0)); + vec->Resize(n); + auto d_vec = vec->DeviceSpan(); + dh::safe_cuda( + cudaMemcpyAsync(d_vec.data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream())); + return true; +} +} // namespace + +[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, + common::AlignedResourceReadStream* fi) { + auto* impl = page->Impl(); + impl->SetCuts(this->cuts_); + if (!fi->Read(&impl->n_rows)) { + return false; + } + if (!fi->Read(&impl->is_dense)) { + return false; + } + if (!fi->Read(&impl->row_stride)) { + return false; + } + if (!ReadDeviceVec(fi, &impl->gidx_buffer)) { + return false; + } + if (!fi->Read(&impl->base_rowid)) { + return false; + } + return true; +} + +[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, + common::AlignedFileWriteStream* fo) { + std::size_t bytes{0}; + auto* impl = page.Impl(); + bytes += fo->Write(impl->n_rows); + bytes += fo->Write(impl->is_dense); + bytes += fo->Write(impl->row_stride); + CHECK(!impl->gidx_buffer.ConstHostVector().empty()); + bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector()); + bytes += fo->Write(impl->base_rowid); + return bytes; +} } // namespace xgboost::data diff --git a/src/data/ellpack_page_raw_format.h b/src/data/ellpack_page_raw_format.h new file mode 100644 index 000000000..5825b4896 --- /dev/null +++ b/src/data/ellpack_page_raw_format.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2024, XGBoost contributors + */ +#pragma once + +#include // for size_t +#include // for shared_ptr +#include // for move + +#include "../common/io.h" // for AlignedResourceReadStream +#include "sparse_page_writer.h" // for SparsePageFormat +#include "xgboost/data.h" // for EllpackPage + +#if !defined(XGBOOST_USE_CUDA) +#include "../common/common.h" // for AssertGPUSupport +#endif // !defined(XGBOOST_USE_CUDA)` + +namespace xgboost::common { +class HistogramCuts; +} + +namespace xgboost::data { +class EllpackPageRawFormat : public SparsePageFormat { + std::shared_ptr cuts_; + + public: + explicit EllpackPageRawFormat(std::shared_ptr cuts) + : cuts_{std::move(cuts)} {} + [[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override; + [[nodiscard]] std::size_t Write(const EllpackPage& page, + common::AlignedFileWriteStream* fo) override; +}; + +#if !defined(XGBOOST_USE_CUDA) +inline bool EllpackPageRawFormat::Read(EllpackPage*, common::AlignedResourceReadStream*) { + common::AssertGPUSupport(); + return false; +} + +inline std::size_t EllpackPageRawFormat::Write(const EllpackPage&, + common::AlignedFileWriteStream*) { + common::AssertGPUSupport(); + return 0; +} +#endif // !defined(XGBOOST_USE_CUDA) +} // namespace xgboost::data diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 1144e7a2e..66500d58b 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -21,7 +21,7 @@ void EllpackPageSource::Fetch() { auto const &csr = source_->Page(); this->page_.reset(new EllpackPage{}); auto *impl = this->page_->Impl(); - *impl = EllpackPageImpl(device_, *cuts_, *csr, is_dense_, row_stride_, feature_types_); + *impl = EllpackPageImpl(device_, cuts_, *csr, is_dense_, row_stride_, feature_types_); page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 53cb52233..f9aa128c7 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -1,35 +1,43 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ #define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ -#include +#include // for int32_t +#include // for shared_ptr +#include // for move -#include -#include -#include - -#include "../common/common.h" -#include "../common/hist_util.h" -#include "ellpack_page.h" // for EllpackPage -#include "sparse_page_source.h" +#include "../common/hist_util.h" // for HistogramCuts +#include "ellpack_page.h" // for EllpackPage +#include "ellpack_page_raw_format.h" // for EllpackPageRawFormat +#include "sparse_page_source.h" // for PageSourceIncMixIn +#include "xgboost/base.h" // for bst_idx_t +#include "xgboost/context.h" // for DeviceOrd +#include "xgboost/data.h" // for BatchParam +#include "xgboost/span.h" // for Span namespace xgboost::data { class EllpackPageSource : public PageSourceIncMixIn { bool is_dense_; - size_t row_stride_; + bst_idx_t row_stride_; BatchParam param_; common::Span feature_types_; - std::unique_ptr cuts_; + std::shared_ptr cuts_; DeviceOrd device_; + protected: + [[nodiscard]] SparsePageFormat* CreatePageFormat() const override { + cuts_->SetDevice(this->device_); + return new EllpackPageRawFormat{cuts_}; + } + public: - EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, - std::shared_ptr cache, BatchParam param, - std::unique_ptr cuts, bool is_dense, size_t row_stride, - common::Span feature_types, + EllpackPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features, + size_t n_batches, std::shared_ptr cache, BatchParam param, + std::shared_ptr cuts, bool is_dense, + bst_idx_t row_stride, common::Span feature_types, std::shared_ptr source, DeviceOrd device) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false), is_dense_{is_dense}, diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index 542d3aaeb..cd0129372 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -1,106 +1,97 @@ /** - * Copyright 2021-2024 XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ -#include // for size_t -#include // for uint8_t -#include // for underlying_type_t -#include // for vector +#include "gradient_index_format.h" +#include // for size_t +#include // for uint8_t +#include // for underlying_type_t +#include // for vector + +#include "../common/hist_util.h" // for HistogramCuts #include "../common/io.h" // for AlignedResourceReadStream #include "../common/ref_resource_view.h" // for ReadVec, WriteVec #include "gradient_index.h" // for GHistIndexMatrix -#include "histogram_cut_format.h" // for ReadHistogramCuts -#include "sparse_page_writer.h" // for SparsePageFormat namespace xgboost::data { -class GHistIndexRawFormat : public SparsePageFormat { - public: - bool Read(GHistIndexMatrix* page, common::AlignedResourceReadStream* fi) override { - CHECK(fi); +[[nodiscard]] bool GHistIndexRawFormat::Read(GHistIndexMatrix* page, + common::AlignedResourceReadStream* fi) { + CHECK(fi); - if (!ReadHistogramCuts(&page->cut, fi)) { - return false; - } - - // indptr - if (!common::ReadVec(fi, &page->row_ptr)) { - return false; - } - - // data - // - bin type - // Old gcc doesn't support reading from enum. - std::underlying_type_t uint_bin_type{0}; - if (!fi->Read(&uint_bin_type)) { - return false; - } - common::BinTypeSize size_type = static_cast(uint_bin_type); - // - index buffer - if (!common::ReadVec(fi, &page->data)) { - return false; - } - // - index - page->index = - common::Index{common::Span{page->data.data(), static_cast(page->data.size())}, - size_type}; - - // hit count - if (!common::ReadVec(fi, &page->hit_count)) { - return false; - } - if (!fi->Read(&page->max_numeric_bins_per_feat)) { - return false; - } - if (!fi->Read(&page->base_rowid)) { - return false; - } - bool is_dense = false; - if (!fi->Read(&is_dense)) { - return false; - } - page->SetDense(is_dense); - if (is_dense) { - page->index.SetBinOffset(page->cut.Ptrs()); - } - - if (!page->ReadColumnPage(fi)) { - return false; - } - return true; + page->Cuts() = this->cuts_; + // indptr + if (!common::ReadVec(fi, &page->row_ptr)) { + return false; } - std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override { - std::size_t bytes = 0; - bytes += WriteHistogramCuts(page.cut, fo); - // indptr - bytes += common::WriteVec(fo, page.row_ptr); - - // data - // - bin type - std::underlying_type_t uint_bin_type = page.index.GetBinTypeSize(); - bytes += fo->Write(uint_bin_type); - // - index buffer - std::vector data(page.index.begin(), page.index.end()); - bytes += fo->Write(static_cast(data.size())); - if (!data.empty()) { - bytes += fo->Write(data.data(), data.size()); - } - - // hit count - bytes += common::WriteVec(fo, page.hit_count); - // max_bins, base row, is_dense - bytes += fo->Write(page.max_numeric_bins_per_feat); - bytes += fo->Write(page.base_rowid); - bytes += fo->Write(page.IsDense()); - - bytes += page.WriteColumnPage(fo); - return bytes; + // data + // - bin type + // Old gcc doesn't support reading from enum. + std::underlying_type_t uint_bin_type{0}; + if (!fi->Read(&uint_bin_type)) { + return false; } -}; + common::BinTypeSize size_type = static_cast(uint_bin_type); + // - index buffer + if (!common::ReadVec(fi, &page->data)) { + return false; + } + // - index + page->index = common::Index{ + common::Span{page->data.data(), static_cast(page->data.size())}, size_type}; + + // hit count + if (!common::ReadVec(fi, &page->hit_count)) { + return false; + } + if (!fi->Read(&page->max_numeric_bins_per_feat)) { + return false; + } + if (!fi->Read(&page->base_rowid)) { + return false; + } + bool is_dense = false; + if (!fi->Read(&is_dense)) { + return false; + } + page->SetDense(is_dense); + if (is_dense) { + page->index.SetBinOffset(page->cut.Ptrs()); + } + + if (!page->ReadColumnPage(fi)) { + return false; + } + return true; +} + +[[nodiscard]] std::size_t GHistIndexRawFormat::Write(GHistIndexMatrix const& page, + common::AlignedFileWriteStream* fo) { + std::size_t bytes = 0; + // indptr + bytes += common::WriteVec(fo, page.row_ptr); + + // data + // - bin type + std::underlying_type_t uint_bin_type = page.index.GetBinTypeSize(); + bytes += fo->Write(uint_bin_type); + // - index buffer + std::vector data(page.index.begin(), page.index.end()); + bytes += fo->Write(static_cast(data.size())); + if (!data.empty()) { + bytes += fo->Write(data.data(), data.size()); + } + + // hit count + bytes += common::WriteVec(fo, page.hit_count); + // max_bins, base row, is_dense + bytes += fo->Write(page.max_numeric_bins_per_feat); + bytes += fo->Write(page.base_rowid); + bytes += fo->Write(page.IsDense()); + + bytes += page.WriteColumnPage(fo); + return bytes; +} DMLC_REGISTRY_FILE_TAG(gradient_index_format); - -XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw) - .describe("Raw GHistIndex binary data format.") - .set_body([]() { return new GHistIndexRawFormat(); }); } // namespace xgboost::data diff --git a/src/data/gradient_index_format.h b/src/data/gradient_index_format.h new file mode 100644 index 000000000..438f189d1 --- /dev/null +++ b/src/data/gradient_index_format.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021-2024, XGBoost contributors + */ +#pragma once + +#include // for size_t +#include // for move + +#include "../common/hist_util.h" // for HistogramCuts +#include "../common/io.h" // for AlignedFileWriteStream +#include "gradient_index.h" // for GHistIndexMatrix +#include "sparse_page_writer.h" // for SparsePageFormat + +namespace xgboost::common { +class HistogramCuts; +} + +namespace xgboost::data { +class GHistIndexRawFormat : public SparsePageFormat { + common::HistogramCuts cuts_; + + public: + [[nodiscard]] bool Read(GHistIndexMatrix* page, common::AlignedResourceReadStream* fi) override; + [[nodiscard]] std::size_t Write(GHistIndexMatrix const& page, + common::AlignedFileWriteStream* fo) override; + + explicit GHistIndexRawFormat(common::HistogramCuts cuts) : cuts_{std::move(cuts)} {} +}; +} // namespace xgboost::data diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 1b2ed3fdd..f1ceb282a 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include "gradient_index_page_source.h" diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index db71c1c6d..c525d51d1 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -1,27 +1,39 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ #define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ -#include -#include +#include // for isnan +#include // for int32_t +#include // for shared_ptr +#include // for move -#include "gradient_index.h" -#include "sparse_page_source.h" +#include "../common/hist_util.h" // for HistogramCuts +#include "gradient_index.h" // for GHistIndexMatrix +#include "gradient_index_format.h" // for GHistIndexRawFormat +#include "sparse_page_source.h" // for PageSourceIncMixIn +#include "xgboost/base.h" // for bst_feature_t +#include "xgboost/data.h" // for BatchParam, FeatureType +#include "xgboost/span.h" // for Span namespace xgboost { namespace data { class GradientIndexPageSource : public PageSourceIncMixIn { common::HistogramCuts cuts_; bool is_dense_; - int32_t max_bin_per_feat_; + std::int32_t max_bin_per_feat_; common::Span feature_types_; double sparse_thresh_; + protected: + [[nodiscard]] SparsePageFormat* CreatePageFormat() const override { + return new GHistIndexRawFormat{cuts_}; + } + public: - GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, - std::shared_ptr cache, BatchParam param, + GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features, + size_t n_batches, std::shared_ptr cache, BatchParam param, common::HistogramCuts cuts, bool is_dense, common::Span feature_types, std::shared_ptr source) diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 69a7b1aa2..868875bf7 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -54,7 +54,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, /** * Generate quantiles */ - common::HistogramCuts cuts; + auto cuts = std::make_shared(); do { // We use do while here as the first batch is fetched in ctor CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs()); @@ -104,9 +104,9 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, sketch_containers.clear(); sketch_containers.shrink_to_fit(); - final_sketch.MakeCuts(ctx, &cuts, this->info_.IsColumnSplit()); + final_sketch.MakeCuts(ctx, cuts.get(), this->info_.IsColumnSplit()); } else { - GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts); + GetCutsFromRef(ctx, ref, Info().num_col_, p, cuts.get()); } this->info_.num_row_ = accumulated_rows; diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 1e76f8601..14a99370a 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -1,7 +1,7 @@ /** * Copyright 2021-2024, XGBoost contributors */ -#include // for unique_ptr +#include // for shared_ptr #include "../common/hist_util.cuh" #include "../common/hist_util.h" // for HistogramCuts @@ -26,13 +26,13 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, // reinitialize the cache cache_info_.erase(id); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); - std::unique_ptr cuts; + std::shared_ptr cuts; if (!param.hess.empty()) { - cuts = std::make_unique( + cuts = std::make_shared( common::DeviceSketchWithHessian(ctx, this, param.max_bin, param.hess)); } else { cuts = - std::make_unique(common::DeviceSketch(ctx, this, param.max_bin)); + std::make_shared(common::DeviceSketch(ctx, this, param.max_bin)); } this->InitializeSparsePage(ctx); // reset after use. diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 6e8ebd33c..427325a74 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -7,6 +7,7 @@ #include // for min #include // for atomic +#include // for uint64_t #include // for remove #include // for future #include // for unique_ptr @@ -72,9 +73,13 @@ struct Cache { */ [[nodiscard]] auto View(std::size_t i) const { std::uint64_t off = offset.at(i); - std::uint64_t len = offset.at(i + 1) - offset[i]; + std::uint64_t len = this->Bytes(i); return std::pair{off, len}; } + /** + * @brief Get the number of bytes for the i^th page. + */ + [[nodiscard]] std::uint64_t Bytes(std::size_t i) const { return offset.at(i + 1) - offset[i]; } /** * @brief Call this once the write for the cache is complete. */ @@ -174,6 +179,10 @@ class SparsePageSourceImpl : public BatchIteratorImpl { ExceHandler exce_; common::Monitor monitor_; + [[nodiscard]] virtual SparsePageFormat* CreatePageFormat() const { + return ::xgboost::data::CreatePageFormat("raw"); + } + [[nodiscard]] bool ReadCache() { CHECK(!at_end_); if (!cache_info_->written) { @@ -207,7 +216,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { *GlobalConfigThreadLocalStore::Get() = config; auto page = std::make_shared(); this->exce_.Run([&] { - std::unique_ptr> fmt{CreatePageFormat("raw")}; + std::unique_ptr> fmt{this->CreatePageFormat()}; auto name = self->cache_info_->ShardName(); auto [offset, length] = self->cache_info_->View(fetch_it); auto fi = std::make_unique(name, offset, length); @@ -234,7 +243,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { CHECK(!cache_info_->written); common::Timer timer; timer.Start(); - std::unique_ptr> fmt{CreatePageFormat("raw")}; + std::unique_ptr> fmt{this->CreatePageFormat()}; auto name = cache_info_->ShardName(); std::unique_ptr fo; @@ -404,8 +413,8 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { bool sync_{true}; public: - PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, std::uint32_t n_batches, - std::shared_ptr cache, bool sync) + PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features, + std::uint32_t n_batches, std::shared_ptr cache, bool sync) : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} [[nodiscard]] PageSourceIncMixIn& operator++() final { diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index c909d817d..989c03d33 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -10,7 +10,6 @@ #include // for string #include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream -#include "dmlc/io.h" // for Stream #include "dmlc/registry.h" // for Registry, FunctionRegEntryBase #include "xgboost/data.h" // for SparsePage,CSCPage,SortedCSCPage,EllpackPage ... diff --git a/src/global_config.cc b/src/global_config.cc index d342e3c3e..ec8353c56 100644 --- a/src/global_config.cc +++ b/src/global_config.cc @@ -7,7 +7,6 @@ #include #include "xgboost/global_config.h" -#include "xgboost/json.h" namespace xgboost { DMLC_REGISTER_PARAMETER(GlobalConfiguration); diff --git a/src/tree/gpu_hist/feature_groups.cuh b/src/tree/gpu_hist/feature_groups.cuh index 671272822..82df69796 100644 --- a/src/tree/gpu_hist/feature_groups.cuh +++ b/src/tree/gpu_hist/feature_groups.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #ifndef FEATURE_GROUPS_CUH_ #define FEATURE_GROUPS_CUH_ @@ -29,22 +29,23 @@ struct FeatureGroup { /** The number of features in the group. */ int num_features; /** The first bin in the group. */ - int start_bin; + bst_bin_t start_bin; /** The number of bins in the group. */ - int num_bins; + bst_bin_t num_bins; }; /** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */ struct FeatureGroupsAccessor { FeatureGroupsAccessor(common::Span feature_segments_, - common::Span bin_segments_, int max_group_bins_) : - feature_segments(feature_segments_), bin_segments(bin_segments_), - max_group_bins(max_group_bins_) {} - + common::Span bin_segments_, int max_group_bins_) + : feature_segments(feature_segments_), + bin_segments(bin_segments_), + max_group_bins(max_group_bins_) {} + common::Span feature_segments; common::Span bin_segments; int max_group_bins; - + /** \brief Gets the number of feature groups. */ __host__ __device__ int NumGroups() const { return feature_segments.size() - 1; @@ -84,7 +85,7 @@ struct FeatureGroups { /** Maximum number of bins in a group. Useful to compute the amount of dynamic shared memory when launching a kernel. */ int max_group_bins; - + /** Creates feature groups by splitting features into groups. \param cuts Histogram cuts that given the number of bins per feature. \param is_dense Whether the data matrix is dense. @@ -110,7 +111,7 @@ struct FeatureGroups { private: void InitSingle(const common::HistogramCuts& cuts); -}; +}; } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index f9a3819ad..7aefebeb6 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -167,7 +167,7 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, for (auto& batch : dmat->GetBatches(ctx, batch_param_)) { auto page = batch.Impl(); if (!page_) { - page_ = std::make_unique(ctx->Device(), page->Cuts(), page->is_dense, + page_ = std::make_unique(ctx->Device(), page->CutsShared(), page->is_dense, page->row_stride, dmat->Info().num_row_); } size_t num_elements = page_->Copy(ctx->Device(), page, offset); @@ -228,7 +228,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, auto first_page = (*batch_iterator.begin()).Impl(); // Create a new ELLPACK page with empty rows. page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl(ctx->Device(), first_page->Cuts(), first_page->is_dense, + page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense, first_page->row_stride, sample_rows)); // Compact the ELLPACK pages into the single sample page. @@ -306,7 +306,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c auto first_page = (*batch_iterator.begin()).Impl(); // Create a new ELLPACK page with empty rows. page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl(ctx->Device(), first_page->Cuts(), first_page->is_dense, + page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense, first_page->row_stride, sample_rows)); // Compact the ELLPACK pages into the single sample page. diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index e37f02ddb..df5ed9004 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -56,8 +56,7 @@ TEST(HistUtil, DeviceSketch) { TEST(HistUtil, SketchBatchNumElements) { #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - LOG(WARNING) << "Test not runnable with RMM enabled."; - return; + GTEST_SKIP_("Test not runnable with RMM enabled."); #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 size_t constexpr kCols = 10000; int device; diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index ab4539fd4..924a458d7 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -152,7 +152,7 @@ TEST(EllpackPage, Copy) { auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); // Create an empty result page. - EllpackPageImpl result(FstCU(), page->Cuts(), page->is_dense, page->row_stride, kRows); + EllpackPageImpl result(FstCU(), page->CutsShared(), page->is_dense, page->row_stride, kRows); // Copy batch pages into the result page. size_t offset = 0; @@ -200,7 +200,8 @@ TEST(EllpackPage, Compact) { auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); // Create an empty result page. - EllpackPageImpl result(FstCU(), page->Cuts(), page->is_dense, page->row_stride, kCompactedRows); + EllpackPageImpl result(FstCU(), page->CutsShared(), page->is_dense, page->row_stride, + kCompactedRows); // Compact batch pages into the result page. std::vector row_indexes_h { diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index f69b7b63a..c50c3bee2 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -1,14 +1,14 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #include #include #include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream... #include "../../../src/data/ellpack_page.cuh" -#include "../../../src/data/sparse_page_source.h" -#include "../../../src/tree/param.h" // TrainParam -#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat +#include "../../../src/tree/param.h" // TrainParam +#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" namespace xgboost::data { @@ -16,12 +16,18 @@ TEST(EllpackPageRawFormat, IO) { Context ctx{MakeCUDACtx(0)}; auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; - std::unique_ptr> format{CreatePageFormat("raw")}; - auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/ellpack.page"; + std::shared_ptr cuts; + for (auto const& page : m->GetBatches(&ctx, param)) { + cuts = page.Impl()->CutsShared(); + } + + cuts->SetDevice(ctx.Device()); + auto format = std::make_unique(cuts); + std::size_t n_bytes{0}; { auto fo = std::make_unique(StringView{path}, "wb"); @@ -33,7 +39,7 @@ TEST(EllpackPageRawFormat, IO) { EllpackPage page; std::unique_ptr fi{ std::make_unique(path.c_str(), 0, n_bytes)}; - format->Read(&page, fi.get()); + ASSERT_TRUE(format->Read(&page, fi.get())); for (auto const &ellpack : m->GetBatches(&ctx, param)) { auto loaded = page.Impl(); diff --git a/tests/cpp/data/test_gradient_index_page_raw_format.cc b/tests/cpp/data/test_gradient_index_page_raw_format.cc index a327b319c..2c2a4b1b1 100644 --- a/tests/cpp/data/test_gradient_index_page_raw_format.cc +++ b/tests/cpp/data/test_gradient_index_page_raw_format.cc @@ -7,23 +7,28 @@ #include // for size_t #include // for unique_ptr -#include "../../../src/common/column_matrix.h" -#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream... -#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix -#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat -#include "../helpers.h" // for RandomDataGenerator +#include "../../../src/common/column_matrix.h" // for common::ColumnMatrix +#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream... +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#include "../../../src/data/gradient_index_format.h" // for GHistIndexRawFormat +#include "../helpers.h" // for RandomDataGenerator namespace xgboost::data { TEST(GHistIndexPageRawFormat, IO) { Context ctx; - std::unique_ptr> format{ - CreatePageFormat("raw")}; auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/ghistindex.page"; auto batch = BatchParam{256, 0.5}; + common::HistogramCuts cuts; + for (auto const &index : m->GetBatches(&ctx, batch)) { + cuts = index.Cuts(); + break; + } + auto format = std::make_unique(std::move(cuts)); + std::size_t bytes{0}; { auto fo = std::make_unique(StringView{path}, "wb"); @@ -36,7 +41,7 @@ TEST(GHistIndexPageRawFormat, IO) { std::unique_ptr fi{ std::make_unique(path, 0, bytes)}; - format->Read(&page, fi.get()); + ASSERT_TRUE(format->Read(&page, fi.get())); for (auto const &gidx : m->GetBatches(&ctx, batch)) { auto const &loaded = gidx; diff --git a/tests/cpp/data/test_iterative_dmatrix.cu b/tests/cpp/data/test_iterative_dmatrix.cu index f7985df45..503cb7696 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cu +++ b/tests/cpp/data/test_iterative_dmatrix.cu @@ -20,9 +20,8 @@ void TestEquivalent(float sparsity) { std::numeric_limits::quiet_NaN(), 0, 256); std::size_t offset = 0; auto first = (*m.GetEllpackBatches(&ctx, {}).begin()).Impl(); - std::unique_ptr page_concatenated { - new EllpackPageImpl(ctx.Device(), first->Cuts(), first->is_dense, - first->row_stride, 1000 * 100)}; + std::unique_ptr page_concatenated{new EllpackPageImpl( + ctx.Device(), first->CutsShared(), first->is_dense, first->row_stride, 1000 * 100)}; for (auto& batch : m.GetBatches(&ctx, {})) { auto page = batch.Impl(); size_t num_elements = page_concatenated->Copy(ctx.Device(), page, offset); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 25acb038c..33308be19 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -171,6 +171,12 @@ TEST(SparsePageDMatrix, GHistIndexSkipSparsePage) { // Restore the batch parameter by passing it in again through check_ghist check_ghist(); } + // half the pages + auto it = Xy->GetBatches(&ctx).begin(); + for (std::int32_t i = 0; i < 3; ++i) { + ++it; + } + check_ghist(); } TEST(SparsePageDMatrix, MetaInfo) { diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 2aff98375..5783caa37 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -164,9 +164,9 @@ TEST(SparsePageDMatrix, EllpackPageContent) { size_t offset = 0; for (auto& batch : dmat_ext->GetBatches(&ctx, param)) { if (!impl_ext) { - impl_ext = std::make_unique(batch.Impl()->gidx_buffer.Device(), - batch.Impl()->Cuts(), batch.Impl()->is_dense, - batch.Impl()->row_stride, kRows); + impl_ext = std::make_unique( + batch.Impl()->gidx_buffer.Device(), batch.Impl()->CutsShared(), batch.Impl()->is_dense, + batch.Impl()->row_stride, kRows); } auto n_elems = impl_ext->Copy(ctx.Device(), batch.Impl(), offset); offset += n_elems; diff --git a/tests/cpp/histogram_helpers.h b/tests/cpp/histogram_helpers.h index 8f345484d..a33d6958a 100644 --- a/tests/cpp/histogram_helpers.h +++ b/tests/cpp/histogram_helpers.h @@ -13,31 +13,25 @@ namespace xgboost { #if defined(__CUDACC__) -namespace { +namespace detail { class HistogramCutsWrapper : public common::HistogramCuts { public: using SuperT = common::HistogramCuts; - void SetValues(std::vector cuts) { - SuperT::cut_values_.HostVector() = std::move(cuts); - } - void SetPtrs(std::vector ptrs) { - SuperT::cut_ptrs_.HostVector() = std::move(ptrs); - } - void SetMins(std::vector mins) { - SuperT::min_vals_.HostVector() = std::move(mins); - } + void SetValues(std::vector cuts) { SuperT::cut_values_.HostVector() = std::move(cuts); } + void SetPtrs(std::vector ptrs) { SuperT::cut_ptrs_.HostVector() = std::move(ptrs); } + void SetMins(std::vector mins) { SuperT::min_vals_.HostVector() = std::move(mins); } }; -} // anonymous namespace +} // namespace detail inline std::unique_ptr BuildEllpackPage(int n_rows, int n_cols, bst_float sparsity = 0) { auto dmat = RandomDataGenerator(n_rows, n_cols, sparsity).Seed(3).GenerateDMatrix(); const SparsePage& batch = *dmat->GetBatches().begin(); - HistogramCutsWrapper cmat; - cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); + auto cmat = std::make_shared(); + cmat->SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); // 24 cut fields, 3 cut fields for each feature (column). - cmat.SetValues({0.30f, 0.67f, 1.64f, + cmat->SetValues({0.30f, 0.67f, 1.64f, 0.32f, 0.77f, 1.95f, 0.29f, 0.70f, 1.80f, 0.32f, 0.75f, 1.85f, @@ -45,7 +39,7 @@ inline std::unique_ptr BuildEllpackPage(int n_rows, int n_cols, 0.25f, 0.74f, 2.00f, 0.26f, 0.74f, 1.98f, 0.26f, 0.71f, 1.83f}); - cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); + cmat->SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); bst_idx_t row_stride = 0; const auto &offset_vec = batch.offset.ConstHostVector(); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index c3a949008..cc4d9fb7f 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -150,13 +150,13 @@ TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); } -HistogramCutsWrapper GetHostCutMatrix () { - HistogramCutsWrapper cmat; - cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); - cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); +std::shared_ptr GetHostCutMatrix () { + auto cmat = std::make_shared(); + cmat->SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); + cmat->SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); // 24 cut fields, 3 cut fields for each feature (column). // Each row of the cut represents the cuts for a data column. - cmat.SetValues({0.30f, 0.67f, 1.64f, + cmat->SetValues({0.30f, 0.67f, 1.64f, 0.32f, 0.77f, 1.95f, 0.29f, 0.70f, 1.80f, 0.32f, 0.75f, 1.85f,