diff --git a/doc/c.rst b/doc/c.rst index d63e779e1..9a9d7b557 100644 --- a/doc/c.rst +++ b/doc/c.rst @@ -33,6 +33,8 @@ DMatrix .. doxygengroup:: DMatrix :project: xgboost +.. _c_streaming: + Streaming --------- diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 3562015e2..fa487f1c8 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -54,6 +54,9 @@ on a dask cluster: y = da.random.random(size=(num_obs, 1), chunks=(1000, 1)) dtrain = xgb.dask.DaskDMatrix(client, X, y) + # or + # dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y) + # `DaskQuantileDMatrix` is available for the `hist` and `gpu_hist` tree method. output = xgb.dask.train( client, diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index f5b6132c7..832d13edd 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -22,6 +22,15 @@ GPU-based training algorithm. We will introduce them in the following sections. The feature is still experimental as of 2.0. The performance is not well optimized. +The external memory support has gone through multiple iterations and is still under heavy +development. Like the :py:class:`~xgboost.QuantileDMatrix` with +:py:class:`~xgboost.DataIter`, XGBoost loads data batch-by-batch using a custom iterator +supplied by the user. However, unlike the :py:class:`~xgboost.QuantileDMatrix`, external +memory will not concatenate the batches unless GPU is used (it uses a hybrid approach, +more details follow). Instead, it will cache all batches on the external memory and fetch +them on-demand. Go to the end of the document to see a comparison between +`QuantileDMatrix` and external memory. + ************* Data Iterator ************* @@ -113,10 +122,11 @@ External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set ``gpu_hist``). However, the algorithm used for GPU is different from the one used for CPU. When training on a CPU, the tree method iterates through all batches from external memory for each step of the tree construction algorithm. On the other hand, the GPU -algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall -memory usage, users can utilize subsampling. The good news is that the GPU hist tree -method supports gradient-based sampling, enabling users to set a low sampling rate without -compromising accuracy. +algorithm uses a hybrid approach. It iterates through the data during the beginning of +each iteration and concatenates all batches into one in GPU memory. To reduce overall +memory usage, users can utilize subsampling. The GPU hist tree method supports +`gradient-based sampling`, enabling users to set a low sampling rate without compromising +accuracy. .. code-block:: python @@ -134,6 +144,8 @@ see `this paper `_. When GPU is running out of memory during iteration on external memory, user might recieve a segfault instead of an OOM exception. +.. _ext_remarks: + ******* Remarks ******* @@ -142,17 +154,64 @@ When using external memory with XBGoost, data is divided into smaller chunks so a fraction of it needs to be stored in memory at any given time. It's important to note that this method only applies to the predictor data (``X``), while other data, like labels and internal runtime structures are concatenated. This means that memory reduction is most -effective when dealing with wide datasets where ``X`` is larger compared to other data -like ``y``, while it has little impact on slim datasets. +effective when dealing with wide datasets where ``X`` is significantly larger in size +compared to other data like ``y``, while it has little impact on slim datasets. + +As one might expect, fetching data on-demand puts significant pressure on the storage +device. Today's computing device can process way more data than a storage can read in a +single unit of time. The ratio is at order of magnitudes. An GPU is capable of processing +hundred of Gigabytes of floating-point data in a split second. On the other hand, a +four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data transfer +rate. As a result, the training is likely to be severely bounded by your storage +device. Before adopting the external memory solution, some back-of-envelop calculations +might help you see whether it's viable. For instance, if your NVMe drive can transfer 4GB +(a fairly practical number) of data per second and you have a 100GB of data in compressed +XGBoost cache (which corresponds to a dense float32 numpy array with the size of 200GB, +give or take). A tree with depth 8 needs at least 16 iterations through the data when the +parameter is right. You need about 14 minutes to train a single tree without accounting +for some other overheads and assume the computation overlaps with the IO. If your dataset +happens to have TB-level size, then you might need thousands of trees to get a generalized +model. These calculations can help you get an estimate on the expected training time. + +However, sometimes we can ameliorate this limitation. One should also consider that the OS +(mostly talking about the Linux kernel) can usually cache the data on host memory. It only +evicts pages when new data comes in and there's no room left. In practice, at least some +portion of the data can persist on the host memory throughout the entire training +session. We are aware of this cache when optimizing the external memory fetcher. The +compressed cache is usually smaller than the raw input data, especially when the input is +dense without any missing value. If the host memory can fit a significant portion of this +compressed cache, then the performance should be decent after initialization. Our +development so far focus on two fronts of optimization for external memory: + +- Avoid iterating through the data whenever appropriate. +- If the OS can cache the data, the performance should be close to in-core training. Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not -yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's -worth noting that most tests have been conducted on Linux distributions. +tested against system errors like disconnected network devices (`SIGBUS`). In the face of +a bus error, you will see a hard crash and need to clean up the cache files. If the +training session might take a long time and you are using solutions like NVMe-oF, we +recommend checkpointing your model periodically. Also, it's worth noting that most tests +have been conducted on Linux distributions. + Another important point to keep in mind is that creating the initial cache for XGBoost may -take some time. The interface to external memory is through custom iterators, which may or -may not be thread-safe. Therefore, initialization is performed sequentially. +take some time. The interface to external memory is through custom iterators, which we can +not assume to be thread-safe. Therefore, initialization is performed sequentially. Using +the `xgboost.config_context` with `verbosity=2` can give you some information on what +XGBoost is doing during the wait if you don't mind the extra output. +******************************* +Compared to the QuantileDMatrix +******************************* + +Passing an iterator to the :py:class:`~xgboost.QuantileDmatrix` enables direct +construction of `QuantileDmatrix` with data chunks. On the other hand, if it's passed to +:py:class:`~xgboost.DMatrix`, it instead enables the external memory feature. The +:py:class:`~xgboost.QuantileDmatrix` concatenates the data on memory after compression and +doesn't fetch data during training. On the other hand, the external memory `DMatrix` +fetches data batches from external memory on-demand. Use the `QuantileDMatrix` (with +iterator if necessary) when you can fit most of your data in memory. The training would be +an order of magnitute faster than using external memory. **************** Text File Inputs diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index eb8c23726..7693173e9 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -11,22 +11,22 @@ See `Awesome XGBoost `_ for mo model saving_model + learning_to_rank + dart + monotonic + feature_interaction_constraint + aft_survival_analysis + categorical + multioutput + rf kubernetes Distributed XGBoost with XGBoost4J-Spark Distributed XGBoost with XGBoost4J-Spark-GPU dask spark_estimator ray - dart - monotonic - rf - feature_interaction_constraint - learning_to_rank - aft_survival_analysis + external_memory c_api_tutorial input_format param_tuning - external_memory custom_metric_obj - categorical - multioutput diff --git a/doc/tutorials/param_tuning.rst b/doc/tutorials/param_tuning.rst index cce145444..cb58fcc20 100644 --- a/doc/tutorials/param_tuning.rst +++ b/doc/tutorials/param_tuning.rst @@ -58,3 +58,46 @@ This can affect the training of XGBoost model, and there are two ways to improve - In such a case, you cannot re-balance the dataset - Set parameter ``max_delta_step`` to a finite number (say 1) to help convergence + + +********************* +Reducing Memory Usage +********************* + +If you are using a HPO library like :py:class:`sklearn.model_selection.GridSearchCV`, +please control the number of threads it can use. It's best to let XGBoost to run in +parallel instead of asking `GridSearchCV` to run multiple experiments at the same +time. For instance, creating a fold of data for cross validation can consume a significant +amount of memory: + +.. code-block:: python + + # This creates a copy of dataset. X and X_train are both in memory at the same time. + + # This happens for every thread at the same time if you run `GridSearchCV` with + # `n_jobs` larger than 1 + + X_train, X_test, y_train, y_test = train_test_split(X, y) + +.. code-block:: python + + df = pd.DataFrame() + # This creates a new copy of the dataframe, even if you specify the inplace parameter + new_df = df.drop(...) + +.. code-block:: python + + array = np.array(...) + # This may or may not make a copy of the data, depending on the type of the data + array.astype(np.float32) + +.. code-block:: + + # np by default uses double, do you actually need it? + array = np.array(...) + +You can find some more specific memory reduction practices scattered through the documents +For instances: :doc:`/tutorials/dask`, :doc:`/gpu/index`, +:doc:`/contrib/scaling`. However, before going into these, being conscious about making +data copies is a good starting point. It usually consumes a lot more memory than people +expect. diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h index a12e1decd..d93f32ff9 100644 --- a/rabit/include/rabit/internal/io.h +++ b/rabit/include/rabit/internal/io.h @@ -19,8 +19,7 @@ #include "rabit/internal/utils.h" #include "rabit/serializable.h" -namespace rabit { -namespace utils { +namespace rabit::utils { /*! \brief re-use definition of dmlc::SeekStream */ using SeekStream = dmlc::SeekStream; /** @@ -31,9 +30,6 @@ struct MemoryFixSizeBuffer : public SeekStream { // similar to SEEK_END in libc static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); - protected: - MemoryFixSizeBuffer() = default; - public: /** * @brief Ctor @@ -68,7 +64,7 @@ struct MemoryFixSizeBuffer : public SeekStream { * @brief Current position in the buffer (stream). */ std::size_t Tell() override { return curr_ptr_; } - virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } + [[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } protected: /*! \brief in memory buffer */ @@ -119,6 +115,5 @@ struct MemoryBufferStream : public SeekStream { /*! \brief current pointer */ size_t curr_ptr_; }; // class MemoryBufferStream -} // namespace utils -} // namespace rabit +} // namespace rabit::utils #endif // RABIT_INTERNAL_IO_H_ diff --git a/src/common/column_matrix.cc b/src/common/column_matrix.cc index d8acfa7a5..1d44f1840 100644 --- a/src/common/column_matrix.cc +++ b/src/common/column_matrix.cc @@ -1,16 +1,27 @@ -/*! - * Copyright 2017-2022 by XGBoost Contributors +/** + * Copyright 2017-2023, XGBoost Contributors * \brief Utility for fast column-wise access */ #include "column_matrix.h" -namespace xgboost { -namespace common { +#include // for transform +#include // for size_t +#include // for uint64_t, uint8_t +#include // for numeric_limits +#include // for remove_reference_t +#include // for vector + +#include "../data/gradient_index.h" // for GHistIndexMatrix +#include "io.h" // for AlignedResourceReadStream, AlignedFileWriteStream +#include "xgboost/base.h" // for bst_feaature_t +#include "xgboost/span.h" // for Span + +namespace xgboost::common { void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold) { auto const nfeature = gmat.Features(); const size_t nrow = gmat.Size(); // identify type of each column - type_.resize(nfeature); + type_ = common::MakeFixedVecWithMalloc(nfeature, ColumnType{}); uint32_t max_val = std::numeric_limits::max(); for (bst_feature_t fid = 0; fid < nfeature; ++fid) { @@ -34,7 +45,7 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres // want to compute storage boundary for each feature // using variants of prefix sum scan - feature_offsets_.resize(nfeature + 1); + feature_offsets_ = common::MakeFixedVecWithMalloc(nfeature + 1, std::size_t{0}); size_t accum_index = 0; feature_offsets_[0] = accum_index; for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) { @@ -49,9 +60,11 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres SetTypeSize(gmat.MaxNumBinPerFeat()); auto storage_size = feature_offsets_.back() * static_cast>(bins_type_size_); - index_.resize(storage_size, 0); + + index_ = common::MakeFixedVecWithMalloc(storage_size, std::uint8_t{0}); + if (!all_dense_column) { - row_ind_.resize(feature_offsets_[nfeature]); + row_ind_ = common::MakeFixedVecWithMalloc(feature_offsets_[nfeature], std::size_t{0}); } // store least bin id for each feature @@ -59,7 +72,51 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres any_missing_ = !gmat.IsDense(); - missing_flags_.clear(); + missing_ = MissingIndicator{0, false}; } -} // namespace common -} // namespace xgboost + +// IO procedures for external memory. +bool ColumnMatrix::Read(AlignedResourceReadStream* fi, uint32_t const* index_base) { + if (!common::ReadVec(fi, &index_)) { + return false; + } + if (!common::ReadVec(fi, &type_)) { + return false; + } + if (!common::ReadVec(fi, &row_ind_)) { + return false; + } + if (!common::ReadVec(fi, &feature_offsets_)) { + return false; + } + + if (!common::ReadVec(fi, &missing_.storage)) { + return false; + } + missing_.InitView(); + + index_base_ = index_base; + if (!fi->Read(&bins_type_size_)) { + return false; + } + if (!fi->Read(&any_missing_)) { + return false; + } + return true; +} + +std::size_t ColumnMatrix::Write(AlignedFileWriteStream* fo) const { + std::size_t bytes{0}; + + bytes += common::WriteVec(fo, index_); + bytes += common::WriteVec(fo, type_); + bytes += common::WriteVec(fo, row_ind_); + bytes += common::WriteVec(fo, feature_offsets_); + bytes += common::WriteVec(fo, missing_.storage); + + bytes += fo->Write(bins_type_size_); + bytes += fo->Write(any_missing_); + + return bytes; +} +} // namespace xgboost::common diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index f121b7a46..78361744d 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2017-2022 by Contributors +/** + * Copyright 2017-2023, XGBoost Contributors * \file column_matrix.h * \brief Utility for fast column-wise access * \author Philip Cho @@ -8,25 +8,30 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ -#include - #include +#include // for size_t +#include // for uint8_t #include #include -#include // std::move +#include // for move #include #include "../data/adapter.h" #include "../data/gradient_index.h" #include "algorithm.h" +#include "bitfield.h" // for RBitField8 #include "hist_util.h" +#include "ref_resource_view.h" // for RefResourceView +#include "xgboost/base.h" // for bst_bin_t +#include "xgboost/span.h" // for Span -namespace xgboost { -namespace common { - +namespace xgboost::common { class ColumnMatrix; +class AlignedFileWriteStream; +class AlignedResourceReadStream; + /*! \brief column type */ -enum ColumnType : uint8_t { kDenseColumn, kSparseColumn }; +enum ColumnType : std::uint8_t { kDenseColumn, kSparseColumn }; /*! \brief a column storage, to be used with ApplySplit. Note that each bin id is stored as index[i] + index_base. @@ -41,12 +46,12 @@ class Column { : index_(index), index_base_(least_bin_idx) {} virtual ~Column() = default; - bst_bin_t GetGlobalBinIdx(size_t idx) const { + [[nodiscard]] bst_bin_t GetGlobalBinIdx(size_t idx) const { return index_base_ + static_cast(index_[idx]); } /* returns number of elements in column */ - size_t Size() const { return index_.size(); } + [[nodiscard]] size_t Size() const { return index_.size(); } private: /* bin indexes in range [0, max_bins - 1] */ @@ -63,7 +68,7 @@ class SparseColumnIter : public Column { common::Span row_ind_; size_t idx_; - size_t const* RowIndices() const { return row_ind_.data(); } + [[nodiscard]] size_t const* RowIndices() const { return row_ind_.data(); } public: SparseColumnIter(common::Span index, bst_bin_t least_bin_idx, @@ -81,7 +86,7 @@ class SparseColumnIter : public Column { SparseColumnIter(SparseColumnIter const&) = delete; SparseColumnIter(SparseColumnIter&&) = default; - size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; } + [[nodiscard]] size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; } bst_bin_t operator[](size_t rid) { const size_t column_size = this->Size(); if (!((idx_) < column_size)) { @@ -101,6 +106,10 @@ class SparseColumnIter : public Column { } }; +/** + * @brief Column stored as a dense vector. It might still contain missing values as + * indicated by the missing flags. + */ template class DenseColumnIter : public Column { public: @@ -109,17 +118,19 @@ class DenseColumnIter : public Column { private: using Base = Column; /* flags for missing values in dense columns */ - std::vector const& missing_flags_; + LBitField32 missing_flags_; size_t feature_offset_; public: explicit DenseColumnIter(common::Span index, bst_bin_t index_base, - std::vector const& missing_flags, size_t feature_offset) + LBitField32 missing_flags, size_t feature_offset) : Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {} DenseColumnIter(DenseColumnIter const&) = delete; DenseColumnIter(DenseColumnIter&&) = default; - bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; } + [[nodiscard]] bool IsMissing(size_t ridx) const { + return missing_flags_.Check(feature_offset_ + ridx); + } bst_bin_t operator[](size_t ridx) const { if (any_missing) { @@ -131,12 +142,54 @@ class DenseColumnIter : public Column { }; /** - * \brief Column major matrix for gradient index. This matrix contains both dense column - * and sparse column, the type of the column is controlled by sparse threshold. When the - * number of missing values in a column is below the threshold it's classified as dense - * column. + * @brief Column major matrix for gradient index on CPU. + * + * This matrix contains both dense columns and sparse columns, the type of the column + * is controlled by the sparse threshold parameter. When the number of missing values + * in a column is below the threshold it's classified as dense column. */ class ColumnMatrix { + /** + * @brief A bit set for indicating whether an element in a dense column is missing. + */ + struct MissingIndicator { + LBitField32 missing; + RefResourceView storage; + + MissingIndicator() = default; + /** + * @param n_elements Size of the bit set + * @param init Initialize the indicator to true or false. + */ + MissingIndicator(std::size_t n_elements, bool init) { + auto m_size = missing.ComputeStorageSize(n_elements); + storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0}); + this->InitView(); + } + /** @brief Set the i^th element to be a valid element (instead of missing). */ + void SetValid(typename LBitField32::index_type i) { missing.Clear(i); } + /** @brief assign the storage to the view. */ + void InitView() { + missing = LBitField32{Span{storage.data(), storage.size()}}; + } + + void GrowTo(std::size_t n_elements, bool init) { + CHECK(storage.Resource()->Type() == ResourceHandler::kMalloc) + << "[Internal Error]: Cannot grow the vector when external memory is used."; + auto m_size = missing.ComputeStorageSize(n_elements); + CHECK_GE(m_size, storage.size()); + if (m_size == storage.size()) { + return; + } + + auto new_storage = + common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0}); + std::copy_n(storage.cbegin(), storage.size(), new_storage.begin()); + storage = std::move(new_storage); + this->InitView(); + } + }; + void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold); template @@ -144,9 +197,10 @@ class ColumnMatrix { if (type_[fid] == kDenseColumn) { ColumnBinT* begin = &local_index[feature_offsets_[fid]]; begin[rid] = bin_id - index_base_[fid]; - // not thread-safe with bool vector. FIXME(jiamingy): We can directly assign - // kMissingId to the index to avoid missing flags. - missing_flags_[feature_offsets_[fid] + rid] = false; + // not thread-safe with bit field. + // FIXME(jiamingy): We can directly assign kMissingId to the index to avoid missing + // flags. + missing_.SetValid(feature_offsets_[fid] + rid); } else { ColumnBinT* begin = &local_index[feature_offsets_[fid]]; begin[num_nonzeros_[fid]] = bin_id - index_base_[fid]; @@ -158,7 +212,9 @@ class ColumnMatrix { public: using ByteType = bool; // get number of features - bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } + [[nodiscard]] bst_feature_t GetNumFeature() const { + return static_cast(type_.size()); + } ColumnMatrix() = default; ColumnMatrix(GHistIndexMatrix const& gmat, double sparse_threshold) { @@ -166,7 +222,7 @@ class ColumnMatrix { } /** - * \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original + * @brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original * SparsePage. */ void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, @@ -178,8 +234,8 @@ class ColumnMatrix { } /** - * \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual - * data. + * @brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual + * data. * * This function requires a binary search for each bin to get back the feature index * for those bins. @@ -199,7 +255,7 @@ class ColumnMatrix { } } - bool IsInitialized() const { return !type_.empty(); } + [[nodiscard]] bool IsInitialized() const { return !type_.empty(); } /** * \brief Push batch of data for Quantile DMatrix support. @@ -257,7 +313,7 @@ class ColumnMatrix { reinterpret_cast(&index_[feature_offset * bins_type_size_]), column_size}; return std::move(DenseColumnIter{ - bin_index, static_cast(index_base_[fidx]), missing_flags_, feature_offset}); + bin_index, static_cast(index_base_[fidx]), missing_.missing, feature_offset}); } // all columns are dense column and has no missing value @@ -265,7 +321,8 @@ class ColumnMatrix { template void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples, const size_t n_features, int32_t n_threads) { - missing_flags_.resize(feature_offsets_[n_features], false); + missing_.GrowTo(feature_offsets_[n_features], false); + DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); auto column_index = Span{reinterpret_cast(index_.data()), @@ -290,9 +347,15 @@ class ColumnMatrix { void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat, float missing) { auto n_features = gmat.Features(); - missing_flags_.resize(feature_offsets_[n_features], true); - auto const* row_index = gmat.index.data() + gmat.row_ptr[base_rowid]; - num_nonzeros_.resize(n_features, 0); + + missing_.GrowTo(feature_offsets_[n_features], true); + auto const* row_index = gmat.index.data() + gmat.row_ptr[base_rowid]; + if (num_nonzeros_.empty()) { + num_nonzeros_ = common::MakeFixedVecWithMalloc(n_features, std::size_t{0}); + } else { + CHECK_EQ(num_nonzeros_.size(), n_features); + } + auto is_valid = data::IsValidFunctor{missing}; DispatchBinType(bins_type_size_, [&](auto t) { @@ -321,8 +384,9 @@ class ColumnMatrix { */ void SetIndexMixedColumns(const GHistIndexMatrix& gmat) { auto n_features = gmat.Features(); - missing_flags_.resize(feature_offsets_[n_features], true); - num_nonzeros_.resize(n_features, 0); + + missing_ = MissingIndicator{feature_offsets_[n_features], true}; + num_nonzeros_ = common::MakeFixedVecWithMalloc(n_features, std::size_t{0}); DispatchBinType(bins_type_size_, [&](auto t) { using ColumnBinT = decltype(t); @@ -335,106 +399,34 @@ class ColumnMatrix { }); } - BinTypeSize GetTypeSize() const { return bins_type_size_; } - auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } + [[nodiscard]] BinTypeSize GetTypeSize() const { return bins_type_size_; } + [[nodiscard]] auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } // And this returns part of state - bool AnyMissing() const { return any_missing_; } + [[nodiscard]] bool AnyMissing() const { return any_missing_; } // IO procedures for external memory. - bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) { - fi->Read(&index_); -#if !DMLC_LITTLE_ENDIAN - // s390x - std::vector::type> int_types; - fi->Read(&int_types); - type_.resize(int_types.size()); - std::transform( - int_types.begin(), int_types.end(), type_.begin(), - [](std::underlying_type::type i) { return static_cast(i); }); -#else - fi->Read(&type_); -#endif // !DMLC_LITTLE_ENDIAN - - fi->Read(&row_ind_); - fi->Read(&feature_offsets_); - - std::vector missing; - fi->Read(&missing); - missing_flags_.resize(missing.size()); - std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(), - [](std::uint8_t flag) { return !!flag; }); - - index_base_ = index_base; -#if !DMLC_LITTLE_ENDIAN - std::underlying_type::type v; - fi->Read(&v); - bins_type_size_ = static_cast(v); -#else - fi->Read(&bins_type_size_); -#endif - - fi->Read(&any_missing_); - return true; - } - - size_t Write(dmlc::Stream* fo) const { - size_t bytes{0}; - - auto write_vec = [&](auto const& vec) { - fo->Write(vec); - bytes += vec.size() * sizeof(typename std::remove_reference_t::value_type) + - sizeof(uint64_t); - }; - write_vec(index_); -#if !DMLC_LITTLE_ENDIAN - // s390x - std::vector::type> int_types(type_.size()); - std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) { - return static_cast::type>(t); - }); - write_vec(int_types); -#else - write_vec(type_); -#endif // !DMLC_LITTLE_ENDIAN - write_vec(row_ind_); - write_vec(feature_offsets_); - // dmlc can not handle bool vector - std::vector missing(missing_flags_.size()); - std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(), - [](bool flag) { return static_cast(flag); }); - write_vec(missing); - -#if !DMLC_LITTLE_ENDIAN - auto v = static_cast::type>(bins_type_size_); - fo->Write(v); -#else - fo->Write(bins_type_size_); -#endif // DMLC_LITTLE_ENDIAN - bytes += sizeof(bins_type_size_); - fo->Write(any_missing_); - bytes += sizeof(any_missing_); - - return bytes; - } + [[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base); + [[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const; private: - std::vector index_; + RefResourceView index_; - std::vector type_; - /* indptr of a CSC matrix. */ - std::vector row_ind_; - /* indicate where each column's index and row_ind is stored. */ - std::vector feature_offsets_; - /* The number of nnz of each column. */ - std::vector num_nonzeros_; + RefResourceView type_; + /** @brief indptr of a CSC matrix. */ + RefResourceView row_ind_; + /** @brief indicate where each column's index and row_ind is stored. */ + RefResourceView feature_offsets_; + /** @brief The number of nnz of each column. */ + RefResourceView num_nonzeros_; // index_base_[fid]: least bin id for feature fid - uint32_t const* index_base_; - std::vector missing_flags_; + std::uint32_t const* index_base_; + + MissingIndicator missing_; + BinTypeSize bins_type_size_; bool any_missing_; }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_COLUMN_MATRIX_H_ diff --git a/src/common/hist_util.h b/src/common/hist_util.h index d2edf2ec8..2781da8e0 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -203,13 +203,33 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) { } /** - * \brief Optionally compressed gradient index. The compression works only with dense + * @brief Optionally compressed gradient index. The compression works only with dense * data. * * The main body of construction code is in gradient_index.cc, this struct is only a - * storage class. + * view class. */ -struct Index { +class Index { + private: + void SetBinTypeSize(BinTypeSize binTypeSize) { + binTypeSize_ = binTypeSize; + switch (binTypeSize) { + case kUint8BinsTypeSize: + func_ = &GetValueFromUint8; + break; + case kUint16BinsTypeSize: + func_ = &GetValueFromUint16; + break; + case kUint32BinsTypeSize: + func_ = &GetValueFromUint32; + break; + default: + CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize || + binTypeSize == kUint32BinsTypeSize); + } + } + + public: // Inside the compressor, bin_idx is the index for cut value across all features. By // subtracting it with starting pointer of each feature, we can reduce it to smaller // value and store it with smaller types. Usable only with dense data. @@ -233,10 +253,24 @@ struct Index { } Index() { SetBinTypeSize(binTypeSize_); } - Index(const Index& i) = delete; - Index& operator=(Index i) = delete; + + Index(Index const& i) = delete; + Index& operator=(Index const& i) = delete; Index(Index&& i) = delete; - Index& operator=(Index&& i) = delete; + + /** @brief Move assignment for lazy initialization. */ + Index& operator=(Index&& i) = default; + + /** + * @brief Construct the index from data. + * + * @param data Storage for compressed histogram bin. + * @param bin_size Number of bytes for each bin. + */ + Index(Span data, BinTypeSize bin_size) : data_{data} { + this->SetBinTypeSize(bin_size); + } + uint32_t operator[](size_t i) const { if (!bin_offset_.empty()) { // dense, compressed @@ -247,26 +281,7 @@ struct Index { return func_(data_.data(), i); } } - void SetBinTypeSize(BinTypeSize binTypeSize) { - binTypeSize_ = binTypeSize; - switch (binTypeSize) { - case kUint8BinsTypeSize: - func_ = &GetValueFromUint8; - break; - case kUint16BinsTypeSize: - func_ = &GetValueFromUint16; - break; - case kUint32BinsTypeSize: - func_ = &GetValueFromUint32; - break; - default: - CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize || - binTypeSize == kUint32BinsTypeSize); - } - } - BinTypeSize GetBinTypeSize() const { - return binTypeSize_; - } + [[nodiscard]] BinTypeSize GetBinTypeSize() const { return binTypeSize_; } template T const* data() const { // NOLINT return reinterpret_cast(data_.data()); @@ -275,30 +290,27 @@ struct Index { T* data() { // NOLINT return reinterpret_cast(data_.data()); } - uint32_t const* Offset() const { return bin_offset_.data(); } - size_t OffsetSize() const { return bin_offset_.size(); } - size_t Size() const { return data_.size() / (binTypeSize_); } + [[nodiscard]] std::uint32_t const* Offset() const { return bin_offset_.data(); } + [[nodiscard]] std::size_t OffsetSize() const { return bin_offset_.size(); } + [[nodiscard]] std::size_t Size() const { return data_.size() / (binTypeSize_); } - void Resize(const size_t n_bytes) { - data_.resize(n_bytes); - } // set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts void SetBinOffset(std::vector const& cut_ptrs) { bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features. std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin()); } - std::vector::const_iterator begin() const { // NOLINT - return data_.begin(); + auto begin() const { // NOLINT + return data_.data(); } - std::vector::const_iterator end() const { // NOLINT - return data_.end(); + auto end() const { // NOLINT + return data_.data() + data_.size(); } - std::vector::iterator begin() { // NOLINT - return data_.begin(); + auto begin() { // NOLINT + return data_.data(); } - std::vector::iterator end() { // NOLINT - return data_.end(); + auto end() { // NOLINT + return data_.data() + data_.size(); } private: @@ -313,12 +325,12 @@ struct Index { using Func = uint32_t (*)(uint8_t const*, size_t); - std::vector data_; + Span data_; // starting position of each feature inside the cut values (the indptr of the CSC cut matrix // HistogramCuts without the last entry.) Used for bin compression. std::vector bin_offset_; - BinTypeSize binTypeSize_ {kUint8BinsTypeSize}; + BinTypeSize binTypeSize_{kUint8BinsTypeSize}; Func func_; }; diff --git a/src/common/io.cc b/src/common/io.cc index ba97db574..db1624b95 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -200,21 +200,43 @@ std::string FileExtension(std::string fname, bool lower) { } } -struct PrivateMmapConstStream::MMAPFile { +// For some reason, NVCC 12.1 marks the function deleted if we expose it in the header. +// NVCC 11.8 doesn't allow `noexcept(false) = default` altogether. +ResourceHandler::~ResourceHandler() noexcept(false) {} // NOLINT + +struct MMAPFile { #if defined(xgboost_IS_WIN) HANDLE fd{INVALID_HANDLE_VALUE}; HANDLE file_map{INVALID_HANDLE_VALUE}; #else std::int32_t fd{0}; #endif - char* base_ptr{nullptr}; + std::byte* base_ptr{nullptr}; std::size_t base_size{0}; + std::size_t delta{0}; std::string path; + + MMAPFile() = default; + +#if defined(xgboost_IS_WIN) + MMAPFile(HANDLE fd, HANDLE fm, std::byte* base_ptr, std::size_t base_size, std::size_t delta, + std::string path) + : fd{fd}, + file_map{fm}, + base_ptr{base_ptr}, + base_size{base_size}, + delta{delta}, + path{std::move(path)} {} +#else + MMAPFile(std::int32_t fd, std::byte* base_ptr, std::size_t base_size, std::size_t delta, + std::string path) + : fd{fd}, base_ptr{base_ptr}, base_size{base_size}, delta{delta}, path{std::move(path)} {} +#endif }; -char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::size_t length) { +std::unique_ptr Open(std::string path, std::size_t offset, std::size_t length) { if (length == 0) { - return nullptr; + return std::make_unique(); } #if defined(xgboost_IS_WIN) @@ -226,16 +248,18 @@ char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::si CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg(); #endif - char* ptr{nullptr}; + std::byte* ptr{nullptr}; // Round down for alignment. auto view_start = offset / GetMmapAlignment() * GetMmapAlignment(); auto view_size = length + (offset - view_start); #if defined(__linux__) || defined(__GLIBC__) int prot{PROT_READ}; - ptr = reinterpret_cast(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); + ptr = reinterpret_cast(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); + madvise(ptr, view_size, MADV_WILLNEED); CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg(); - handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)}); + auto handle = + std::make_unique(fd, ptr, view_size, offset - view_start, std::move(path)); #elif defined(xgboost_IS_WIN) auto file_size = GetFileSize(fd, nullptr); DWORD access = PAGE_READONLY; @@ -244,33 +268,32 @@ char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::si std::uint32_t loff = static_cast(view_start); std::uint32_t hoff = view_start >> 32; CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg(); - ptr = reinterpret_cast(MapViewOfFile(map_file, access, hoff, loff, view_size)); + ptr = reinterpret_cast(MapViewOfFile(map_file, access, hoff, loff, view_size)); CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg(); - handle_.reset(new MMAPFile{fd, map_file, ptr, view_size, std::move(path)}); + auto handle = std::make_unique(fd, map_file, ptr, view_size, offset - view_start, + std::move(path)); #else CHECK_LE(offset, std::numeric_limits::max()) << "File size has exceeded the limit on the current system."; int prot{PROT_READ}; - ptr = reinterpret_cast(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); + ptr = reinterpret_cast(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg(); - handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)}); + auto handle = + std::make_unique(fd, ptr, view_size, offset - view_start, std::move(path)); #endif // defined(__linux__) - ptr += (offset - view_start); - return ptr; + return handle; } -PrivateMmapConstStream::PrivateMmapConstStream(std::string path, std::size_t offset, - std::size_t length) - : MemoryFixSizeBuffer{}, handle_{nullptr} { - this->p_buffer_ = Open(std::move(path), offset, length); - this->buffer_size_ = length; -} +MmapResource::MmapResource(std::string path, std::size_t offset, std::size_t length) + : ResourceHandler{kMmap}, handle_{Open(std::move(path), offset, length)}, n_{length} {} -PrivateMmapConstStream::~PrivateMmapConstStream() { - CHECK(handle_); +MmapResource::~MmapResource() noexcept(false) { + if (!handle_) { + return; + } #if defined(xgboost_IS_WIN) - if (p_buffer_) { + if (handle_->base_ptr) { CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg(); } if (handle_->fd != INVALID_HANDLE_VALUE) { @@ -290,6 +313,43 @@ PrivateMmapConstStream::~PrivateMmapConstStream() { } #endif } + +[[nodiscard]] void* MmapResource::Data() { + if (!handle_) { + return nullptr; + } + return handle_->base_ptr + handle_->delta; +} + +[[nodiscard]] std::size_t MmapResource::Size() const { return n_; } + +// For some reason, NVCC 12.1 marks the function deleted if we expose it in the header. +// NVCC 11.8 doesn't allow `noexcept(false) = default` altogether. +AlignedResourceReadStream::~AlignedResourceReadStream() noexcept(false) {} // NOLINT +PrivateMmapConstStream::~PrivateMmapConstStream() noexcept(false) {} // NOLINT + +AlignedFileWriteStream::AlignedFileWriteStream(StringView path, StringView flags) + : pimpl_{dmlc::Stream::Create(path.c_str(), flags.c_str())} {} + +[[nodiscard]] std::size_t AlignedFileWriteStream::DoWrite(const void* ptr, + std::size_t n_bytes) noexcept(true) { + pimpl_->Write(ptr, n_bytes); + return n_bytes; +} + +AlignedMemWriteStream::AlignedMemWriteStream(std::string* p_buf) + : pimpl_{std::make_unique(p_buf)} {} +AlignedMemWriteStream::~AlignedMemWriteStream() = default; + +[[nodiscard]] std::size_t AlignedMemWriteStream::DoWrite(const void* ptr, + std::size_t n_bytes) noexcept(true) { + this->pimpl_->Write(ptr, n_bytes); + return n_bytes; +} + +[[nodiscard]] std::size_t AlignedMemWriteStream::Tell() const noexcept(true) { + return this->pimpl_->Tell(); +} } // namespace xgboost::common #if defined(xgboost_IS_WIN) diff --git a/src/common/io.h b/src/common/io.h index ab408dec1..baf518aa5 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -4,22 +4,29 @@ * \brief general stream interface for serialization, I/O * \author Tianqi Chen */ - #ifndef XGBOOST_COMMON_IO_H_ #define XGBOOST_COMMON_IO_H_ #include #include -#include -#include -#include // for unique_ptr -#include // for string +#include // for min +#include // for array +#include // for byte, size_t +#include // for malloc, realloc, free +#include // for memcpy +#include // for ifstream +#include // for numeric_limits +#include // for unique_ptr +#include // for string +#include // for alignment_of_v, enable_if_t +#include // for move +#include // for vector #include "common.h" +#include "xgboost/string_view.h" // for StringView -namespace xgboost { -namespace common { +namespace xgboost::common { using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer; using MemoryBufferStream = rabit::utils::MemoryBufferStream; @@ -58,8 +65,8 @@ class FixedSizeStream : public PeekableInStream { size_t Read(void* dptr, size_t size) override; size_t PeekRead(void* dptr, size_t size) override; - size_t Size() const { return buffer_.size(); } - size_t Tell() const { return pointer_; } + [[nodiscard]] std::size_t Size() const { return buffer_.size(); } + [[nodiscard]] std::size_t Tell() const { return pointer_; } void Seek(size_t pos); void Write(const void*, size_t) override { @@ -129,18 +136,245 @@ inline std::string ReadAll(std::string const &path) { return content; } +struct MMAPFile; + +/** + * @brief Handler for one-shot resource. Unlike `std::pmr::*`, the resource handler is + * fixed once it's constructed. Users cannot use mutable operations like resize + * without acquiring the specific resource first. + */ +class ResourceHandler { + public: + // RTTI + enum Kind : std::uint8_t { + kMalloc = 0, + kMmap = 1, + }; + + private: + Kind kind_{kMalloc}; + + public: + virtual void* Data() = 0; + template + [[nodiscard]] T* DataAs() { + return reinterpret_cast(this->Data()); + } + + [[nodiscard]] virtual std::size_t Size() const = 0; + [[nodiscard]] auto Type() const { return kind_; } + + // Allow exceptions for cleaning up resource. + virtual ~ResourceHandler() noexcept(false); + + explicit ResourceHandler(Kind kind) : kind_{kind} {} + // Use shared_ptr to manage a pool like resource handler. All copy and assignment + // operators are disabled. + ResourceHandler(ResourceHandler const& that) = delete; + ResourceHandler& operator=(ResourceHandler const& that) = delete; + ResourceHandler(ResourceHandler&& that) = delete; + ResourceHandler& operator=(ResourceHandler&& that) = delete; + /** + * @brief Wether two resources have the same type. (both malloc or both mmap). + */ + [[nodiscard]] bool IsSameType(ResourceHandler const& that) const { + return this->Type() == that.Type(); + } +}; + +class MallocResource : public ResourceHandler { + void* ptr_{nullptr}; + std::size_t n_{0}; + + void Clear() noexcept(true) { + std::free(ptr_); + ptr_ = nullptr; + n_ = 0; + } + + public: + explicit MallocResource(std::size_t n_bytes) : ResourceHandler{kMalloc} { this->Resize(n_bytes); } + ~MallocResource() noexcept(true) override { this->Clear(); } + + void* Data() override { return ptr_; } + [[nodiscard]] std::size_t Size() const override { return n_; } + /** + * @brief Resize the resource to n_bytes. Unlike std::vector::resize, it prefers realloc + * over malloc. + * + * @tparam force_malloc Force the use of malloc over realloc. Used for testing. + * + * @param n_bytes The new size. + */ + template + void Resize(std::size_t n_bytes) { + // realloc(ptr, 0) works, but is deprecated. + if (n_bytes == 0) { + this->Clear(); + return; + } + + // If realloc fails, we need to copy the data ourselves. + bool need_copy{false}; + void* new_ptr{nullptr}; + // use realloc first, it can handle nullptr. + if constexpr (!force_malloc) { + new_ptr = std::realloc(ptr_, n_bytes); + } + // retry with malloc if realloc fails + if (!new_ptr) { + // ptr_ is preserved if realloc fails + new_ptr = std::malloc(n_bytes); + need_copy = true; + } + if (!new_ptr) { + // malloc fails + LOG(FATAL) << "bad_malloc: Failed to allocate " << n_bytes << " bytes."; + } + + if (need_copy) { + std::copy_n(reinterpret_cast(ptr_), n_, reinterpret_cast(new_ptr)); + } + // default initialize + std::memset(reinterpret_cast(new_ptr) + n_, '\0', n_bytes - n_); + // free the old ptr if malloc is used. + if (need_copy) { + this->Clear(); + } + + ptr_ = new_ptr; + n_ = n_bytes; + } +}; + +/** + * @brief A class for wrapping mmap as a resource for RAII. + */ +class MmapResource : public ResourceHandler { + std::unique_ptr handle_; + std::size_t n_; + + public: + MmapResource(std::string path, std::size_t offset, std::size_t length); + ~MmapResource() noexcept(false) override; + + [[nodiscard]] void* Data() override; + [[nodiscard]] std::size_t Size() const override; +}; + +/** + * @param Alignment for resource read stream and aligned write stream. + */ +constexpr std::size_t IOAlignment() { + // For most of the pod types in XGBoost, 8 byte is sufficient. + return 8; +} + +/** + * @brief Wrap resource into a dmlc stream. + * + * This class is to facilitate the use of mmap. Caller can optionally use the `Read()` + * method or the `Consume()` method. The former copies data into output, while the latter + * makes copy only if it's a primitive type. + * + * Input is required to be aligned to IOAlignment(). + */ +class AlignedResourceReadStream { + std::shared_ptr resource_; + std::size_t curr_ptr_{0}; + + // Similar to SEEK_END in libc + static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); + + public: + explicit AlignedResourceReadStream(std::shared_ptr resource) + : resource_{std::move(resource)} {} + + [[nodiscard]] std::shared_ptr Share() noexcept(true) { return resource_; } + /** + * @brief Consume n_bytes of data, no copying is performed. + * + * @return A pair with the beginning pointer and the number of available bytes, which + * may be smaller than requested. + */ + [[nodiscard]] auto Consume(std::size_t n_bytes) noexcept(true) { + auto res_size = resource_->Size(); + auto data = reinterpret_cast(resource_->Data()); + auto ptr = data + curr_ptr_; + + // Move the cursor + auto aligned_n_bytes = DivRoundUp(n_bytes, IOAlignment()) * IOAlignment(); + auto aligned_forward = std::min(res_size - curr_ptr_, aligned_n_bytes); + std::size_t forward = std::min(res_size - curr_ptr_, n_bytes); + + curr_ptr_ += aligned_forward; + + return std::pair{ptr, forward}; + } + + template + [[nodiscard]] auto Consume(T* out) noexcept(false) -> std::enable_if_t, bool> { + auto [ptr, size] = this->Consume(sizeof(T)); + if (size != sizeof(T)) { + return false; + } + CHECK_EQ(reinterpret_cast(ptr) % std::alignment_of_v, 0); + *out = *reinterpret_cast(ptr); + return true; + } + + [[nodiscard]] virtual std::size_t Tell() noexcept(true) { return curr_ptr_; } + /** + * @brief Read n_bytes of data, output is copied into ptr. + */ + [[nodiscard]] std::size_t Read(void* ptr, std::size_t n_bytes) noexcept(true) { + auto [res_ptr, forward] = this->Consume(n_bytes); + if (forward != 0) { + std::memcpy(ptr, res_ptr, forward); + } + return forward; + } + /** + * @brief Read a primitive type. + * + * @return Whether the read is successful. + */ + template + [[nodiscard]] auto Read(T* out) noexcept(false) -> std::enable_if_t, bool> { + return this->Consume(out); + } + /** + * @brief Read a vector. + * + * @return Whether the read is successful. + */ + template + [[nodiscard]] bool Read(std::vector* out) noexcept(true) { + std::uint64_t n{0}; + if (!this->Consume(&n)) { + return false; + } + out->resize(n); + + auto n_bytes = sizeof(T) * n; + if (this->Read(out->data(), n_bytes) != n_bytes) { + return false; + } + return true; + } + + virtual ~AlignedResourceReadStream() noexcept(false); +}; + /** * @brief Private mmap file as a read-only stream. * * It can calculate alignment automatically based on system page size (or allocation * granularity on Windows). + * + * The file is required to be aligned by IOAlignment(). */ -class PrivateMmapConstStream : public MemoryFixSizeBuffer { - struct MMAPFile; - std::unique_ptr handle_; - - char* Open(std::string path, std::size_t offset, std::size_t length); - +class PrivateMmapConstStream : public AlignedResourceReadStream { public: /** * @brief Construct a private mmap stream. @@ -149,11 +383,71 @@ class PrivateMmapConstStream : public MemoryFixSizeBuffer { * @param offset See the `offset` parameter of `mmap` for details. * @param length See the `length` parameter of `mmap` for details. */ - explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length); - void Write(void const*, std::size_t) override { LOG(FATAL) << "Read-only stream."; } - - ~PrivateMmapConstStream() override; + explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length) + : AlignedResourceReadStream{std::make_shared(path, offset, length)} {} + ~PrivateMmapConstStream() noexcept(false) override; }; -} // namespace common -} // namespace xgboost + +/** + * @brief Base class for write stream with alignment defined by IOAlignment(). + */ +class AlignedWriteStream { + protected: + [[nodiscard]] virtual std::size_t DoWrite(const void* ptr, + std::size_t n_bytes) noexcept(true) = 0; + + public: + virtual ~AlignedWriteStream() = default; + + [[nodiscard]] std::size_t Write(const void* ptr, std::size_t n_bytes) noexcept(false) { + auto aligned_n_bytes = DivRoundUp(n_bytes, IOAlignment()) * IOAlignment(); + auto w_n_bytes = this->DoWrite(ptr, n_bytes); + CHECK_EQ(w_n_bytes, n_bytes); + auto remaining = aligned_n_bytes - n_bytes; + if (remaining > 0) { + std::array padding; + std::memset(padding.data(), '\0', padding.size()); + w_n_bytes = this->DoWrite(padding.data(), remaining); + CHECK_EQ(w_n_bytes, remaining); + } + return aligned_n_bytes; + } + + template + [[nodiscard]] std::enable_if_t, std::size_t> Write(T const& v) { + return this->Write(&v, sizeof(T)); + } +}; + +/** + * @brief Output stream backed by a file. Aligned to IOAlignment() bytes. + */ +class AlignedFileWriteStream : public AlignedWriteStream { + std::unique_ptr pimpl_; + + protected: + [[nodiscard]] std::size_t DoWrite(const void* ptr, std::size_t n_bytes) noexcept(true) override; + + public: + AlignedFileWriteStream() = default; + AlignedFileWriteStream(StringView path, StringView flags); + ~AlignedFileWriteStream() override = default; +}; + +/** + * @brief Output stream backed by memory buffer. Aligned to IOAlignment() bytes. + */ +class AlignedMemWriteStream : public AlignedFileWriteStream { + std::unique_ptr pimpl_; + + protected: + [[nodiscard]] std::size_t DoWrite(const void* ptr, std::size_t n_bytes) noexcept(true) override; + + public: + explicit AlignedMemWriteStream(std::string* p_buf); + ~AlignedMemWriteStream() override; + + [[nodiscard]] std::size_t Tell() const noexcept(true); +}; +} // namespace xgboost::common #endif // XGBOOST_COMMON_IO_H_ diff --git a/src/common/ref_resource_view.h b/src/common/ref_resource_view.h new file mode 100644 index 000000000..2804d79eb --- /dev/null +++ b/src/common/ref_resource_view.h @@ -0,0 +1,158 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_REF_RESOURCE_VIEW_H_ +#define XGBOOST_COMMON_REF_RESOURCE_VIEW_H_ + +#include // for fill_n +#include // for uint64_t +#include // for memcpy +#include // for shared_ptr, make_shared +#include // for is_reference_v, remove_reference_t, is_same_v +#include // for swap, move + +#include "io.h" // for ResourceHandler, AlignedResourceReadStream, MallocResource +#include "xgboost/logging.h" +#include "xgboost/span.h" // for Span + +namespace xgboost::common { +/** + * @brief A vector-like type that holds a reference counted resource. + * + * The vector size is immutable after construction. This way we can swap the underlying + * resource when needed. + */ +template +class RefResourceView { + static_assert(!std::is_reference_v); + + public: + using value_type = T; // NOLINT + using size_type = std::uint64_t; // NOLINT + + private: + value_type* ptr_{nullptr}; + size_type size_{0}; + std::shared_ptr mem_{nullptr}; + + public: + RefResourceView(value_type* ptr, size_type n, std::shared_ptr mem) + : ptr_{ptr}, size_{n}, mem_{std::move(mem)} { + CHECK_GE(mem_->Size(), n); + } + /** + * @brief Construct a view on ptr with length n. The ptr is held by the mem resource. + * + * @param ptr The pointer to view. + * @param n The length of the view. + * @param mem The owner of the pointer. + * @param init Initialize the view with this value. + */ + RefResourceView(value_type* ptr, size_type n, std::shared_ptr mem, + T const& init) + : RefResourceView{ptr, n, mem} { + if (n != 0) { + std::fill_n(ptr_, n, init); + } + } + + ~RefResourceView() = default; + + RefResourceView() = default; + RefResourceView(RefResourceView const& that) = delete; + RefResourceView(RefResourceView&& that) = delete; + RefResourceView& operator=(RefResourceView const& that) = delete; + /** + * @brief We allow move assignment for lazy initialization. + */ + RefResourceView& operator=(RefResourceView&& that) = default; + + [[nodiscard]] size_type size() const { return size_; } // NOLINT + [[nodiscard]] size_type size_bytes() const { // NOLINT + return Span{data(), size()}.size_bytes(); + } + [[nodiscard]] value_type* data() { return ptr_; }; // NOLINT + [[nodiscard]] value_type const* data() const { return ptr_; }; // NOLINT + [[nodiscard]] bool empty() const { return size() == 0; } // NOLINT + + [[nodiscard]] auto cbegin() const { return data(); } // NOLINT + [[nodiscard]] auto begin() { return data(); } // NOLINT + [[nodiscard]] auto begin() const { return cbegin(); } // NOLINT + [[nodiscard]] auto cend() const { return data() + size(); } // NOLINT + [[nodiscard]] auto end() { return data() + size(); } // NOLINT + [[nodiscard]] auto end() const { return cend(); } // NOLINT + + [[nodiscard]] auto const& front() const { return data()[0]; } // NOLINT + [[nodiscard]] auto& front() { return data()[0]; } // NOLINT + [[nodiscard]] auto const& back() const { return data()[size() - 1]; } // NOLINT + [[nodiscard]] auto& back() { return data()[size() - 1]; } // NOLINT + + [[nodiscard]] value_type& operator[](size_type i) { return ptr_[i]; } + [[nodiscard]] value_type const& operator[](size_type i) const { return ptr_[i]; } + + /** + * @brief Get the underlying resource. + */ + auto Resource() const { return mem_; } +}; + +/** + * @brief Read a vector from stream. Accepts both `std::vector` and `RefResourceView`. + * + * If the output vector is a referenced counted view, no copying occur. + */ +template +[[nodiscard]] bool ReadVec(common::AlignedResourceReadStream* fi, Vec* vec) { + std::uint64_t n{0}; + if (!fi->Read(&n)) { + return false; + } + if (n == 0) { + return true; + } + + using T = typename Vec::value_type; + auto expected_bytes = sizeof(T) * n; + + auto [ptr, n_bytes] = fi->Consume(expected_bytes); + if (n_bytes != expected_bytes) { + return false; + } + + if constexpr (std::is_same_v>) { + *vec = RefResourceView{reinterpret_cast(ptr), n, fi->Share()}; + } else { + vec->resize(n); + std::memcpy(vec->data(), ptr, n_bytes); + } + return true; +} + +/** + * @brief Write a vector to stream. Accepts both `std::vector` and `RefResourceView`. + */ +template +[[nodiscard]] std::size_t WriteVec(AlignedFileWriteStream* fo, Vec const& vec) { + std::size_t bytes{0}; + auto n = static_cast(vec.size()); + bytes += fo->Write(n); + if (n == 0) { + return sizeof(n); + } + + using T = typename std::remove_reference_t::value_type; + bytes += fo->Write(vec.data(), vec.size() * sizeof(T)); + + return bytes; +} + +/** + * @brief Make a fixed size `RefResourceView` with malloc resource. + */ +template +[[nodiscard]] RefResourceView MakeFixedVecWithMalloc(std::size_t n_elements, T const& init) { + auto resource = std::make_shared(n_elements * sizeof(T)); + return RefResourceView{resource->DataAs(), n_elements, resource, init}; +} +} // namespace xgboost::common +#endif // XGBOOST_COMMON_REF_RESOURCE_VIEW_H_ diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 2f54b91c9..8316368ba 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -1,60 +1,59 @@ -/*! - * Copyright 2019-2021 XGBoost contributors +/** + * Copyright 2019-2023, XGBoost contributors */ -#include #include +#include // for size_t + +#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream +#include "../common/ref_resource_view.h" // for ReadVec, WriteVec #include "ellpack_page.cuh" -#include "sparse_page_writer.h" -#include "histogram_cut_format.h" - -namespace xgboost { -namespace data { +#include "histogram_cut_format.h" // for ReadHistogramCuts, WriteHistogramCuts +#include "sparse_page_writer.h" // for SparsePageFormat +namespace xgboost::data { DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); - class EllpackPageRawFormat : public SparsePageFormat { public: - bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { + bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override { auto* impl = page->Impl(); if (!ReadHistogramCuts(&impl->Cuts(), fi)) { return false; } - fi->Read(&impl->n_rows); - fi->Read(&impl->is_dense); - fi->Read(&impl->row_stride); - fi->Read(&impl->gidx_buffer.HostVector()); + if (!fi->Read(&impl->n_rows)) { + return false; + } + if (!fi->Read(&impl->is_dense)) { + return false; + } + if (!fi->Read(&impl->row_stride)) { + return false; + } + if (!common::ReadVec(fi, &impl->gidx_buffer.HostVector())) { + return false; + } if (!fi->Read(&impl->base_rowid)) { return false; } return true; } - size_t Write(const EllpackPage& page, dmlc::Stream* fo) override { - size_t bytes = 0; + size_t Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) override { + std::size_t bytes{0}; auto* impl = page.Impl(); bytes += WriteHistogramCuts(impl->Cuts(), fo); - fo->Write(impl->n_rows); - bytes += sizeof(impl->n_rows); - fo->Write(impl->is_dense); - bytes += sizeof(impl->is_dense); - fo->Write(impl->row_stride); - bytes += sizeof(impl->row_stride); + bytes += fo->Write(impl->n_rows); + bytes += fo->Write(impl->is_dense); + bytes += fo->Write(impl->row_stride); CHECK(!impl->gidx_buffer.ConstHostVector().empty()); - fo->Write(impl->gidx_buffer.HostVector()); - bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t); - fo->Write(impl->base_rowid); - bytes += sizeof(impl->base_rowid); + bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector()); + bytes += fo->Write(impl->base_rowid); return bytes; } }; XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw) .describe("Raw ELLPACK binary data format.") - .set_body([]() { - return new EllpackPageRawFormat(); - }); - -} // namespace data -} // namespace xgboost + .set_body([]() { return new EllpackPageRawFormat(); }); +} // namespace xgboost::data diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 11e9a4bec..1d47ae9e6 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -29,7 +29,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_ cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess); const uint32_t nbins = cut.Ptrs().back(); - hit_count.resize(nbins, 0); + hit_count = common::MakeFixedVecWithMalloc(nbins, std::size_t{0}); hit_count_tloc_.resize(ctx->Threads() * nbins, 0); size_t new_size = 1; @@ -37,8 +37,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_ new_size += batch.Size(); } - row_ptr.resize(new_size); - row_ptr[0] = 0; + row_ptr = common::MakeFixedVecWithMalloc(new_size, std::size_t{0}); const bool isDense = p_fmat->IsDense(); this->isDense_ = isDense; @@ -61,8 +60,8 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &&cuts, bst_bin_t max_bin_per_feat) - : row_ptr(info.num_row_ + 1, 0), - hit_count(cuts.TotalBins(), 0), + : row_ptr{common::MakeFixedVecWithMalloc(info.num_row_ + 1, std::size_t{0})}, + hit_count{common::MakeFixedVecWithMalloc(cuts.TotalBins(), std::size_t{0})}, cut{std::forward(cuts)}, max_numeric_bins_per_feat(max_bin_per_feat), isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} @@ -95,12 +94,10 @@ GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::SpanPushBatch(batch, ft, n_threads); @@ -128,20 +125,45 @@ INSTANTIATION_PUSH(data::SparsePageAdapterBatch) #undef INSTANTIATION_PUSH void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { + auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) { + // Must resize instead of allocating a new one. This function is called everytime a + // new batch is pushed, and we grow the size accordingly without loosing the data the + // previous batches. + using T = decltype(t); + std::size_t n_bytes = sizeof(T) * n_index; + CHECK_GE(n_bytes, this->data.size()); + + auto resource = this->data.Resource(); + decltype(this->data) new_vec; + if (!resource) { + CHECK(this->data.empty()); + new_vec = common::MakeFixedVecWithMalloc(n_bytes, std::uint8_t{0}); + } else { + CHECK(resource->Type() == common::ResourceHandler::kMalloc); + auto malloc_resource = std::dynamic_pointer_cast(resource); + CHECK(malloc_resource); + malloc_resource->Resize(n_bytes); + + // gcc-11.3 doesn't work if DataAs is used. + std::uint8_t *new_ptr = reinterpret_cast(malloc_resource->Data()); + new_vec = {new_ptr, n_bytes / sizeof(std::uint8_t), malloc_resource}; + } + this->data = std::move(new_vec); + this->index = common::Index{common::Span{data.data(), data.size()}, t_size}; + }; + if ((MaxNumBinPerFeat() - 1 <= static_cast(std::numeric_limits::max())) && isDense) { // compress dense index to uint8 - index.SetBinTypeSize(common::kUint8BinsTypeSize); - index.Resize((sizeof(uint8_t)) * n_index); + make_index(std::uint8_t{}, common::kUint8BinsTypeSize); } else if ((MaxNumBinPerFeat() - 1 > static_cast(std::numeric_limits::max()) && MaxNumBinPerFeat() - 1 <= static_cast(std::numeric_limits::max())) && isDense) { // compress dense index to uint16 - index.SetBinTypeSize(common::kUint16BinsTypeSize); - index.Resize((sizeof(uint16_t)) * n_index); + make_index(std::uint16_t{}, common::kUint16BinsTypeSize); } else { - index.SetBinTypeSize(common::kUint32BinsTypeSize); - index.Resize((sizeof(uint32_t)) * n_index); + // no compression + make_index(std::uint32_t{}, common::kUint32BinsTypeSize); } } @@ -214,11 +236,11 @@ float GHistIndexMatrix::GetFvalue(std::vector const &ptrs, return std::numeric_limits::quiet_NaN(); } -bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) { +bool GHistIndexMatrix::ReadColumnPage(common::AlignedResourceReadStream *fi) { return this->columns_->Read(fi, this->cut.Ptrs().data()); } -size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const { +std::size_t GHistIndexMatrix::WriteColumnPage(common::AlignedFileWriteStream *fo) const { return this->columns_->Write(fo); } } // namespace xgboost diff --git a/src/data/gradient_index.cu b/src/data/gradient_index.cu index af5b0f67b..42018eab4 100644 --- a/src/data/gradient_index.cu +++ b/src/data/gradient_index.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2022 by XGBoost Contributors +/** + * Copyright 2022-2023, XGBoost Contributors */ #include // std::unique_ptr @@ -41,9 +41,9 @@ void SetIndexData(Context const* ctx, EllpackPageImpl const* page, } void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page, - std::vector* p_out) { + common::RefResourceView* p_out) { auto& row_ptr = *p_out; - row_ptr.resize(page->Size() + 1, 0); + row_ptr = common::MakeFixedVecWithMalloc(page->Size() + 1, std::size_t{0}); if (page->is_dense) { std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride); } else { @@ -95,7 +95,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info, ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this); } - this->hit_count.resize(n_bins_total, 0); + this->hit_count = common::MakeFixedVecWithMalloc(n_bins_total, std::size_t{0}); this->GatherHitCount(ctx->Threads(), n_bins_total); // sanity checks diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index d36373d6b..840be4b06 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -9,13 +9,14 @@ #include // for atomic #include // for uint32_t #include // for size_t -#include +#include // for make_unique #include #include "../common/categorical.h" #include "../common/error_msg.h" // for InfInData #include "../common/hist_util.h" #include "../common/numeric.h" +#include "../common/ref_resource_view.h" // for RefResourceView #include "../common/threading_utils.h" #include "../common/transform_iterator.h" // for MakeIndexTransformIter #include "adapter.h" @@ -25,9 +26,11 @@ namespace xgboost { namespace common { class ColumnMatrix; +class AlignedFileWriteStream; } // namespace common -/*! - * \brief preprocessed global index matrix, in CSR format + +/** + * @brief preprocessed global index matrix, in CSR format. * * Transform floating values to integer index in histogram This is a global histogram * index for CPU histogram. On GPU ellpack page is used. @@ -133,20 +136,22 @@ class GHistIndexMatrix { } public: - /*! \brief row pointer to rows by element position */ - std::vector row_ptr; - /*! \brief The index data */ + /** @brief row pointer to rows by element position */ + common::RefResourceView row_ptr; + /** @brief data storage for index. */ + common::RefResourceView data; + /** @brief The histogram index. */ common::Index index; - /*! \brief hit count of each index, used for constructing the ColumnMatrix */ - std::vector hit_count; - /*! \brief The corresponding cuts */ + /** @brief hit count of each index, used for constructing the ColumnMatrix */ + common::RefResourceView hit_count; + /** @brief The corresponding cuts */ common::HistogramCuts cut; - /** \brief max_bin for each feature. */ + /** @brief max_bin for each feature. */ bst_bin_t max_numeric_bins_per_feat; - /*! \brief base row index for current page (used by external memory) */ - size_t base_rowid{0}; + /** @brief base row index for current page (used by external memory) */ + bst_row_t base_rowid{0}; - bst_bin_t MaxNumBinPerFeat() const { + [[nodiscard]] bst_bin_t MaxNumBinPerFeat() const { return std::max(static_cast(cut.MaxCategory() + 1), max_numeric_bins_per_feat); } @@ -218,29 +223,27 @@ class GHistIndexMatrix { } } - bool IsDense() const { - return isDense_; - } + [[nodiscard]] bool IsDense() const { return isDense_; } void SetDense(bool is_dense) { isDense_ = is_dense; } /** - * \brief Get the local row index. + * @brief Get the local row index. */ - size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; } + [[nodiscard]] std::size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; } - bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } - bst_feature_t Features() const { return cut.Ptrs().size() - 1; } + [[nodiscard]] bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } + [[nodiscard]] bst_feature_t Features() const { return cut.Ptrs().size() - 1; } - bool ReadColumnPage(dmlc::SeekStream* fi); - size_t WriteColumnPage(dmlc::Stream* fo) const; + [[nodiscard]] bool ReadColumnPage(common::AlignedResourceReadStream* fi); + [[nodiscard]] std::size_t WriteColumnPage(common::AlignedFileWriteStream* fo) const; - common::ColumnMatrix const& Transpose() const; + [[nodiscard]] common::ColumnMatrix const& Transpose() const; - bst_bin_t GetGindex(size_t ridx, size_t fidx) const; + [[nodiscard]] bst_bin_t GetGindex(size_t ridx, size_t fidx) const; - float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; - float GetFvalue(std::vector const& ptrs, std::vector const& values, - std::vector const& mins, bst_row_t ridx, bst_feature_t fidx, - bool is_cat) const; + [[nodiscard]] float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; + [[nodiscard]] float GetFvalue(std::vector const& ptrs, + std::vector const& values, std::vector const& mins, + bst_row_t ridx, bst_feature_t fidx, bool is_cat) const; private: std::unique_ptr columns_; @@ -294,5 +297,5 @@ void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) { } }); } -} // namespace xgboost +} // namespace xgboost #endif // XGBOOST_DATA_GRADIENT_INDEX_H_ diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index 204157682..ac52c0697 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -1,38 +1,49 @@ -/*! - * Copyright 2021-2022 XGBoost contributors +/** + * Copyright 2021-2023 XGBoost contributors */ -#include "sparse_page_writer.h" -#include "gradient_index.h" -#include "histogram_cut_format.h" +#include // for size_t +#include // for uint8_t +#include // for underlying_type_t +#include // for vector -namespace xgboost { -namespace data { +#include "../common/io.h" // for AlignedResourceReadStream +#include "../common/ref_resource_view.h" // for ReadVec, WriteVec +#include "gradient_index.h" // for GHistIndexMatrix +#include "histogram_cut_format.h" // for ReadHistogramCuts +#include "sparse_page_writer.h" // for SparsePageFormat + +namespace xgboost::data { class GHistIndexRawFormat : public SparsePageFormat { public: - bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override { + bool Read(GHistIndexMatrix* page, common::AlignedResourceReadStream* fi) override { + CHECK(fi); + if (!ReadHistogramCuts(&page->cut, fi)) { return false; } + // indptr - fi->Read(&page->row_ptr); - // data - std::vector data; - if (!fi->Read(&data)) { + if (!common::ReadVec(fi, &page->row_ptr)) { return false; } - page->index.Resize(data.size()); - std::copy(data.cbegin(), data.cend(), page->index.begin()); - // bin type + + // data + // - bin type // Old gcc doesn't support reading from enum. std::underlying_type_t uint_bin_type{0}; if (!fi->Read(&uint_bin_type)) { return false; } - common::BinTypeSize size_type = - static_cast(uint_bin_type); - page->index.SetBinTypeSize(size_type); + common::BinTypeSize size_type = static_cast(uint_bin_type); + // - index buffer + if (!common::ReadVec(fi, &page->data)) { + return false; + } + // - index + page->index = common::Index{common::Span{page->data.data(), page->data.size()}, size_type}; + // hit count - if (!fi->Read(&page->hit_count)) { + if (!common::ReadVec(fi, &page->hit_count)) { return false; } if (!fi->Read(&page->max_numeric_bins_per_feat)) { @@ -50,38 +61,33 @@ class GHistIndexRawFormat : public SparsePageFormat { page->index.SetBinOffset(page->cut.Ptrs()); } - page->ReadColumnPage(fi); + if (!page->ReadColumnPage(fi)) { + return false; + } return true; } - size_t Write(GHistIndexMatrix const &page, dmlc::Stream *fo) override { - size_t bytes = 0; + std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override { + std::size_t bytes = 0; bytes += WriteHistogramCuts(page.cut, fo); // indptr - fo->Write(page.row_ptr); - bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) + - sizeof(uint64_t); + bytes += common::WriteVec(fo, page.row_ptr); + // data - std::vector data(page.index.begin(), page.index.end()); - fo->Write(data); - bytes += data.size() * sizeof(decltype(data)::value_type) + sizeof(uint64_t); - // bin type - std::underlying_type_t uint_bin_type = - page.index.GetBinTypeSize(); - fo->Write(uint_bin_type); - bytes += sizeof(page.index.GetBinTypeSize()); + // - bin type + std::underlying_type_t uint_bin_type = page.index.GetBinTypeSize(); + bytes += fo->Write(uint_bin_type); + // - index buffer + std::vector data(page.index.begin(), page.index.end()); + bytes += fo->Write(static_cast(data.size())); + bytes += fo->Write(data.data(), data.size()); + // hit count - fo->Write(page.hit_count); - bytes += - page.hit_count.size() * sizeof(decltype(page.hit_count)::value_type) + - sizeof(uint64_t); + bytes += common::WriteVec(fo, page.hit_count); // max_bins, base row, is_dense - fo->Write(page.max_numeric_bins_per_feat); - bytes += sizeof(page.max_numeric_bins_per_feat); - fo->Write(page.base_rowid); - bytes += sizeof(page.base_rowid); - fo->Write(page.IsDense()); - bytes += sizeof(page.IsDense()); + bytes += fo->Write(page.max_numeric_bins_per_feat); + bytes += fo->Write(page.base_rowid); + bytes += fo->Write(page.IsDense()); bytes += page.WriteColumnPage(fo); return bytes; @@ -93,6 +99,4 @@ DMLC_REGISTRY_FILE_TAG(gradient_index_format); XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw) .describe("Raw GHistIndex binary data format.") .set_body([]() { return new GHistIndexRawFormat(); }); - -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/src/data/histogram_cut_format.h b/src/data/histogram_cut_format.h index 39961c4a2..45a96134f 100644 --- a/src/data/histogram_cut_format.h +++ b/src/data/histogram_cut_format.h @@ -1,36 +1,38 @@ -/*! - * Copyright 2021 XGBoost contributors +/** + * Copyright 2021-2023, XGBoost contributors */ #ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ #define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ -#include "../common/hist_util.h" +#include // for Stream -namespace xgboost { -namespace data { -inline bool ReadHistogramCuts(common::HistogramCuts *cuts, dmlc::SeekStream *fi) { - if (!fi->Read(&cuts->cut_values_.HostVector())) { +#include // for size_t + +#include "../common/hist_util.h" // for HistogramCuts +#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream +#include "../common/ref_resource_view.h" // for WriteVec, ReadVec + +namespace xgboost::data { +inline bool ReadHistogramCuts(common::HistogramCuts *cuts, common::AlignedResourceReadStream *fi) { + if (!common::ReadVec(fi, &cuts->cut_values_.HostVector())) { return false; } - if (!fi->Read(&cuts->cut_ptrs_.HostVector())) { + if (!common::ReadVec(fi, &cuts->cut_ptrs_.HostVector())) { return false; } - if (!fi->Read(&cuts->min_vals_.HostVector())) { + if (!common::ReadVec(fi, &cuts->min_vals_.HostVector())) { return false; } return true; } -inline size_t WriteHistogramCuts(common::HistogramCuts const &cuts, dmlc::Stream *fo) { - size_t bytes = 0; - fo->Write(cuts.cut_values_.ConstHostVector()); - bytes += cuts.cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t); - fo->Write(cuts.cut_ptrs_.ConstHostVector()); - bytes += cuts.cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t); - fo->Write(cuts.min_vals_.ConstHostVector()); - bytes += cuts.min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t); +inline std::size_t WriteHistogramCuts(common::HistogramCuts const &cuts, + common::AlignedFileWriteStream *fo) { + std::size_t bytes = 0; + bytes += common::WriteVec(fo, cuts.Values()); + bytes += common::WriteVec(fo, cuts.Ptrs()); + bytes += common::WriteVec(fo, cuts.MinValues()); return bytes; } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_ diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 627606aa3..c2c9a1d70 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -240,9 +240,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, * Generate gradient index. */ this->ghist_ = std::make_unique(Info(), std::move(cuts), p.max_bin); - size_t rbegin = 0; - size_t prev_sum = 0; - size_t i = 0; + std::size_t rbegin = 0; + std::size_t prev_sum = 0; + std::size_t i = 0; while (iter.Next()) { HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc index 1e5d1ec71..1edf27c46 100644 --- a/src/data/sparse_page_raw_format.cc +++ b/src/data/sparse_page_raw_format.cc @@ -1,59 +1,57 @@ -/*! - * Copyright (c) 2015-2021 by Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file sparse_page_raw_format.cc * Raw binary format of sparse page. */ -#include #include -#include "xgboost/logging.h" +#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream +#include "../common/ref_resource_view.h" // for WriteVec #include "./sparse_page_writer.h" +#include "xgboost/data.h" +#include "xgboost/logging.h" -namespace xgboost { -namespace data { - +namespace xgboost::data { DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format); -template +template class SparsePageRawFormat : public SparsePageFormat { public: - bool Read(T* page, dmlc::SeekStream* fi) override { + bool Read(T* page, common::AlignedResourceReadStream* fi) override { auto& offset_vec = page->offset.HostVector(); - if (!fi->Read(&offset_vec)) { + if (!common::ReadVec(fi, &offset_vec)) { return false; } auto& data_vec = page->data.HostVector(); CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file"; data_vec.resize(offset_vec.back()); if (page->data.Size() != 0) { - size_t n_bytes = fi->Read(dmlc::BeginPtr(data_vec), - (page->data).Size() * sizeof(Entry)); - CHECK_EQ(n_bytes, (page->data).Size() * sizeof(Entry)) - << "Invalid SparsePage file"; + if (!common::ReadVec(fi, &data_vec)) { + return false; + } + } + if (!fi->Read(&page->base_rowid, sizeof(page->base_rowid))) { + return false; } - fi->Read(&page->base_rowid, sizeof(page->base_rowid)); return true; } - size_t Write(const T& page, dmlc::Stream* fo) override { + std::size_t Write(const T& page, common::AlignedFileWriteStream* fo) override { const auto& offset_vec = page.offset.HostVector(); const auto& data_vec = page.data.HostVector(); CHECK(page.offset.Size() != 0 && offset_vec[0] == 0); CHECK_EQ(offset_vec.back(), page.data.Size()); - fo->Write(offset_vec); - auto bytes = page.MemCostBytes(); - bytes += sizeof(uint64_t); + + std::size_t bytes{0}; + bytes += common::WriteVec(fo, offset_vec); if (page.data.Size() != 0) { - fo->Write(dmlc::BeginPtr(data_vec), page.data.Size() * sizeof(Entry)); + bytes += common::WriteVec(fo, data_vec); } - fo->Write(&page.base_rowid, sizeof(page.base_rowid)); - bytes += sizeof(page.base_rowid); + bytes += fo->Write(&page.base_rowid, sizeof(page.base_rowid)); return bytes; } private: - /*! \brief external memory column offset */ - std::vector disk_offset_; }; XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) @@ -74,5 +72,4 @@ XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw) return new SparsePageRawFormat(); }); -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 9f7bee521..b32c536af 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -6,9 +6,11 @@ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #include // for min +#include // for atomic #include // for async #include #include +#include // for mutex #include #include #include // for pair, move @@ -18,7 +20,6 @@ #include "../common/io.h" // for PrivateMmapConstStream #include "../common/timer.h" // for Monitor, Timer #include "adapter.h" -#include "dmlc/common.h" // for OMPException #include "proxy_dmatrix.h" // for DMatrixProxy #include "sparse_page_writer.h" // for SparsePageFormat #include "xgboost/base.h" @@ -93,6 +94,47 @@ class TryLockGuard { } }; +// Similar to `dmlc::OMPException`, but doesn't need the threads to be joined before rethrow +class ExceHandler { + std::mutex mutex_; + std::atomic flag_{false}; + std::exception_ptr curr_exce_{nullptr}; + + public: + template + decltype(auto) Run(Fn&& fn) noexcept(true) { + try { + return fn(); + } catch (dmlc::Error const& e) { + std::lock_guard guard{mutex_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + } + flag_ = true; + } catch (std::exception const& e) { + std::lock_guard guard{mutex_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + } + flag_ = true; + } catch (...) { + std::lock_guard guard{mutex_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + } + flag_ = true; + } + return std::invoke_result_t(); + } + + void Rethrow() noexcept(false) { + if (flag_) { + CHECK(curr_exce_); + std::rethrow_exception(curr_exce_); + } + } +}; + /** * @brief Base class for all page sources. Handles fetching, writing, and iteration. */ @@ -122,7 +164,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { // Catching exception in pre-fetch threads to prevent segfault. Not always work though, // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then // OOM error should be rare. - dmlc::OMPException exec_; + ExceHandler exce_; common::Monitor monitor_; bool ReadCache() { @@ -141,7 +183,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; std::size_t fetch_it = count_; - exec_.Rethrow(); + exce_.Rethrow(); for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring @@ -152,7 +194,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { CHECK_LT(fetch_it, cache_info_->offset.size()); ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { auto page = std::make_shared(); - this->exec_.Run([&] { + this->exce_.Run([&] { std::unique_ptr> fmt{CreatePageFormat("raw")}; auto name = self->cache_info_->ShardName(); auto [offset, length] = self->cache_info_->View(fetch_it); @@ -172,7 +214,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { CHECK(!(*ring_)[count_].valid()); monitor_.Stop("Wait"); - exec_.Rethrow(); + exce_.Rethrow(); return true; } @@ -184,11 +226,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::unique_ptr> fmt{CreatePageFormat("raw")}; auto name = cache_info_->ShardName(); - std::unique_ptr fo; + std::unique_ptr fo; if (this->Iter() == 0) { - fo.reset(dmlc::Stream::Create(name.c_str(), "wb")); + fo = std::make_unique(StringView{name}, "wb"); } else { - fo.reset(dmlc::Stream::Create(name.c_str(), "ab")); + fo = std::make_unique(StringView{name}, "ab"); } auto bytes = fmt->Write(*page_, fo.get()); diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index 91a6504fe..c909d817d 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -1,52 +1,44 @@ -/*! - * Copyright (c) 2014-2019 by Contributors +/** + * Copyright 2014-2023, XGBoost Contributors * \file sparse_page_writer.h * \author Tianqi Chen */ #ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ #define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include // for function +#include // for string -#if DMLC_ENABLE_STD_THREAD -#include -#include -#endif // DMLC_ENABLE_STD_THREAD - -namespace xgboost { -namespace data { +#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream +#include "dmlc/io.h" // for Stream +#include "dmlc/registry.h" // for Registry, FunctionRegEntryBase +#include "xgboost/data.h" // for SparsePage,CSCPage,SortedCSCPage,EllpackPage ... +namespace xgboost::data { template struct SparsePageFormatReg; -/*! - * \brief Format specification of SparsePage. +/** + * @brief Format specification of various data formats like SparsePage. */ -template +template class SparsePageFormat { public: - /*! \brief virtual destructor */ virtual ~SparsePageFormat() = default; - /*! - * \brief Load all the segments into page, advance fi to end of the block. - * \param page The data to read page into. - * \param fi the input stream of the file - * \return true of the loading as successful, false if end of file was reached + /** + * @brief Load all the segments into page, advance fi to end of the block. + * + * @param page The data to read page into. + * @param fi the input stream of the file + * @return true of the loading as successful, false if end of file was reached */ - virtual bool Read(T* page, dmlc::SeekStream* fi) = 0; - /*! - * \brief save the data to fo, when a page was written. - * \param fo output stream + virtual bool Read(T* page, common::AlignedResourceReadStream* fi) = 0; + /** + * @brief save the data to fo, when a page was written. + * + * @param fo output stream */ - virtual size_t Write(const T& page, dmlc::Stream* fo) = 0; + virtual size_t Write(const T& page, common::AlignedFileWriteStream* fo) = 0; }; /*! @@ -105,6 +97,5 @@ struct SparsePageFormatReg DMLC_REGISTRY_REGISTER(SparsePageFormatReg, \ GHistIndexPageFmt, Name) -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index f67c05344..9d268b8d7 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -634,6 +634,22 @@ GBTree::GetPredictor(HostDeviceVector const *out_pred, return cpu_predictor_; } + // Data comes from SparsePageDMatrix. Since we are loading data in pages, no need to + // prevent data copy. + if (f_dmat && !f_dmat->SingleColBlock()) { + if (ctx_->IsCPU()) { + return cpu_predictor_; + } else { +#if defined(XGBOOST_USE_CUDA) + CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; + return gpu_predictor_; +#else + common::AssertGPUSupport(); + return cpu_predictor_; +#endif // defined(XGBOOST_USE_CUDA) + } + } + // Data comes from Device DMatrix. auto is_ellpack = f_dmat && f_dmat->PageExists() && !f_dmat->PageExists(); diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index a64b60b80..986e58c5a 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -3,11 +3,12 @@ */ #include -#include +#include // for size_t +#include // for ofstream #include "../../../src/common/io.h" -#include "../helpers.h" #include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../helpers.h" namespace xgboost::common { TEST(MemoryFixSizeBuffer, Seek) { @@ -89,6 +90,57 @@ TEST(IO, LoadSequentialFile) { ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error); } +TEST(IO, Resource) { + { + // test malloc basic + std::size_t n = 128; + std::shared_ptr resource = std::make_shared(n); + ASSERT_EQ(resource->Size(), n); + ASSERT_EQ(resource->Type(), ResourceHandler::kMalloc); + } + + // test malloc resize + auto test_malloc_resize = [](bool force_malloc) { + std::size_t n = 64; + std::shared_ptr resource = std::make_shared(n); + auto ptr = reinterpret_cast(resource->Data()); + std::iota(ptr, ptr + n, 0); + + auto malloc_resource = std::dynamic_pointer_cast(resource); + ASSERT_TRUE(malloc_resource); + if (force_malloc) { + malloc_resource->Resize(n * 2); + } else { + malloc_resource->Resize(n * 2); + } + for (std::size_t i = 0; i < n; ++i) { + ASSERT_EQ(malloc_resource->DataAs()[i], i) << force_malloc; + } + for (std::size_t i = n; i < 2 * n; ++i) { + ASSERT_EQ(malloc_resource->DataAs()[i], 0); + } + }; + test_malloc_resize(true); + test_malloc_resize(false); + + { + // test mmap + dmlc::TemporaryDirectory tmpdir; + auto path = tmpdir.path + "/testfile"; + + std::ofstream fout(path, std::ios::binary); + double val{1.0}; + fout.write(reinterpret_cast(&val), sizeof(val)); + fout << 1.0 << std::endl; + fout.close(); + + auto resource = std::make_shared(path, 0, sizeof(double)); + ASSERT_EQ(resource->Size(), sizeof(double)); + ASSERT_EQ(resource->Type(), ResourceHandler::kMmap); + ASSERT_EQ(resource->DataAs()[0], val); + } +} + TEST(IO, PrivateMmapStream) { dmlc::TemporaryDirectory tempdir; auto path = tempdir.path + "/testfile"; @@ -124,17 +176,35 @@ TEST(IO, PrivateMmapStream) { // Turn size info offset std::partial_sum(offset.begin(), offset.end(), offset.begin()); + // Test read for (std::size_t i = 0; i < n_batches; ++i) { std::size_t off = offset[i]; std::size_t n = offset.at(i + 1) - offset[i]; - std::unique_ptr fi{std::make_unique(path, off, n)}; + auto fi{std::make_unique(path, off, n)}; std::vector data; std::uint64_t size{0}; - fi->Read(&size); + ASSERT_TRUE(fi->Read(&size)); + ASSERT_EQ(fi->Tell(), sizeof(size)); data.resize(size); - fi->Read(data.data(), size * sizeof(T)); + ASSERT_EQ(fi->Read(data.data(), size * sizeof(T)), size * sizeof(T)); + ASSERT_EQ(data, batches[i]); + } + + // Test consume + for (std::size_t i = 0; i < n_batches; ++i) { + std::size_t off = offset[i]; + std::size_t n = offset.at(i + 1) - offset[i]; + std::unique_ptr fi{std::make_unique(path, off, n)}; + std::vector data; + + std::uint64_t size{0}; + ASSERT_TRUE(fi->Consume(&size)); + ASSERT_EQ(fi->Tell(), sizeof(size)); + data.resize(size); + + ASSERT_EQ(fi->Read(data.data(), size * sizeof(T)), sizeof(T) * size); ASSERT_EQ(data, batches[i]); } } diff --git a/tests/cpp/common/test_ref_resource_view.cc b/tests/cpp/common/test_ref_resource_view.cc new file mode 100644 index 000000000..9ae55fdec --- /dev/null +++ b/tests/cpp/common/test_ref_resource_view.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include + +#include // for size_t +#include // for make_shared, make_unique +#include // for iota +#include // for vector + +#include "../../../src/common/ref_resource_view.h" +#include "dmlc/filesystem.h" // for TemporaryDirectory + +namespace xgboost::common { +TEST(RefResourceView, Basic) { + std::size_t n_bytes = 1024; + auto mem = std::make_shared(n_bytes); + { + RefResourceView view{reinterpret_cast(mem->Data()), mem->Size() / sizeof(float), mem}; + + RefResourceView kview{reinterpret_cast(mem->Data()), mem->Size() / sizeof(float), + mem}; + ASSERT_EQ(mem.use_count(), 3); + ASSERT_EQ(view.size(), n_bytes / sizeof(1024)); + ASSERT_EQ(kview.size(), n_bytes / sizeof(1024)); + } + { + RefResourceView view{reinterpret_cast(mem->Data()), mem->Size() / sizeof(float), mem, + 1.5f}; + for (auto v : view) { + ASSERT_EQ(v, 1.5f); + } + std::iota(view.begin(), view.end(), 0.0f); + ASSERT_EQ(view.front(), 0.0f); + ASSERT_EQ(view.back(), static_cast(view.size() - 1)); + + view.front() = 1.0f; + view.back() = 2.0f; + ASSERT_EQ(view.front(), 1.0f); + ASSERT_EQ(view.back(), 2.0f); + } + ASSERT_EQ(mem.use_count(), 1); +} + +TEST(RefResourceView, IO) { + dmlc::TemporaryDirectory tmpdir; + auto path = tmpdir.path + "/testfile"; + auto data = MakeFixedVecWithMalloc(123, std::size_t{1}); + + { + auto fo = std::make_unique(StringView{path}, "wb"); + ASSERT_EQ(fo->Write(data.data(), data.size_bytes()), data.size_bytes()); + } + { + auto fo = std::make_unique(StringView{path}, "wb"); + ASSERT_EQ(WriteVec(fo.get(), data), + data.size_bytes() + sizeof(RefResourceView::size_type)); + } + { + auto fi = std::make_unique( + path, 0, data.size_bytes() + sizeof(RefResourceView::size_type)); + auto read = MakeFixedVecWithMalloc(123, std::size_t{1}); + ASSERT_TRUE(ReadVec(fi.get(), &read)); + for (auto v : read) { + ASSERT_EQ(v, 1ul); + } + } +} + +TEST(RefResourceView, IOAligned) { + dmlc::TemporaryDirectory tmpdir; + auto path = tmpdir.path + "/testfile"; + auto data = MakeFixedVecWithMalloc(123, 1.0f); + + { + auto fo = std::make_unique(StringView{path}, "wb"); + // + sizeof(float) for alignment + ASSERT_EQ(WriteVec(fo.get(), data), + data.size_bytes() + sizeof(RefResourceView::size_type) + sizeof(float)); + } + { + auto fi = std::make_unique( + path, 0, data.size_bytes() + sizeof(RefResourceView::size_type)); + // wrong type, float vs. double + auto read = MakeFixedVecWithMalloc(123, 2.0); + ASSERT_FALSE(ReadVec(fi.get(), &read)); + } + { + auto fi = std::make_unique( + path, 0, data.size_bytes() + sizeof(RefResourceView::size_type)); + auto read = MakeFixedVecWithMalloc(123, 2.0f); + ASSERT_TRUE(ReadVec(fi.get(), &read)); + for (auto v : read) { + ASSERT_EQ(v, 1ul); + } + } + { + // Test std::vector + std::vector data(123); + std::iota(data.begin(), data.end(), 0.0f); + auto fo = std::make_unique(StringView{path}, "wb"); + // + sizeof(float) for alignment + ASSERT_EQ(WriteVec(fo.get(), data), data.size() * sizeof(float) + + sizeof(RefResourceView::size_type) + + sizeof(float)); + } +} +} // namespace xgboost::common diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index 66d4024ec..f69b7b63a 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -4,14 +4,14 @@ #include #include +#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream... #include "../../../src/data/ellpack_page.cuh" #include "../../../src/data/sparse_page_source.h" #include "../../../src/tree/param.h" // TrainParam #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" -namespace xgboost { -namespace data { +namespace xgboost::data { TEST(EllpackPageRawFormat, IO) { Context ctx{MakeCUDACtx(0)}; auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; @@ -22,15 +22,17 @@ TEST(EllpackPageRawFormat, IO) { dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/ellpack.page"; + std::size_t n_bytes{0}; { - std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + auto fo = std::make_unique(StringView{path}, "wb"); for (auto const &ellpack : m->GetBatches(&ctx, param)) { - format->Write(ellpack, fo.get()); + n_bytes += format->Write(ellpack, fo.get()); } } EllpackPage page; - std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; + std::unique_ptr fi{ + std::make_unique(path.c_str(), 0, n_bytes)}; format->Read(&page, fi.get()); for (auto const &ellpack : m->GetBatches(&ctx, param)) { @@ -44,5 +46,4 @@ TEST(EllpackPageRawFormat, IO) { ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector()); } } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 22eb849ee..bd29c87b0 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -26,8 +26,7 @@ #include "xgboost/context.h" // for Context #include "xgboost/host_device_vector.h" // for HostDeviceVector -namespace xgboost { -namespace data { +namespace xgboost::data { TEST(GradientIndex, ExternalMemory) { Context ctx; std::unique_ptr dmat = CreateSparsePageDMatrix(10000); @@ -171,7 +170,7 @@ class GHistIndexMatrixTest : public testing::TestWithParamGetBatches( &gpu_ctx, BatchParam{kBins, tree::TrainParam::DftSparseThreshold()})) { - from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p}); + from_ellpack = std::make_unique(&ctx, Xy->Info(), page, p); } for (auto const &from_sparse_page : Xy->GetBatches(&ctx, p)) { @@ -199,13 +198,15 @@ class GHistIndexMatrixTest : public testing::TestWithParam +#include // for Context + +#include // for size_t +#include // for unique_ptr #include "../../../src/common/column_matrix.h" -#include "../../../src/data/gradient_index.h" +#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream... +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix #include "../../../src/data/sparse_page_source.h" -#include "../helpers.h" +#include "../helpers.h" // for RandomDataGenerator -namespace xgboost { -namespace data { +namespace xgboost::data { TEST(GHistIndexPageRawFormat, IO) { Context ctx; @@ -20,15 +24,18 @@ TEST(GHistIndexPageRawFormat, IO) { std::string path = tmpdir.path + "/ghistindex.page"; auto batch = BatchParam{256, 0.5}; + std::size_t bytes{0}; { - std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + auto fo = std::make_unique(StringView{path}, "wb"); for (auto const &index : m->GetBatches(&ctx, batch)) { - format->Write(index, fo.get()); + bytes += format->Write(index, fo.get()); } } GHistIndexMatrix page; - std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; + + std::unique_ptr fi{ + std::make_unique(path, 0, bytes)}; format->Read(&page, fi.get()); for (auto const &gidx : m->GetBatches(&ctx, batch)) { @@ -37,6 +44,8 @@ TEST(GHistIndexPageRawFormat, IO) { ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues()); ASSERT_EQ(loaded.cut.Values(), page.cut.Values()); ASSERT_EQ(loaded.base_rowid, page.base_rowid); + ASSERT_EQ(loaded.row_ptr.size(), page.row_ptr.size()); + ASSERT_TRUE(std::equal(loaded.row_ptr.cbegin(), loaded.row_ptr.cend(), page.row_ptr.cbegin())); ASSERT_EQ(loaded.IsDense(), page.IsDense()); ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin())); ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(), @@ -45,5 +54,4 @@ TEST(GHistIndexPageRawFormat, IO) { ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize()); } } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/tests/cpp/data/test_sparse_page_raw_format.cc b/tests/cpp/data/test_sparse_page_raw_format.cc index 722655880..bd0f97dcc 100644 --- a/tests/cpp/data/test_sparse_page_raw_format.cc +++ b/tests/cpp/data/test_sparse_page_raw_format.cc @@ -2,20 +2,20 @@ * Copyright 2021-2023, XGBoost contributors */ #include -#include // for CSCPage, SortedCSCPage, SparsePage +#include // for CSCPage, SortedCSCPage, SparsePage -#include // for allocator, unique_ptr, __shared_ptr_ac... -#include // for char_traits, operator+, basic_string +#include // for allocator, unique_ptr, __shared_ptr_ac... +#include // for char_traits, operator+, basic_string +#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream... #include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat #include "../helpers.h" // for RandomDataGenerator #include "dmlc/filesystem.h" // for TemporaryDirectory -#include "dmlc/io.h" // for SeekStream, Stream +#include "dmlc/io.h" // for Stream #include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST #include "xgboost/context.h" // for Context -namespace xgboost { -namespace data { +namespace xgboost::data { template void TestSparsePageRawFormat() { std::unique_ptr> format{CreatePageFormat("raw")}; Context ctx; @@ -25,17 +25,19 @@ template void TestSparsePageRawFormat() { dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/sparse.page"; S orig; + std::size_t n_bytes{0}; { // block code to flush the stream - std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + auto fo = std::make_unique(StringView{path}, "wb"); for (auto const &page : m->GetBatches(&ctx)) { orig.Push(page); - format->Write(page, fo.get()); + n_bytes = format->Write(page, fo.get()); } } S page; - std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; + std::unique_ptr fi{ + std::make_unique(path.c_str(), 0, n_bytes)}; format->Read(&page, fi.get()); for (size_t i = 0; i < orig.data.Size(); ++i) { ASSERT_EQ(page.data.HostVector()[i].fvalue, @@ -59,5 +61,4 @@ TEST(SparsePageRawFormat, CSCPage) { TEST(SparsePageRawFormat, SortedCSCPage) { TestSparsePageRawFormat(); } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data