Use ptr from mmap for GHistIndexMatrix and ColumnMatrix. (#9315)
* Use ptr from mmap for `GHistIndexMatrix` and `ColumnMatrix`. - Define a resource for holding various types of memory pointers. - Define ref vector for holding resources. - Swap the underlying resources for GHist and ColumnM. - Add documentation for current status. - s390x support is removed. It should work if you can compile XGBoost, all the old workaround code does is to get GCC to compile.
This commit is contained in:
parent
96c3071a8a
commit
bc267dd729
@ -33,6 +33,8 @@ DMatrix
|
|||||||
.. doxygengroup:: DMatrix
|
.. doxygengroup:: DMatrix
|
||||||
:project: xgboost
|
:project: xgboost
|
||||||
|
|
||||||
|
.. _c_streaming:
|
||||||
|
|
||||||
Streaming
|
Streaming
|
||||||
---------
|
---------
|
||||||
|
|
||||||
|
|||||||
@ -54,6 +54,9 @@ on a dask cluster:
|
|||||||
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
|
y = da.random.random(size=(num_obs, 1), chunks=(1000, 1))
|
||||||
|
|
||||||
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
# or
|
||||||
|
# dtrain = xgb.dask.DaskQuantileDMatrix(client, X, y)
|
||||||
|
# `DaskQuantileDMatrix` is available for the `hist` and `gpu_hist` tree method.
|
||||||
|
|
||||||
output = xgb.dask.train(
|
output = xgb.dask.train(
|
||||||
client,
|
client,
|
||||||
|
|||||||
@ -22,6 +22,15 @@ GPU-based training algorithm. We will introduce them in the following sections.
|
|||||||
|
|
||||||
The feature is still experimental as of 2.0. The performance is not well optimized.
|
The feature is still experimental as of 2.0. The performance is not well optimized.
|
||||||
|
|
||||||
|
The external memory support has gone through multiple iterations and is still under heavy
|
||||||
|
development. Like the :py:class:`~xgboost.QuantileDMatrix` with
|
||||||
|
:py:class:`~xgboost.DataIter`, XGBoost loads data batch-by-batch using a custom iterator
|
||||||
|
supplied by the user. However, unlike the :py:class:`~xgboost.QuantileDMatrix`, external
|
||||||
|
memory will not concatenate the batches unless GPU is used (it uses a hybrid approach,
|
||||||
|
more details follow). Instead, it will cache all batches on the external memory and fetch
|
||||||
|
them on-demand. Go to the end of the document to see a comparison between
|
||||||
|
`QuantileDMatrix` and external memory.
|
||||||
|
|
||||||
*************
|
*************
|
||||||
Data Iterator
|
Data Iterator
|
||||||
*************
|
*************
|
||||||
@ -113,10 +122,11 @@ External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set
|
|||||||
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
|
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
|
||||||
CPU. When training on a CPU, the tree method iterates through all batches from external
|
CPU. When training on a CPU, the tree method iterates through all batches from external
|
||||||
memory for each step of the tree construction algorithm. On the other hand, the GPU
|
memory for each step of the tree construction algorithm. On the other hand, the GPU
|
||||||
algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall
|
algorithm uses a hybrid approach. It iterates through the data during the beginning of
|
||||||
memory usage, users can utilize subsampling. The good news is that the GPU hist tree
|
each iteration and concatenates all batches into one in GPU memory. To reduce overall
|
||||||
method supports gradient-based sampling, enabling users to set a low sampling rate without
|
memory usage, users can utilize subsampling. The GPU hist tree method supports
|
||||||
compromising accuracy.
|
`gradient-based sampling`, enabling users to set a low sampling rate without compromising
|
||||||
|
accuracy.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -134,6 +144,8 @@ see `this paper <https://arxiv.org/abs/2005.09148>`_.
|
|||||||
When GPU is running out of memory during iteration on external memory, user might
|
When GPU is running out of memory during iteration on external memory, user might
|
||||||
recieve a segfault instead of an OOM exception.
|
recieve a segfault instead of an OOM exception.
|
||||||
|
|
||||||
|
.. _ext_remarks:
|
||||||
|
|
||||||
*******
|
*******
|
||||||
Remarks
|
Remarks
|
||||||
*******
|
*******
|
||||||
@ -142,17 +154,64 @@ When using external memory with XBGoost, data is divided into smaller chunks so
|
|||||||
a fraction of it needs to be stored in memory at any given time. It's important to note
|
a fraction of it needs to be stored in memory at any given time. It's important to note
|
||||||
that this method only applies to the predictor data (``X``), while other data, like labels
|
that this method only applies to the predictor data (``X``), while other data, like labels
|
||||||
and internal runtime structures are concatenated. This means that memory reduction is most
|
and internal runtime structures are concatenated. This means that memory reduction is most
|
||||||
effective when dealing with wide datasets where ``X`` is larger compared to other data
|
effective when dealing with wide datasets where ``X`` is significantly larger in size
|
||||||
like ``y``, while it has little impact on slim datasets.
|
compared to other data like ``y``, while it has little impact on slim datasets.
|
||||||
|
|
||||||
|
As one might expect, fetching data on-demand puts significant pressure on the storage
|
||||||
|
device. Today's computing device can process way more data than a storage can read in a
|
||||||
|
single unit of time. The ratio is at order of magnitudes. An GPU is capable of processing
|
||||||
|
hundred of Gigabytes of floating-point data in a split second. On the other hand, a
|
||||||
|
four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data transfer
|
||||||
|
rate. As a result, the training is likely to be severely bounded by your storage
|
||||||
|
device. Before adopting the external memory solution, some back-of-envelop calculations
|
||||||
|
might help you see whether it's viable. For instance, if your NVMe drive can transfer 4GB
|
||||||
|
(a fairly practical number) of data per second and you have a 100GB of data in compressed
|
||||||
|
XGBoost cache (which corresponds to a dense float32 numpy array with the size of 200GB,
|
||||||
|
give or take). A tree with depth 8 needs at least 16 iterations through the data when the
|
||||||
|
parameter is right. You need about 14 minutes to train a single tree without accounting
|
||||||
|
for some other overheads and assume the computation overlaps with the IO. If your dataset
|
||||||
|
happens to have TB-level size, then you might need thousands of trees to get a generalized
|
||||||
|
model. These calculations can help you get an estimate on the expected training time.
|
||||||
|
|
||||||
|
However, sometimes we can ameliorate this limitation. One should also consider that the OS
|
||||||
|
(mostly talking about the Linux kernel) can usually cache the data on host memory. It only
|
||||||
|
evicts pages when new data comes in and there's no room left. In practice, at least some
|
||||||
|
portion of the data can persist on the host memory throughout the entire training
|
||||||
|
session. We are aware of this cache when optimizing the external memory fetcher. The
|
||||||
|
compressed cache is usually smaller than the raw input data, especially when the input is
|
||||||
|
dense without any missing value. If the host memory can fit a significant portion of this
|
||||||
|
compressed cache, then the performance should be decent after initialization. Our
|
||||||
|
development so far focus on two fronts of optimization for external memory:
|
||||||
|
|
||||||
|
- Avoid iterating through the data whenever appropriate.
|
||||||
|
- If the OS can cache the data, the performance should be close to in-core training.
|
||||||
|
|
||||||
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
|
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
|
||||||
yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's
|
tested against system errors like disconnected network devices (`SIGBUS`). In the face of
|
||||||
worth noting that most tests have been conducted on Linux distributions.
|
a bus error, you will see a hard crash and need to clean up the cache files. If the
|
||||||
|
training session might take a long time and you are using solutions like NVMe-oF, we
|
||||||
|
recommend checkpointing your model periodically. Also, it's worth noting that most tests
|
||||||
|
have been conducted on Linux distributions.
|
||||||
|
|
||||||
|
|
||||||
Another important point to keep in mind is that creating the initial cache for XGBoost may
|
Another important point to keep in mind is that creating the initial cache for XGBoost may
|
||||||
take some time. The interface to external memory is through custom iterators, which may or
|
take some time. The interface to external memory is through custom iterators, which we can
|
||||||
may not be thread-safe. Therefore, initialization is performed sequentially.
|
not assume to be thread-safe. Therefore, initialization is performed sequentially. Using
|
||||||
|
the `xgboost.config_context` with `verbosity=2` can give you some information on what
|
||||||
|
XGBoost is doing during the wait if you don't mind the extra output.
|
||||||
|
|
||||||
|
*******************************
|
||||||
|
Compared to the QuantileDMatrix
|
||||||
|
*******************************
|
||||||
|
|
||||||
|
Passing an iterator to the :py:class:`~xgboost.QuantileDmatrix` enables direct
|
||||||
|
construction of `QuantileDmatrix` with data chunks. On the other hand, if it's passed to
|
||||||
|
:py:class:`~xgboost.DMatrix`, it instead enables the external memory feature. The
|
||||||
|
:py:class:`~xgboost.QuantileDmatrix` concatenates the data on memory after compression and
|
||||||
|
doesn't fetch data during training. On the other hand, the external memory `DMatrix`
|
||||||
|
fetches data batches from external memory on-demand. Use the `QuantileDMatrix` (with
|
||||||
|
iterator if necessary) when you can fit most of your data in memory. The training would be
|
||||||
|
an order of magnitute faster than using external memory.
|
||||||
|
|
||||||
****************
|
****************
|
||||||
Text File Inputs
|
Text File Inputs
|
||||||
|
|||||||
@ -11,22 +11,22 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
|||||||
|
|
||||||
model
|
model
|
||||||
saving_model
|
saving_model
|
||||||
|
learning_to_rank
|
||||||
|
dart
|
||||||
|
monotonic
|
||||||
|
feature_interaction_constraint
|
||||||
|
aft_survival_analysis
|
||||||
|
categorical
|
||||||
|
multioutput
|
||||||
|
rf
|
||||||
kubernetes
|
kubernetes
|
||||||
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
|
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
|
||||||
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
|
Distributed XGBoost with XGBoost4J-Spark-GPU <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_gpu_tutorial.html>
|
||||||
dask
|
dask
|
||||||
spark_estimator
|
spark_estimator
|
||||||
ray
|
ray
|
||||||
dart
|
external_memory
|
||||||
monotonic
|
|
||||||
rf
|
|
||||||
feature_interaction_constraint
|
|
||||||
learning_to_rank
|
|
||||||
aft_survival_analysis
|
|
||||||
c_api_tutorial
|
c_api_tutorial
|
||||||
input_format
|
input_format
|
||||||
param_tuning
|
param_tuning
|
||||||
external_memory
|
|
||||||
custom_metric_obj
|
custom_metric_obj
|
||||||
categorical
|
|
||||||
multioutput
|
|
||||||
|
|||||||
@ -58,3 +58,46 @@ This can affect the training of XGBoost model, and there are two ways to improve
|
|||||||
|
|
||||||
- In such a case, you cannot re-balance the dataset
|
- In such a case, you cannot re-balance the dataset
|
||||||
- Set parameter ``max_delta_step`` to a finite number (say 1) to help convergence
|
- Set parameter ``max_delta_step`` to a finite number (say 1) to help convergence
|
||||||
|
|
||||||
|
|
||||||
|
*********************
|
||||||
|
Reducing Memory Usage
|
||||||
|
*********************
|
||||||
|
|
||||||
|
If you are using a HPO library like :py:class:`sklearn.model_selection.GridSearchCV`,
|
||||||
|
please control the number of threads it can use. It's best to let XGBoost to run in
|
||||||
|
parallel instead of asking `GridSearchCV` to run multiple experiments at the same
|
||||||
|
time. For instance, creating a fold of data for cross validation can consume a significant
|
||||||
|
amount of memory:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# This creates a copy of dataset. X and X_train are both in memory at the same time.
|
||||||
|
|
||||||
|
# This happens for every thread at the same time if you run `GridSearchCV` with
|
||||||
|
# `n_jobs` larger than 1
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
df = pd.DataFrame()
|
||||||
|
# This creates a new copy of the dataframe, even if you specify the inplace parameter
|
||||||
|
new_df = df.drop(...)
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
array = np.array(...)
|
||||||
|
# This may or may not make a copy of the data, depending on the type of the data
|
||||||
|
array.astype(np.float32)
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
# np by default uses double, do you actually need it?
|
||||||
|
array = np.array(...)
|
||||||
|
|
||||||
|
You can find some more specific memory reduction practices scattered through the documents
|
||||||
|
For instances: :doc:`/tutorials/dask`, :doc:`/gpu/index`,
|
||||||
|
:doc:`/contrib/scaling`. However, before going into these, being conscious about making
|
||||||
|
data copies is a good starting point. It usually consumes a lot more memory than people
|
||||||
|
expect.
|
||||||
|
|||||||
@ -19,8 +19,7 @@
|
|||||||
#include "rabit/internal/utils.h"
|
#include "rabit/internal/utils.h"
|
||||||
#include "rabit/serializable.h"
|
#include "rabit/serializable.h"
|
||||||
|
|
||||||
namespace rabit {
|
namespace rabit::utils {
|
||||||
namespace utils {
|
|
||||||
/*! \brief re-use definition of dmlc::SeekStream */
|
/*! \brief re-use definition of dmlc::SeekStream */
|
||||||
using SeekStream = dmlc::SeekStream;
|
using SeekStream = dmlc::SeekStream;
|
||||||
/**
|
/**
|
||||||
@ -31,9 +30,6 @@ struct MemoryFixSizeBuffer : public SeekStream {
|
|||||||
// similar to SEEK_END in libc
|
// similar to SEEK_END in libc
|
||||||
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||||
|
|
||||||
protected:
|
|
||||||
MemoryFixSizeBuffer() = default;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Ctor
|
* @brief Ctor
|
||||||
@ -68,7 +64,7 @@ struct MemoryFixSizeBuffer : public SeekStream {
|
|||||||
* @brief Current position in the buffer (stream).
|
* @brief Current position in the buffer (stream).
|
||||||
*/
|
*/
|
||||||
std::size_t Tell() override { return curr_ptr_; }
|
std::size_t Tell() override { return curr_ptr_; }
|
||||||
virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
[[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/*! \brief in memory buffer */
|
/*! \brief in memory buffer */
|
||||||
@ -119,6 +115,5 @@ struct MemoryBufferStream : public SeekStream {
|
|||||||
/*! \brief current pointer */
|
/*! \brief current pointer */
|
||||||
size_t curr_ptr_;
|
size_t curr_ptr_;
|
||||||
}; // class MemoryBufferStream
|
}; // class MemoryBufferStream
|
||||||
} // namespace utils
|
} // namespace rabit::utils
|
||||||
} // namespace rabit
|
|
||||||
#endif // RABIT_INTERNAL_IO_H_
|
#endif // RABIT_INTERNAL_IO_H_
|
||||||
|
|||||||
@ -1,16 +1,27 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by XGBoost Contributors
|
* Copyright 2017-2023, XGBoost Contributors
|
||||||
* \brief Utility for fast column-wise access
|
* \brief Utility for fast column-wise access
|
||||||
*/
|
*/
|
||||||
#include "column_matrix.h"
|
#include "column_matrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
#include <algorithm> // for transform
|
||||||
namespace common {
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for uint64_t, uint8_t
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <type_traits> // for remove_reference_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
|
#include "io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||||
|
#include "xgboost/base.h" // for bst_feaature_t
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold) {
|
void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold) {
|
||||||
auto const nfeature = gmat.Features();
|
auto const nfeature = gmat.Features();
|
||||||
const size_t nrow = gmat.Size();
|
const size_t nrow = gmat.Size();
|
||||||
// identify type of each column
|
// identify type of each column
|
||||||
type_.resize(nfeature);
|
type_ = common::MakeFixedVecWithMalloc(nfeature, ColumnType{});
|
||||||
|
|
||||||
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
||||||
for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
|
for (bst_feature_t fid = 0; fid < nfeature; ++fid) {
|
||||||
@ -34,7 +45,7 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
|
|||||||
|
|
||||||
// want to compute storage boundary for each feature
|
// want to compute storage boundary for each feature
|
||||||
// using variants of prefix sum scan
|
// using variants of prefix sum scan
|
||||||
feature_offsets_.resize(nfeature + 1);
|
feature_offsets_ = common::MakeFixedVecWithMalloc(nfeature + 1, std::size_t{0});
|
||||||
size_t accum_index = 0;
|
size_t accum_index = 0;
|
||||||
feature_offsets_[0] = accum_index;
|
feature_offsets_[0] = accum_index;
|
||||||
for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
|
for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) {
|
||||||
@ -49,9 +60,11 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
|
|||||||
SetTypeSize(gmat.MaxNumBinPerFeat());
|
SetTypeSize(gmat.MaxNumBinPerFeat());
|
||||||
auto storage_size =
|
auto storage_size =
|
||||||
feature_offsets_.back() * static_cast<std::underlying_type_t<BinTypeSize>>(bins_type_size_);
|
feature_offsets_.back() * static_cast<std::underlying_type_t<BinTypeSize>>(bins_type_size_);
|
||||||
index_.resize(storage_size, 0);
|
|
||||||
|
index_ = common::MakeFixedVecWithMalloc(storage_size, std::uint8_t{0});
|
||||||
|
|
||||||
if (!all_dense_column) {
|
if (!all_dense_column) {
|
||||||
row_ind_.resize(feature_offsets_[nfeature]);
|
row_ind_ = common::MakeFixedVecWithMalloc(feature_offsets_[nfeature], std::size_t{0});
|
||||||
}
|
}
|
||||||
|
|
||||||
// store least bin id for each feature
|
// store least bin id for each feature
|
||||||
@ -59,7 +72,51 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
|
|||||||
|
|
||||||
any_missing_ = !gmat.IsDense();
|
any_missing_ = !gmat.IsDense();
|
||||||
|
|
||||||
missing_flags_.clear();
|
missing_ = MissingIndicator{0, false};
|
||||||
}
|
}
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
// IO procedures for external memory.
|
||||||
|
bool ColumnMatrix::Read(AlignedResourceReadStream* fi, uint32_t const* index_base) {
|
||||||
|
if (!common::ReadVec(fi, &index_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!common::ReadVec(fi, &type_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!common::ReadVec(fi, &row_ind_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!common::ReadVec(fi, &feature_offsets_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!common::ReadVec(fi, &missing_.storage)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
missing_.InitView();
|
||||||
|
|
||||||
|
index_base_ = index_base;
|
||||||
|
if (!fi->Read(&bins_type_size_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!fi->Read(&any_missing_)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t ColumnMatrix::Write(AlignedFileWriteStream* fo) const {
|
||||||
|
std::size_t bytes{0};
|
||||||
|
|
||||||
|
bytes += common::WriteVec(fo, index_);
|
||||||
|
bytes += common::WriteVec(fo, type_);
|
||||||
|
bytes += common::WriteVec(fo, row_ind_);
|
||||||
|
bytes += common::WriteVec(fo, feature_offsets_);
|
||||||
|
bytes += common::WriteVec(fo, missing_.storage);
|
||||||
|
|
||||||
|
bytes += fo->Write(bins_type_size_);
|
||||||
|
bytes += fo->Write(any_missing_);
|
||||||
|
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by Contributors
|
* Copyright 2017-2023, XGBoost Contributors
|
||||||
* \file column_matrix.h
|
* \file column_matrix.h
|
||||||
* \brief Utility for fast column-wise access
|
* \brief Utility for fast column-wise access
|
||||||
* \author Philip Cho
|
* \author Philip Cho
|
||||||
@ -8,25 +8,30 @@
|
|||||||
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
|
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||||
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||||
|
|
||||||
#include <dmlc/endian.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for uint8_t
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility> // std::move
|
#include <utility> // for move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/gradient_index.h"
|
#include "../data/gradient_index.h"
|
||||||
#include "algorithm.h"
|
#include "algorithm.h"
|
||||||
|
#include "bitfield.h" // for RBitField8
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
|
#include "ref_resource_view.h" // for RefResourceView
|
||||||
|
#include "xgboost/base.h" // for bst_bin_t
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
class ColumnMatrix;
|
class ColumnMatrix;
|
||||||
|
class AlignedFileWriteStream;
|
||||||
|
class AlignedResourceReadStream;
|
||||||
|
|
||||||
/*! \brief column type */
|
/*! \brief column type */
|
||||||
enum ColumnType : uint8_t { kDenseColumn, kSparseColumn };
|
enum ColumnType : std::uint8_t { kDenseColumn, kSparseColumn };
|
||||||
|
|
||||||
/*! \brief a column storage, to be used with ApplySplit. Note that each
|
/*! \brief a column storage, to be used with ApplySplit. Note that each
|
||||||
bin id is stored as index[i] + index_base.
|
bin id is stored as index[i] + index_base.
|
||||||
@ -41,12 +46,12 @@ class Column {
|
|||||||
: index_(index), index_base_(least_bin_idx) {}
|
: index_(index), index_base_(least_bin_idx) {}
|
||||||
virtual ~Column() = default;
|
virtual ~Column() = default;
|
||||||
|
|
||||||
bst_bin_t GetGlobalBinIdx(size_t idx) const {
|
[[nodiscard]] bst_bin_t GetGlobalBinIdx(size_t idx) const {
|
||||||
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
|
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* returns number of elements in column */
|
/* returns number of elements in column */
|
||||||
size_t Size() const { return index_.size(); }
|
[[nodiscard]] size_t Size() const { return index_.size(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/* bin indexes in range [0, max_bins - 1] */
|
/* bin indexes in range [0, max_bins - 1] */
|
||||||
@ -63,7 +68,7 @@ class SparseColumnIter : public Column<BinIdxT> {
|
|||||||
common::Span<const size_t> row_ind_;
|
common::Span<const size_t> row_ind_;
|
||||||
size_t idx_;
|
size_t idx_;
|
||||||
|
|
||||||
size_t const* RowIndices() const { return row_ind_.data(); }
|
[[nodiscard]] size_t const* RowIndices() const { return row_ind_.data(); }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
|
SparseColumnIter(common::Span<const BinIdxT> index, bst_bin_t least_bin_idx,
|
||||||
@ -81,7 +86,7 @@ class SparseColumnIter : public Column<BinIdxT> {
|
|||||||
SparseColumnIter(SparseColumnIter const&) = delete;
|
SparseColumnIter(SparseColumnIter const&) = delete;
|
||||||
SparseColumnIter(SparseColumnIter&&) = default;
|
SparseColumnIter(SparseColumnIter&&) = default;
|
||||||
|
|
||||||
size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; }
|
[[nodiscard]] size_t GetRowIdx(size_t idx) const { return RowIndices()[idx]; }
|
||||||
bst_bin_t operator[](size_t rid) {
|
bst_bin_t operator[](size_t rid) {
|
||||||
const size_t column_size = this->Size();
|
const size_t column_size = this->Size();
|
||||||
if (!((idx_) < column_size)) {
|
if (!((idx_) < column_size)) {
|
||||||
@ -101,6 +106,10 @@ class SparseColumnIter : public Column<BinIdxT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Column stored as a dense vector. It might still contain missing values as
|
||||||
|
* indicated by the missing flags.
|
||||||
|
*/
|
||||||
template <typename BinIdxT, bool any_missing>
|
template <typename BinIdxT, bool any_missing>
|
||||||
class DenseColumnIter : public Column<BinIdxT> {
|
class DenseColumnIter : public Column<BinIdxT> {
|
||||||
public:
|
public:
|
||||||
@ -109,17 +118,19 @@ class DenseColumnIter : public Column<BinIdxT> {
|
|||||||
private:
|
private:
|
||||||
using Base = Column<BinIdxT>;
|
using Base = Column<BinIdxT>;
|
||||||
/* flags for missing values in dense columns */
|
/* flags for missing values in dense columns */
|
||||||
std::vector<ByteType> const& missing_flags_;
|
LBitField32 missing_flags_;
|
||||||
size_t feature_offset_;
|
size_t feature_offset_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
|
||||||
std::vector<ByteType> const& missing_flags, size_t feature_offset)
|
LBitField32 missing_flags, size_t feature_offset)
|
||||||
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
|
||||||
DenseColumnIter(DenseColumnIter const&) = delete;
|
DenseColumnIter(DenseColumnIter const&) = delete;
|
||||||
DenseColumnIter(DenseColumnIter&&) = default;
|
DenseColumnIter(DenseColumnIter&&) = default;
|
||||||
|
|
||||||
bool IsMissing(size_t ridx) const { return missing_flags_[feature_offset_ + ridx]; }
|
[[nodiscard]] bool IsMissing(size_t ridx) const {
|
||||||
|
return missing_flags_.Check(feature_offset_ + ridx);
|
||||||
|
}
|
||||||
|
|
||||||
bst_bin_t operator[](size_t ridx) const {
|
bst_bin_t operator[](size_t ridx) const {
|
||||||
if (any_missing) {
|
if (any_missing) {
|
||||||
@ -131,12 +142,54 @@ class DenseColumnIter : public Column<BinIdxT> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Column major matrix for gradient index. This matrix contains both dense column
|
* @brief Column major matrix for gradient index on CPU.
|
||||||
* and sparse column, the type of the column is controlled by sparse threshold. When the
|
*
|
||||||
* number of missing values in a column is below the threshold it's classified as dense
|
* This matrix contains both dense columns and sparse columns, the type of the column
|
||||||
* column.
|
* is controlled by the sparse threshold parameter. When the number of missing values
|
||||||
|
* in a column is below the threshold it's classified as dense column.
|
||||||
*/
|
*/
|
||||||
class ColumnMatrix {
|
class ColumnMatrix {
|
||||||
|
/**
|
||||||
|
* @brief A bit set for indicating whether an element in a dense column is missing.
|
||||||
|
*/
|
||||||
|
struct MissingIndicator {
|
||||||
|
LBitField32 missing;
|
||||||
|
RefResourceView<std::uint32_t> storage;
|
||||||
|
|
||||||
|
MissingIndicator() = default;
|
||||||
|
/**
|
||||||
|
* @param n_elements Size of the bit set
|
||||||
|
* @param init Initialize the indicator to true or false.
|
||||||
|
*/
|
||||||
|
MissingIndicator(std::size_t n_elements, bool init) {
|
||||||
|
auto m_size = missing.ComputeStorageSize(n_elements);
|
||||||
|
storage = common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
|
||||||
|
this->InitView();
|
||||||
|
}
|
||||||
|
/** @brief Set the i^th element to be a valid element (instead of missing). */
|
||||||
|
void SetValid(typename LBitField32::index_type i) { missing.Clear(i); }
|
||||||
|
/** @brief assign the storage to the view. */
|
||||||
|
void InitView() {
|
||||||
|
missing = LBitField32{Span{storage.data(), storage.size()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void GrowTo(std::size_t n_elements, bool init) {
|
||||||
|
CHECK(storage.Resource()->Type() == ResourceHandler::kMalloc)
|
||||||
|
<< "[Internal Error]: Cannot grow the vector when external memory is used.";
|
||||||
|
auto m_size = missing.ComputeStorageSize(n_elements);
|
||||||
|
CHECK_GE(m_size, storage.size());
|
||||||
|
if (m_size == storage.size()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_storage =
|
||||||
|
common::MakeFixedVecWithMalloc(m_size, init ? ~std::uint32_t{0} : std::uint32_t{0});
|
||||||
|
std::copy_n(storage.cbegin(), storage.size(), new_storage.begin());
|
||||||
|
storage = std::move(new_storage);
|
||||||
|
this->InitView();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);
|
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);
|
||||||
|
|
||||||
template <typename ColumnBinT, typename BinT, typename RIdx>
|
template <typename ColumnBinT, typename BinT, typename RIdx>
|
||||||
@ -144,9 +197,10 @@ class ColumnMatrix {
|
|||||||
if (type_[fid] == kDenseColumn) {
|
if (type_[fid] == kDenseColumn) {
|
||||||
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
||||||
begin[rid] = bin_id - index_base_[fid];
|
begin[rid] = bin_id - index_base_[fid];
|
||||||
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
|
// not thread-safe with bit field.
|
||||||
// kMissingId to the index to avoid missing flags.
|
// FIXME(jiamingy): We can directly assign kMissingId to the index to avoid missing
|
||||||
missing_flags_[feature_offsets_[fid] + rid] = false;
|
// flags.
|
||||||
|
missing_.SetValid(feature_offsets_[fid] + rid);
|
||||||
} else {
|
} else {
|
||||||
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
|
||||||
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
|
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
|
||||||
@ -158,7 +212,9 @@ class ColumnMatrix {
|
|||||||
public:
|
public:
|
||||||
using ByteType = bool;
|
using ByteType = bool;
|
||||||
// get number of features
|
// get number of features
|
||||||
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
|
[[nodiscard]] bst_feature_t GetNumFeature() const {
|
||||||
|
return static_cast<bst_feature_t>(type_.size());
|
||||||
|
}
|
||||||
|
|
||||||
ColumnMatrix() = default;
|
ColumnMatrix() = default;
|
||||||
ColumnMatrix(GHistIndexMatrix const& gmat, double sparse_threshold) {
|
ColumnMatrix(GHistIndexMatrix const& gmat, double sparse_threshold) {
|
||||||
@ -166,7 +222,7 @@ class ColumnMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original
|
* @brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original
|
||||||
* SparsePage.
|
* SparsePage.
|
||||||
*/
|
*/
|
||||||
void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
|
void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
|
||||||
@ -178,8 +234,8 @@ class ColumnMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual
|
* @brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual
|
||||||
* data.
|
* data.
|
||||||
*
|
*
|
||||||
* This function requires a binary search for each bin to get back the feature index
|
* This function requires a binary search for each bin to get back the feature index
|
||||||
* for those bins.
|
* for those bins.
|
||||||
@ -199,7 +255,7 @@ class ColumnMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsInitialized() const { return !type_.empty(); }
|
[[nodiscard]] bool IsInitialized() const { return !type_.empty(); }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Push batch of data for Quantile DMatrix support.
|
* \brief Push batch of data for Quantile DMatrix support.
|
||||||
@ -257,7 +313,7 @@ class ColumnMatrix {
|
|||||||
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
|
reinterpret_cast<const BinIdxType*>(&index_[feature_offset * bins_type_size_]),
|
||||||
column_size};
|
column_size};
|
||||||
return std::move(DenseColumnIter<BinIdxType, any_missing>{
|
return std::move(DenseColumnIter<BinIdxType, any_missing>{
|
||||||
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_flags_, feature_offset});
|
bin_index, static_cast<bst_bin_t>(index_base_[fidx]), missing_.missing, feature_offset});
|
||||||
}
|
}
|
||||||
|
|
||||||
// all columns are dense column and has no missing value
|
// all columns are dense column and has no missing value
|
||||||
@ -265,7 +321,8 @@ class ColumnMatrix {
|
|||||||
template <typename RowBinIdxT>
|
template <typename RowBinIdxT>
|
||||||
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
|
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
|
||||||
const size_t n_features, int32_t n_threads) {
|
const size_t n_features, int32_t n_threads) {
|
||||||
missing_flags_.resize(feature_offsets_[n_features], false);
|
missing_.GrowTo(feature_offsets_[n_features], false);
|
||||||
|
|
||||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||||
using ColumnBinT = decltype(t);
|
using ColumnBinT = decltype(t);
|
||||||
auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
|
auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
|
||||||
@ -290,9 +347,15 @@ class ColumnMatrix {
|
|||||||
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
|
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
|
||||||
float missing) {
|
float missing) {
|
||||||
auto n_features = gmat.Features();
|
auto n_features = gmat.Features();
|
||||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
|
||||||
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[base_rowid];
|
missing_.GrowTo(feature_offsets_[n_features], true);
|
||||||
num_nonzeros_.resize(n_features, 0);
|
auto const* row_index = gmat.index.data<std::uint32_t>() + gmat.row_ptr[base_rowid];
|
||||||
|
if (num_nonzeros_.empty()) {
|
||||||
|
num_nonzeros_ = common::MakeFixedVecWithMalloc(n_features, std::size_t{0});
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(num_nonzeros_.size(), n_features);
|
||||||
|
}
|
||||||
|
|
||||||
auto is_valid = data::IsValidFunctor{missing};
|
auto is_valid = data::IsValidFunctor{missing};
|
||||||
|
|
||||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||||
@ -321,8 +384,9 @@ class ColumnMatrix {
|
|||||||
*/
|
*/
|
||||||
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
|
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
|
||||||
auto n_features = gmat.Features();
|
auto n_features = gmat.Features();
|
||||||
missing_flags_.resize(feature_offsets_[n_features], true);
|
|
||||||
num_nonzeros_.resize(n_features, 0);
|
missing_ = MissingIndicator{feature_offsets_[n_features], true};
|
||||||
|
num_nonzeros_ = common::MakeFixedVecWithMalloc(n_features, std::size_t{0});
|
||||||
|
|
||||||
DispatchBinType(bins_type_size_, [&](auto t) {
|
DispatchBinType(bins_type_size_, [&](auto t) {
|
||||||
using ColumnBinT = decltype(t);
|
using ColumnBinT = decltype(t);
|
||||||
@ -335,106 +399,34 @@ class ColumnMatrix {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
BinTypeSize GetTypeSize() const { return bins_type_size_; }
|
[[nodiscard]] BinTypeSize GetTypeSize() const { return bins_type_size_; }
|
||||||
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
|
[[nodiscard]] auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }
|
||||||
|
|
||||||
// And this returns part of state
|
// And this returns part of state
|
||||||
bool AnyMissing() const { return any_missing_; }
|
[[nodiscard]] bool AnyMissing() const { return any_missing_; }
|
||||||
|
|
||||||
// IO procedures for external memory.
|
// IO procedures for external memory.
|
||||||
bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) {
|
[[nodiscard]] bool Read(AlignedResourceReadStream* fi, uint32_t const* index_base);
|
||||||
fi->Read(&index_);
|
[[nodiscard]] std::size_t Write(AlignedFileWriteStream* fo) const;
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
|
||||||
// s390x
|
|
||||||
std::vector<std::underlying_type<ColumnType>::type> int_types;
|
|
||||||
fi->Read(&int_types);
|
|
||||||
type_.resize(int_types.size());
|
|
||||||
std::transform(
|
|
||||||
int_types.begin(), int_types.end(), type_.begin(),
|
|
||||||
[](std::underlying_type<ColumnType>::type i) { return static_cast<ColumnType>(i); });
|
|
||||||
#else
|
|
||||||
fi->Read(&type_);
|
|
||||||
#endif // !DMLC_LITTLE_ENDIAN
|
|
||||||
|
|
||||||
fi->Read(&row_ind_);
|
|
||||||
fi->Read(&feature_offsets_);
|
|
||||||
|
|
||||||
std::vector<std::uint8_t> missing;
|
|
||||||
fi->Read(&missing);
|
|
||||||
missing_flags_.resize(missing.size());
|
|
||||||
std::transform(missing.cbegin(), missing.cend(), missing_flags_.begin(),
|
|
||||||
[](std::uint8_t flag) { return !!flag; });
|
|
||||||
|
|
||||||
index_base_ = index_base;
|
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
|
||||||
std::underlying_type<BinTypeSize>::type v;
|
|
||||||
fi->Read(&v);
|
|
||||||
bins_type_size_ = static_cast<BinTypeSize>(v);
|
|
||||||
#else
|
|
||||||
fi->Read(&bins_type_size_);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
fi->Read(&any_missing_);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t Write(dmlc::Stream* fo) const {
|
|
||||||
size_t bytes{0};
|
|
||||||
|
|
||||||
auto write_vec = [&](auto const& vec) {
|
|
||||||
fo->Write(vec);
|
|
||||||
bytes += vec.size() * sizeof(typename std::remove_reference_t<decltype(vec)>::value_type) +
|
|
||||||
sizeof(uint64_t);
|
|
||||||
};
|
|
||||||
write_vec(index_);
|
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
|
||||||
// s390x
|
|
||||||
std::vector<std::underlying_type<ColumnType>::type> int_types(type_.size());
|
|
||||||
std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) {
|
|
||||||
return static_cast<std::underlying_type<ColumnType>::type>(t);
|
|
||||||
});
|
|
||||||
write_vec(int_types);
|
|
||||||
#else
|
|
||||||
write_vec(type_);
|
|
||||||
#endif // !DMLC_LITTLE_ENDIAN
|
|
||||||
write_vec(row_ind_);
|
|
||||||
write_vec(feature_offsets_);
|
|
||||||
// dmlc can not handle bool vector
|
|
||||||
std::vector<std::uint8_t> missing(missing_flags_.size());
|
|
||||||
std::transform(missing_flags_.cbegin(), missing_flags_.cend(), missing.begin(),
|
|
||||||
[](bool flag) { return static_cast<std::uint8_t>(flag); });
|
|
||||||
write_vec(missing);
|
|
||||||
|
|
||||||
#if !DMLC_LITTLE_ENDIAN
|
|
||||||
auto v = static_cast<std::underlying_type<BinTypeSize>::type>(bins_type_size_);
|
|
||||||
fo->Write(v);
|
|
||||||
#else
|
|
||||||
fo->Write(bins_type_size_);
|
|
||||||
#endif // DMLC_LITTLE_ENDIAN
|
|
||||||
bytes += sizeof(bins_type_size_);
|
|
||||||
fo->Write(any_missing_);
|
|
||||||
bytes += sizeof(any_missing_);
|
|
||||||
|
|
||||||
return bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<uint8_t> index_;
|
RefResourceView<std::uint8_t> index_;
|
||||||
|
|
||||||
std::vector<ColumnType> type_;
|
RefResourceView<ColumnType> type_;
|
||||||
/* indptr of a CSC matrix. */
|
/** @brief indptr of a CSC matrix. */
|
||||||
std::vector<size_t> row_ind_;
|
RefResourceView<std::size_t> row_ind_;
|
||||||
/* indicate where each column's index and row_ind is stored. */
|
/** @brief indicate where each column's index and row_ind is stored. */
|
||||||
std::vector<size_t> feature_offsets_;
|
RefResourceView<std::size_t> feature_offsets_;
|
||||||
/* The number of nnz of each column. */
|
/** @brief The number of nnz of each column. */
|
||||||
std::vector<size_t> num_nonzeros_;
|
RefResourceView<std::size_t> num_nonzeros_;
|
||||||
|
|
||||||
// index_base_[fid]: least bin id for feature fid
|
// index_base_[fid]: least bin id for feature fid
|
||||||
uint32_t const* index_base_;
|
std::uint32_t const* index_base_;
|
||||||
std::vector<ByteType> missing_flags_;
|
|
||||||
|
MissingIndicator missing_;
|
||||||
|
|
||||||
BinTypeSize bins_type_size_;
|
BinTypeSize bins_type_size_;
|
||||||
bool any_missing_;
|
bool any_missing_;
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_COMMON_COLUMN_MATRIX_H_
|
#endif // XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||||
|
|||||||
@ -203,13 +203,33 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Optionally compressed gradient index. The compression works only with dense
|
* @brief Optionally compressed gradient index. The compression works only with dense
|
||||||
* data.
|
* data.
|
||||||
*
|
*
|
||||||
* The main body of construction code is in gradient_index.cc, this struct is only a
|
* The main body of construction code is in gradient_index.cc, this struct is only a
|
||||||
* storage class.
|
* view class.
|
||||||
*/
|
*/
|
||||||
struct Index {
|
class Index {
|
||||||
|
private:
|
||||||
|
void SetBinTypeSize(BinTypeSize binTypeSize) {
|
||||||
|
binTypeSize_ = binTypeSize;
|
||||||
|
switch (binTypeSize) {
|
||||||
|
case kUint8BinsTypeSize:
|
||||||
|
func_ = &GetValueFromUint8;
|
||||||
|
break;
|
||||||
|
case kUint16BinsTypeSize:
|
||||||
|
func_ = &GetValueFromUint16;
|
||||||
|
break;
|
||||||
|
case kUint32BinsTypeSize:
|
||||||
|
func_ = &GetValueFromUint32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize ||
|
||||||
|
binTypeSize == kUint32BinsTypeSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
// Inside the compressor, bin_idx is the index for cut value across all features. By
|
// Inside the compressor, bin_idx is the index for cut value across all features. By
|
||||||
// subtracting it with starting pointer of each feature, we can reduce it to smaller
|
// subtracting it with starting pointer of each feature, we can reduce it to smaller
|
||||||
// value and store it with smaller types. Usable only with dense data.
|
// value and store it with smaller types. Usable only with dense data.
|
||||||
@ -233,10 +253,24 @@ struct Index {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Index() { SetBinTypeSize(binTypeSize_); }
|
Index() { SetBinTypeSize(binTypeSize_); }
|
||||||
Index(const Index& i) = delete;
|
|
||||||
Index& operator=(Index i) = delete;
|
Index(Index const& i) = delete;
|
||||||
|
Index& operator=(Index const& i) = delete;
|
||||||
Index(Index&& i) = delete;
|
Index(Index&& i) = delete;
|
||||||
Index& operator=(Index&& i) = delete;
|
|
||||||
|
/** @brief Move assignment for lazy initialization. */
|
||||||
|
Index& operator=(Index&& i) = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct the index from data.
|
||||||
|
*
|
||||||
|
* @param data Storage for compressed histogram bin.
|
||||||
|
* @param bin_size Number of bytes for each bin.
|
||||||
|
*/
|
||||||
|
Index(Span<std::uint8_t> data, BinTypeSize bin_size) : data_{data} {
|
||||||
|
this->SetBinTypeSize(bin_size);
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t operator[](size_t i) const {
|
uint32_t operator[](size_t i) const {
|
||||||
if (!bin_offset_.empty()) {
|
if (!bin_offset_.empty()) {
|
||||||
// dense, compressed
|
// dense, compressed
|
||||||
@ -247,26 +281,7 @@ struct Index {
|
|||||||
return func_(data_.data(), i);
|
return func_(data_.data(), i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void SetBinTypeSize(BinTypeSize binTypeSize) {
|
[[nodiscard]] BinTypeSize GetBinTypeSize() const { return binTypeSize_; }
|
||||||
binTypeSize_ = binTypeSize;
|
|
||||||
switch (binTypeSize) {
|
|
||||||
case kUint8BinsTypeSize:
|
|
||||||
func_ = &GetValueFromUint8;
|
|
||||||
break;
|
|
||||||
case kUint16BinsTypeSize:
|
|
||||||
func_ = &GetValueFromUint16;
|
|
||||||
break;
|
|
||||||
case kUint32BinsTypeSize:
|
|
||||||
func_ = &GetValueFromUint32;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
CHECK(binTypeSize == kUint8BinsTypeSize || binTypeSize == kUint16BinsTypeSize ||
|
|
||||||
binTypeSize == kUint32BinsTypeSize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BinTypeSize GetBinTypeSize() const {
|
|
||||||
return binTypeSize_;
|
|
||||||
}
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T const* data() const { // NOLINT
|
T const* data() const { // NOLINT
|
||||||
return reinterpret_cast<T const*>(data_.data());
|
return reinterpret_cast<T const*>(data_.data());
|
||||||
@ -275,30 +290,27 @@ struct Index {
|
|||||||
T* data() { // NOLINT
|
T* data() { // NOLINT
|
||||||
return reinterpret_cast<T*>(data_.data());
|
return reinterpret_cast<T*>(data_.data());
|
||||||
}
|
}
|
||||||
uint32_t const* Offset() const { return bin_offset_.data(); }
|
[[nodiscard]] std::uint32_t const* Offset() const { return bin_offset_.data(); }
|
||||||
size_t OffsetSize() const { return bin_offset_.size(); }
|
[[nodiscard]] std::size_t OffsetSize() const { return bin_offset_.size(); }
|
||||||
size_t Size() const { return data_.size() / (binTypeSize_); }
|
[[nodiscard]] std::size_t Size() const { return data_.size() / (binTypeSize_); }
|
||||||
|
|
||||||
void Resize(const size_t n_bytes) {
|
|
||||||
data_.resize(n_bytes);
|
|
||||||
}
|
|
||||||
// set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts
|
// set the offset used in compression, cut_ptrs is the CSC indptr in HistogramCuts
|
||||||
void SetBinOffset(std::vector<uint32_t> const& cut_ptrs) {
|
void SetBinOffset(std::vector<uint32_t> const& cut_ptrs) {
|
||||||
bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features.
|
bin_offset_.resize(cut_ptrs.size() - 1); // resize to number of features.
|
||||||
std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin());
|
std::copy_n(cut_ptrs.begin(), bin_offset_.size(), bin_offset_.begin());
|
||||||
}
|
}
|
||||||
std::vector<uint8_t>::const_iterator begin() const { // NOLINT
|
auto begin() const { // NOLINT
|
||||||
return data_.begin();
|
return data_.data();
|
||||||
}
|
}
|
||||||
std::vector<uint8_t>::const_iterator end() const { // NOLINT
|
auto end() const { // NOLINT
|
||||||
return data_.end();
|
return data_.data() + data_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t>::iterator begin() { // NOLINT
|
auto begin() { // NOLINT
|
||||||
return data_.begin();
|
return data_.data();
|
||||||
}
|
}
|
||||||
std::vector<uint8_t>::iterator end() { // NOLINT
|
auto end() { // NOLINT
|
||||||
return data_.end();
|
return data_.data() + data_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -313,12 +325,12 @@ struct Index {
|
|||||||
|
|
||||||
using Func = uint32_t (*)(uint8_t const*, size_t);
|
using Func = uint32_t (*)(uint8_t const*, size_t);
|
||||||
|
|
||||||
std::vector<uint8_t> data_;
|
Span<std::uint8_t> data_;
|
||||||
// starting position of each feature inside the cut values (the indptr of the CSC cut matrix
|
// starting position of each feature inside the cut values (the indptr of the CSC cut matrix
|
||||||
// HistogramCuts without the last entry.) Used for bin compression.
|
// HistogramCuts without the last entry.) Used for bin compression.
|
||||||
std::vector<uint32_t> bin_offset_;
|
std::vector<uint32_t> bin_offset_;
|
||||||
|
|
||||||
BinTypeSize binTypeSize_ {kUint8BinsTypeSize};
|
BinTypeSize binTypeSize_{kUint8BinsTypeSize};
|
||||||
Func func_;
|
Func func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
104
src/common/io.cc
104
src/common/io.cc
@ -200,21 +200,43 @@ std::string FileExtension(std::string fname, bool lower) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PrivateMmapConstStream::MMAPFile {
|
// For some reason, NVCC 12.1 marks the function deleted if we expose it in the header.
|
||||||
|
// NVCC 11.8 doesn't allow `noexcept(false) = default` altogether.
|
||||||
|
ResourceHandler::~ResourceHandler() noexcept(false) {} // NOLINT
|
||||||
|
|
||||||
|
struct MMAPFile {
|
||||||
#if defined(xgboost_IS_WIN)
|
#if defined(xgboost_IS_WIN)
|
||||||
HANDLE fd{INVALID_HANDLE_VALUE};
|
HANDLE fd{INVALID_HANDLE_VALUE};
|
||||||
HANDLE file_map{INVALID_HANDLE_VALUE};
|
HANDLE file_map{INVALID_HANDLE_VALUE};
|
||||||
#else
|
#else
|
||||||
std::int32_t fd{0};
|
std::int32_t fd{0};
|
||||||
#endif
|
#endif
|
||||||
char* base_ptr{nullptr};
|
std::byte* base_ptr{nullptr};
|
||||||
std::size_t base_size{0};
|
std::size_t base_size{0};
|
||||||
|
std::size_t delta{0};
|
||||||
std::string path;
|
std::string path;
|
||||||
|
|
||||||
|
MMAPFile() = default;
|
||||||
|
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
MMAPFile(HANDLE fd, HANDLE fm, std::byte* base_ptr, std::size_t base_size, std::size_t delta,
|
||||||
|
std::string path)
|
||||||
|
: fd{fd},
|
||||||
|
file_map{fm},
|
||||||
|
base_ptr{base_ptr},
|
||||||
|
base_size{base_size},
|
||||||
|
delta{delta},
|
||||||
|
path{std::move(path)} {}
|
||||||
|
#else
|
||||||
|
MMAPFile(std::int32_t fd, std::byte* base_ptr, std::size_t base_size, std::size_t delta,
|
||||||
|
std::string path)
|
||||||
|
: fd{fd}, base_ptr{base_ptr}, base_size{base_size}, delta{delta}, path{std::move(path)} {}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::size_t length) {
|
std::unique_ptr<MMAPFile> Open(std::string path, std::size_t offset, std::size_t length) {
|
||||||
if (length == 0) {
|
if (length == 0) {
|
||||||
return nullptr;
|
return std::make_unique<MMAPFile>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(xgboost_IS_WIN)
|
#if defined(xgboost_IS_WIN)
|
||||||
@ -226,16 +248,18 @@ char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::si
|
|||||||
CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg();
|
CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
char* ptr{nullptr};
|
std::byte* ptr{nullptr};
|
||||||
// Round down for alignment.
|
// Round down for alignment.
|
||||||
auto view_start = offset / GetMmapAlignment() * GetMmapAlignment();
|
auto view_start = offset / GetMmapAlignment() * GetMmapAlignment();
|
||||||
auto view_size = length + (offset - view_start);
|
auto view_size = length + (offset - view_start);
|
||||||
|
|
||||||
#if defined(__linux__) || defined(__GLIBC__)
|
#if defined(__linux__) || defined(__GLIBC__)
|
||||||
int prot{PROT_READ};
|
int prot{PROT_READ};
|
||||||
ptr = reinterpret_cast<char*>(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
ptr = reinterpret_cast<std::byte*>(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
||||||
|
madvise(ptr, view_size, MADV_WILLNEED);
|
||||||
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
|
auto handle =
|
||||||
|
std::make_unique<MMAPFile>(fd, ptr, view_size, offset - view_start, std::move(path));
|
||||||
#elif defined(xgboost_IS_WIN)
|
#elif defined(xgboost_IS_WIN)
|
||||||
auto file_size = GetFileSize(fd, nullptr);
|
auto file_size = GetFileSize(fd, nullptr);
|
||||||
DWORD access = PAGE_READONLY;
|
DWORD access = PAGE_READONLY;
|
||||||
@ -244,33 +268,32 @@ char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::si
|
|||||||
std::uint32_t loff = static_cast<std::uint32_t>(view_start);
|
std::uint32_t loff = static_cast<std::uint32_t>(view_start);
|
||||||
std::uint32_t hoff = view_start >> 32;
|
std::uint32_t hoff = view_start >> 32;
|
||||||
CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
ptr = reinterpret_cast<char*>(MapViewOfFile(map_file, access, hoff, loff, view_size));
|
ptr = reinterpret_cast<std::byte*>(MapViewOfFile(map_file, access, hoff, loff, view_size));
|
||||||
CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
handle_.reset(new MMAPFile{fd, map_file, ptr, view_size, std::move(path)});
|
auto handle = std::make_unique<MMAPFile>(fd, map_file, ptr, view_size, offset - view_start,
|
||||||
|
std::move(path));
|
||||||
#else
|
#else
|
||||||
CHECK_LE(offset, std::numeric_limits<off_t>::max())
|
CHECK_LE(offset, std::numeric_limits<off_t>::max())
|
||||||
<< "File size has exceeded the limit on the current system.";
|
<< "File size has exceeded the limit on the current system.";
|
||||||
int prot{PROT_READ};
|
int prot{PROT_READ};
|
||||||
ptr = reinterpret_cast<char*>(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
ptr = reinterpret_cast<std::byte*>(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
||||||
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
|
auto handle =
|
||||||
|
std::make_unique<MMAPFile>(fd, ptr, view_size, offset - view_start, std::move(path));
|
||||||
#endif // defined(__linux__)
|
#endif // defined(__linux__)
|
||||||
|
|
||||||
ptr += (offset - view_start);
|
return handle;
|
||||||
return ptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PrivateMmapConstStream::PrivateMmapConstStream(std::string path, std::size_t offset,
|
MmapResource::MmapResource(std::string path, std::size_t offset, std::size_t length)
|
||||||
std::size_t length)
|
: ResourceHandler{kMmap}, handle_{Open(std::move(path), offset, length)}, n_{length} {}
|
||||||
: MemoryFixSizeBuffer{}, handle_{nullptr} {
|
|
||||||
this->p_buffer_ = Open(std::move(path), offset, length);
|
|
||||||
this->buffer_size_ = length;
|
|
||||||
}
|
|
||||||
|
|
||||||
PrivateMmapConstStream::~PrivateMmapConstStream() {
|
MmapResource::~MmapResource() noexcept(false) {
|
||||||
CHECK(handle_);
|
if (!handle_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
#if defined(xgboost_IS_WIN)
|
#if defined(xgboost_IS_WIN)
|
||||||
if (p_buffer_) {
|
if (handle_->base_ptr) {
|
||||||
CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg();
|
CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg();
|
||||||
}
|
}
|
||||||
if (handle_->fd != INVALID_HANDLE_VALUE) {
|
if (handle_->fd != INVALID_HANDLE_VALUE) {
|
||||||
@ -290,6 +313,43 @@ PrivateMmapConstStream::~PrivateMmapConstStream() {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] void* MmapResource::Data() {
|
||||||
|
if (!handle_) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return handle_->base_ptr + handle_->delta;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t MmapResource::Size() const { return n_; }
|
||||||
|
|
||||||
|
// For some reason, NVCC 12.1 marks the function deleted if we expose it in the header.
|
||||||
|
// NVCC 11.8 doesn't allow `noexcept(false) = default` altogether.
|
||||||
|
AlignedResourceReadStream::~AlignedResourceReadStream() noexcept(false) {} // NOLINT
|
||||||
|
PrivateMmapConstStream::~PrivateMmapConstStream() noexcept(false) {} // NOLINT
|
||||||
|
|
||||||
|
AlignedFileWriteStream::AlignedFileWriteStream(StringView path, StringView flags)
|
||||||
|
: pimpl_{dmlc::Stream::Create(path.c_str(), flags.c_str())} {}
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t AlignedFileWriteStream::DoWrite(const void* ptr,
|
||||||
|
std::size_t n_bytes) noexcept(true) {
|
||||||
|
pimpl_->Write(ptr, n_bytes);
|
||||||
|
return n_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
AlignedMemWriteStream::AlignedMemWriteStream(std::string* p_buf)
|
||||||
|
: pimpl_{std::make_unique<MemoryBufferStream>(p_buf)} {}
|
||||||
|
AlignedMemWriteStream::~AlignedMemWriteStream() = default;
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t AlignedMemWriteStream::DoWrite(const void* ptr,
|
||||||
|
std::size_t n_bytes) noexcept(true) {
|
||||||
|
this->pimpl_->Write(ptr, n_bytes);
|
||||||
|
return n_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t AlignedMemWriteStream::Tell() const noexcept(true) {
|
||||||
|
return this->pimpl_->Tell();
|
||||||
|
}
|
||||||
} // namespace xgboost::common
|
} // namespace xgboost::common
|
||||||
|
|
||||||
#if defined(xgboost_IS_WIN)
|
#if defined(xgboost_IS_WIN)
|
||||||
|
|||||||
336
src/common/io.h
336
src/common/io.h
@ -4,22 +4,29 @@
|
|||||||
* \brief general stream interface for serialization, I/O
|
* \brief general stream interface for serialization, I/O
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef XGBOOST_COMMON_IO_H_
|
#ifndef XGBOOST_COMMON_IO_H_
|
||||||
#define XGBOOST_COMMON_IO_H_
|
#define XGBOOST_COMMON_IO_H_
|
||||||
|
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
|
|
||||||
#include <cstring>
|
#include <algorithm> // for min
|
||||||
#include <fstream>
|
#include <array> // for array
|
||||||
#include <memory> // for unique_ptr
|
#include <cstddef> // for byte, size_t
|
||||||
#include <string> // for string
|
#include <cstdlib> // for malloc, realloc, free
|
||||||
|
#include <cstring> // for memcpy
|
||||||
|
#include <fstream> // for ifstream
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <type_traits> // for alignment_of_v, enable_if_t
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
|
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
|
||||||
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
|
using MemoryBufferStream = rabit::utils::MemoryBufferStream;
|
||||||
|
|
||||||
@ -58,8 +65,8 @@ class FixedSizeStream : public PeekableInStream {
|
|||||||
|
|
||||||
size_t Read(void* dptr, size_t size) override;
|
size_t Read(void* dptr, size_t size) override;
|
||||||
size_t PeekRead(void* dptr, size_t size) override;
|
size_t PeekRead(void* dptr, size_t size) override;
|
||||||
size_t Size() const { return buffer_.size(); }
|
[[nodiscard]] std::size_t Size() const { return buffer_.size(); }
|
||||||
size_t Tell() const { return pointer_; }
|
[[nodiscard]] std::size_t Tell() const { return pointer_; }
|
||||||
void Seek(size_t pos);
|
void Seek(size_t pos);
|
||||||
|
|
||||||
void Write(const void*, size_t) override {
|
void Write(const void*, size_t) override {
|
||||||
@ -129,18 +136,245 @@ inline std::string ReadAll(std::string const &path) {
|
|||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MMAPFile;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Handler for one-shot resource. Unlike `std::pmr::*`, the resource handler is
|
||||||
|
* fixed once it's constructed. Users cannot use mutable operations like resize
|
||||||
|
* without acquiring the specific resource first.
|
||||||
|
*/
|
||||||
|
class ResourceHandler {
|
||||||
|
public:
|
||||||
|
// RTTI
|
||||||
|
enum Kind : std::uint8_t {
|
||||||
|
kMalloc = 0,
|
||||||
|
kMmap = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
Kind kind_{kMalloc};
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual void* Data() = 0;
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] T* DataAs() {
|
||||||
|
return reinterpret_cast<T*>(this->Data());
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] virtual std::size_t Size() const = 0;
|
||||||
|
[[nodiscard]] auto Type() const { return kind_; }
|
||||||
|
|
||||||
|
// Allow exceptions for cleaning up resource.
|
||||||
|
virtual ~ResourceHandler() noexcept(false);
|
||||||
|
|
||||||
|
explicit ResourceHandler(Kind kind) : kind_{kind} {}
|
||||||
|
// Use shared_ptr to manage a pool like resource handler. All copy and assignment
|
||||||
|
// operators are disabled.
|
||||||
|
ResourceHandler(ResourceHandler const& that) = delete;
|
||||||
|
ResourceHandler& operator=(ResourceHandler const& that) = delete;
|
||||||
|
ResourceHandler(ResourceHandler&& that) = delete;
|
||||||
|
ResourceHandler& operator=(ResourceHandler&& that) = delete;
|
||||||
|
/**
|
||||||
|
* @brief Wether two resources have the same type. (both malloc or both mmap).
|
||||||
|
*/
|
||||||
|
[[nodiscard]] bool IsSameType(ResourceHandler const& that) const {
|
||||||
|
return this->Type() == that.Type();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MallocResource : public ResourceHandler {
|
||||||
|
void* ptr_{nullptr};
|
||||||
|
std::size_t n_{0};
|
||||||
|
|
||||||
|
void Clear() noexcept(true) {
|
||||||
|
std::free(ptr_);
|
||||||
|
ptr_ = nullptr;
|
||||||
|
n_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit MallocResource(std::size_t n_bytes) : ResourceHandler{kMalloc} { this->Resize(n_bytes); }
|
||||||
|
~MallocResource() noexcept(true) override { this->Clear(); }
|
||||||
|
|
||||||
|
void* Data() override { return ptr_; }
|
||||||
|
[[nodiscard]] std::size_t Size() const override { return n_; }
|
||||||
|
/**
|
||||||
|
* @brief Resize the resource to n_bytes. Unlike std::vector::resize, it prefers realloc
|
||||||
|
* over malloc.
|
||||||
|
*
|
||||||
|
* @tparam force_malloc Force the use of malloc over realloc. Used for testing.
|
||||||
|
*
|
||||||
|
* @param n_bytes The new size.
|
||||||
|
*/
|
||||||
|
template <bool force_malloc = false>
|
||||||
|
void Resize(std::size_t n_bytes) {
|
||||||
|
// realloc(ptr, 0) works, but is deprecated.
|
||||||
|
if (n_bytes == 0) {
|
||||||
|
this->Clear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If realloc fails, we need to copy the data ourselves.
|
||||||
|
bool need_copy{false};
|
||||||
|
void* new_ptr{nullptr};
|
||||||
|
// use realloc first, it can handle nullptr.
|
||||||
|
if constexpr (!force_malloc) {
|
||||||
|
new_ptr = std::realloc(ptr_, n_bytes);
|
||||||
|
}
|
||||||
|
// retry with malloc if realloc fails
|
||||||
|
if (!new_ptr) {
|
||||||
|
// ptr_ is preserved if realloc fails
|
||||||
|
new_ptr = std::malloc(n_bytes);
|
||||||
|
need_copy = true;
|
||||||
|
}
|
||||||
|
if (!new_ptr) {
|
||||||
|
// malloc fails
|
||||||
|
LOG(FATAL) << "bad_malloc: Failed to allocate " << n_bytes << " bytes.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (need_copy) {
|
||||||
|
std::copy_n(reinterpret_cast<std::byte*>(ptr_), n_, reinterpret_cast<std::byte*>(new_ptr));
|
||||||
|
}
|
||||||
|
// default initialize
|
||||||
|
std::memset(reinterpret_cast<std::byte*>(new_ptr) + n_, '\0', n_bytes - n_);
|
||||||
|
// free the old ptr if malloc is used.
|
||||||
|
if (need_copy) {
|
||||||
|
this->Clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr_ = new_ptr;
|
||||||
|
n_ = n_bytes;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A class for wrapping mmap as a resource for RAII.
|
||||||
|
*/
|
||||||
|
class MmapResource : public ResourceHandler {
|
||||||
|
std::unique_ptr<MMAPFile> handle_;
|
||||||
|
std::size_t n_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MmapResource(std::string path, std::size_t offset, std::size_t length);
|
||||||
|
~MmapResource() noexcept(false) override;
|
||||||
|
|
||||||
|
[[nodiscard]] void* Data() override;
|
||||||
|
[[nodiscard]] std::size_t Size() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param Alignment for resource read stream and aligned write stream.
|
||||||
|
*/
|
||||||
|
constexpr std::size_t IOAlignment() {
|
||||||
|
// For most of the pod types in XGBoost, 8 byte is sufficient.
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Wrap resource into a dmlc stream.
|
||||||
|
*
|
||||||
|
* This class is to facilitate the use of mmap. Caller can optionally use the `Read()`
|
||||||
|
* method or the `Consume()` method. The former copies data into output, while the latter
|
||||||
|
* makes copy only if it's a primitive type.
|
||||||
|
*
|
||||||
|
* Input is required to be aligned to IOAlignment().
|
||||||
|
*/
|
||||||
|
class AlignedResourceReadStream {
|
||||||
|
std::shared_ptr<ResourceHandler> resource_;
|
||||||
|
std::size_t curr_ptr_{0};
|
||||||
|
|
||||||
|
// Similar to SEEK_END in libc
|
||||||
|
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit AlignedResourceReadStream(std::shared_ptr<ResourceHandler> resource)
|
||||||
|
: resource_{std::move(resource)} {}
|
||||||
|
|
||||||
|
[[nodiscard]] std::shared_ptr<ResourceHandler> Share() noexcept(true) { return resource_; }
|
||||||
|
/**
|
||||||
|
* @brief Consume n_bytes of data, no copying is performed.
|
||||||
|
*
|
||||||
|
* @return A pair with the beginning pointer and the number of available bytes, which
|
||||||
|
* may be smaller than requested.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] auto Consume(std::size_t n_bytes) noexcept(true) {
|
||||||
|
auto res_size = resource_->Size();
|
||||||
|
auto data = reinterpret_cast<std::byte*>(resource_->Data());
|
||||||
|
auto ptr = data + curr_ptr_;
|
||||||
|
|
||||||
|
// Move the cursor
|
||||||
|
auto aligned_n_bytes = DivRoundUp(n_bytes, IOAlignment()) * IOAlignment();
|
||||||
|
auto aligned_forward = std::min(res_size - curr_ptr_, aligned_n_bytes);
|
||||||
|
std::size_t forward = std::min(res_size - curr_ptr_, n_bytes);
|
||||||
|
|
||||||
|
curr_ptr_ += aligned_forward;
|
||||||
|
|
||||||
|
return std::pair{ptr, forward};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] auto Consume(T* out) noexcept(false) -> std::enable_if_t<std::is_pod_v<T>, bool> {
|
||||||
|
auto [ptr, size] = this->Consume(sizeof(T));
|
||||||
|
if (size != sizeof(T)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
CHECK_EQ(reinterpret_cast<std::uintptr_t>(ptr) % std::alignment_of_v<T>, 0);
|
||||||
|
*out = *reinterpret_cast<T*>(ptr);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] virtual std::size_t Tell() noexcept(true) { return curr_ptr_; }
|
||||||
|
/**
|
||||||
|
* @brief Read n_bytes of data, output is copied into ptr.
|
||||||
|
*/
|
||||||
|
[[nodiscard]] std::size_t Read(void* ptr, std::size_t n_bytes) noexcept(true) {
|
||||||
|
auto [res_ptr, forward] = this->Consume(n_bytes);
|
||||||
|
if (forward != 0) {
|
||||||
|
std::memcpy(ptr, res_ptr, forward);
|
||||||
|
}
|
||||||
|
return forward;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Read a primitive type.
|
||||||
|
*
|
||||||
|
* @return Whether the read is successful.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] auto Read(T* out) noexcept(false) -> std::enable_if_t<std::is_pod_v<T>, bool> {
|
||||||
|
return this->Consume(out);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Read a vector.
|
||||||
|
*
|
||||||
|
* @return Whether the read is successful.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] bool Read(std::vector<T>* out) noexcept(true) {
|
||||||
|
std::uint64_t n{0};
|
||||||
|
if (!this->Consume(&n)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out->resize(n);
|
||||||
|
|
||||||
|
auto n_bytes = sizeof(T) * n;
|
||||||
|
if (this->Read(out->data(), n_bytes) != n_bytes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~AlignedResourceReadStream() noexcept(false);
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Private mmap file as a read-only stream.
|
* @brief Private mmap file as a read-only stream.
|
||||||
*
|
*
|
||||||
* It can calculate alignment automatically based on system page size (or allocation
|
* It can calculate alignment automatically based on system page size (or allocation
|
||||||
* granularity on Windows).
|
* granularity on Windows).
|
||||||
|
*
|
||||||
|
* The file is required to be aligned by IOAlignment().
|
||||||
*/
|
*/
|
||||||
class PrivateMmapConstStream : public MemoryFixSizeBuffer {
|
class PrivateMmapConstStream : public AlignedResourceReadStream {
|
||||||
struct MMAPFile;
|
|
||||||
std::unique_ptr<MMAPFile> handle_;
|
|
||||||
|
|
||||||
char* Open(std::string path, std::size_t offset, std::size_t length);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Construct a private mmap stream.
|
* @brief Construct a private mmap stream.
|
||||||
@ -149,11 +383,71 @@ class PrivateMmapConstStream : public MemoryFixSizeBuffer {
|
|||||||
* @param offset See the `offset` parameter of `mmap` for details.
|
* @param offset See the `offset` parameter of `mmap` for details.
|
||||||
* @param length See the `length` parameter of `mmap` for details.
|
* @param length See the `length` parameter of `mmap` for details.
|
||||||
*/
|
*/
|
||||||
explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length);
|
explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length)
|
||||||
void Write(void const*, std::size_t) override { LOG(FATAL) << "Read-only stream."; }
|
: AlignedResourceReadStream{std::make_shared<MmapResource>(path, offset, length)} {}
|
||||||
|
~PrivateMmapConstStream() noexcept(false) override;
|
||||||
~PrivateMmapConstStream() override;
|
|
||||||
};
|
};
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
/**
|
||||||
|
* @brief Base class for write stream with alignment defined by IOAlignment().
|
||||||
|
*/
|
||||||
|
class AlignedWriteStream {
|
||||||
|
protected:
|
||||||
|
[[nodiscard]] virtual std::size_t DoWrite(const void* ptr,
|
||||||
|
std::size_t n_bytes) noexcept(true) = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ~AlignedWriteStream() = default;
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t Write(const void* ptr, std::size_t n_bytes) noexcept(false) {
|
||||||
|
auto aligned_n_bytes = DivRoundUp(n_bytes, IOAlignment()) * IOAlignment();
|
||||||
|
auto w_n_bytes = this->DoWrite(ptr, n_bytes);
|
||||||
|
CHECK_EQ(w_n_bytes, n_bytes);
|
||||||
|
auto remaining = aligned_n_bytes - n_bytes;
|
||||||
|
if (remaining > 0) {
|
||||||
|
std::array<std::uint8_t, IOAlignment()> padding;
|
||||||
|
std::memset(padding.data(), '\0', padding.size());
|
||||||
|
w_n_bytes = this->DoWrite(padding.data(), remaining);
|
||||||
|
CHECK_EQ(w_n_bytes, remaining);
|
||||||
|
}
|
||||||
|
return aligned_n_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] std::enable_if_t<std::is_pod_v<T>, std::size_t> Write(T const& v) {
|
||||||
|
return this->Write(&v, sizeof(T));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output stream backed by a file. Aligned to IOAlignment() bytes.
|
||||||
|
*/
|
||||||
|
class AlignedFileWriteStream : public AlignedWriteStream {
|
||||||
|
std::unique_ptr<dmlc::Stream> pimpl_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
[[nodiscard]] std::size_t DoWrite(const void* ptr, std::size_t n_bytes) noexcept(true) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AlignedFileWriteStream() = default;
|
||||||
|
AlignedFileWriteStream(StringView path, StringView flags);
|
||||||
|
~AlignedFileWriteStream() override = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output stream backed by memory buffer. Aligned to IOAlignment() bytes.
|
||||||
|
*/
|
||||||
|
class AlignedMemWriteStream : public AlignedFileWriteStream {
|
||||||
|
std::unique_ptr<MemoryBufferStream> pimpl_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
[[nodiscard]] std::size_t DoWrite(const void* ptr, std::size_t n_bytes) noexcept(true) override;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit AlignedMemWriteStream(std::string* p_buf);
|
||||||
|
~AlignedMemWriteStream() override;
|
||||||
|
|
||||||
|
[[nodiscard]] std::size_t Tell() const noexcept(true);
|
||||||
|
};
|
||||||
|
} // namespace xgboost::common
|
||||||
#endif // XGBOOST_COMMON_IO_H_
|
#endif // XGBOOST_COMMON_IO_H_
|
||||||
|
|||||||
158
src/common/ref_resource_view.h
Normal file
158
src/common/ref_resource_view.h
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
|
||||||
|
#define XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
|
||||||
|
|
||||||
|
#include <algorithm> // for fill_n
|
||||||
|
#include <cstdint> // for uint64_t
|
||||||
|
#include <cstring> // for memcpy
|
||||||
|
#include <memory> // for shared_ptr, make_shared
|
||||||
|
#include <type_traits> // for is_reference_v, remove_reference_t, is_same_v
|
||||||
|
#include <utility> // for swap, move
|
||||||
|
|
||||||
|
#include "io.h" // for ResourceHandler, AlignedResourceReadStream, MallocResource
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/span.h" // for Span
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
|
/**
|
||||||
|
* @brief A vector-like type that holds a reference counted resource.
|
||||||
|
*
|
||||||
|
* The vector size is immutable after construction. This way we can swap the underlying
|
||||||
|
* resource when needed.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
class RefResourceView {
|
||||||
|
static_assert(!std::is_reference_v<T>);
|
||||||
|
|
||||||
|
public:
|
||||||
|
using value_type = T; // NOLINT
|
||||||
|
using size_type = std::uint64_t; // NOLINT
|
||||||
|
|
||||||
|
private:
|
||||||
|
value_type* ptr_{nullptr};
|
||||||
|
size_type size_{0};
|
||||||
|
std::shared_ptr<common::ResourceHandler> mem_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
RefResourceView(value_type* ptr, size_type n, std::shared_ptr<common::ResourceHandler> mem)
|
||||||
|
: ptr_{ptr}, size_{n}, mem_{std::move(mem)} {
|
||||||
|
CHECK_GE(mem_->Size(), n);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Construct a view on ptr with length n. The ptr is held by the mem resource.
|
||||||
|
*
|
||||||
|
* @param ptr The pointer to view.
|
||||||
|
* @param n The length of the view.
|
||||||
|
* @param mem The owner of the pointer.
|
||||||
|
* @param init Initialize the view with this value.
|
||||||
|
*/
|
||||||
|
RefResourceView(value_type* ptr, size_type n, std::shared_ptr<common::ResourceHandler> mem,
|
||||||
|
T const& init)
|
||||||
|
: RefResourceView{ptr, n, mem} {
|
||||||
|
if (n != 0) {
|
||||||
|
std::fill_n(ptr_, n, init);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~RefResourceView() = default;
|
||||||
|
|
||||||
|
RefResourceView() = default;
|
||||||
|
RefResourceView(RefResourceView const& that) = delete;
|
||||||
|
RefResourceView(RefResourceView&& that) = delete;
|
||||||
|
RefResourceView& operator=(RefResourceView const& that) = delete;
|
||||||
|
/**
|
||||||
|
* @brief We allow move assignment for lazy initialization.
|
||||||
|
*/
|
||||||
|
RefResourceView& operator=(RefResourceView&& that) = default;
|
||||||
|
|
||||||
|
[[nodiscard]] size_type size() const { return size_; } // NOLINT
|
||||||
|
[[nodiscard]] size_type size_bytes() const { // NOLINT
|
||||||
|
return Span{data(), size()}.size_bytes();
|
||||||
|
}
|
||||||
|
[[nodiscard]] value_type* data() { return ptr_; }; // NOLINT
|
||||||
|
[[nodiscard]] value_type const* data() const { return ptr_; }; // NOLINT
|
||||||
|
[[nodiscard]] bool empty() const { return size() == 0; } // NOLINT
|
||||||
|
|
||||||
|
[[nodiscard]] auto cbegin() const { return data(); } // NOLINT
|
||||||
|
[[nodiscard]] auto begin() { return data(); } // NOLINT
|
||||||
|
[[nodiscard]] auto begin() const { return cbegin(); } // NOLINT
|
||||||
|
[[nodiscard]] auto cend() const { return data() + size(); } // NOLINT
|
||||||
|
[[nodiscard]] auto end() { return data() + size(); } // NOLINT
|
||||||
|
[[nodiscard]] auto end() const { return cend(); } // NOLINT
|
||||||
|
|
||||||
|
[[nodiscard]] auto const& front() const { return data()[0]; } // NOLINT
|
||||||
|
[[nodiscard]] auto& front() { return data()[0]; } // NOLINT
|
||||||
|
[[nodiscard]] auto const& back() const { return data()[size() - 1]; } // NOLINT
|
||||||
|
[[nodiscard]] auto& back() { return data()[size() - 1]; } // NOLINT
|
||||||
|
|
||||||
|
[[nodiscard]] value_type& operator[](size_type i) { return ptr_[i]; }
|
||||||
|
[[nodiscard]] value_type const& operator[](size_type i) const { return ptr_[i]; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the underlying resource.
|
||||||
|
*/
|
||||||
|
auto Resource() const { return mem_; }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Read a vector from stream. Accepts both `std::vector` and `RefResourceView`.
|
||||||
|
*
|
||||||
|
* If the output vector is a referenced counted view, no copying occur.
|
||||||
|
*/
|
||||||
|
template <typename Vec>
|
||||||
|
[[nodiscard]] bool ReadVec(common::AlignedResourceReadStream* fi, Vec* vec) {
|
||||||
|
std::uint64_t n{0};
|
||||||
|
if (!fi->Read(&n)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (n == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
using T = typename Vec::value_type;
|
||||||
|
auto expected_bytes = sizeof(T) * n;
|
||||||
|
|
||||||
|
auto [ptr, n_bytes] = fi->Consume(expected_bytes);
|
||||||
|
if (n_bytes != expected_bytes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<Vec, RefResourceView<T>>) {
|
||||||
|
*vec = RefResourceView<T>{reinterpret_cast<T*>(ptr), n, fi->Share()};
|
||||||
|
} else {
|
||||||
|
vec->resize(n);
|
||||||
|
std::memcpy(vec->data(), ptr, n_bytes);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Write a vector to stream. Accepts both `std::vector` and `RefResourceView`.
|
||||||
|
*/
|
||||||
|
template <typename Vec>
|
||||||
|
[[nodiscard]] std::size_t WriteVec(AlignedFileWriteStream* fo, Vec const& vec) {
|
||||||
|
std::size_t bytes{0};
|
||||||
|
auto n = static_cast<std::uint64_t>(vec.size());
|
||||||
|
bytes += fo->Write(n);
|
||||||
|
if (n == 0) {
|
||||||
|
return sizeof(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
|
||||||
|
bytes += fo->Write(vec.data(), vec.size() * sizeof(T));
|
||||||
|
|
||||||
|
return bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Make a fixed size `RefResourceView` with malloc resource.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[nodiscard]] RefResourceView<T> MakeFixedVecWithMalloc(std::size_t n_elements, T const& init) {
|
||||||
|
auto resource = std::make_shared<common::MallocResource>(n_elements * sizeof(T));
|
||||||
|
return RefResourceView{resource->DataAs<T>(), n_elements, resource, init};
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
|
#endif // XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
|
||||||
@ -1,60 +1,59 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019-2021 XGBoost contributors
|
* Copyright 2019-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <xgboost/data.h>
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
|
||||||
|
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||||
|
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
#include "sparse_page_writer.h"
|
#include "histogram_cut_format.h" // for ReadHistogramCuts, WriteHistogramCuts
|
||||||
#include "histogram_cut_format.h"
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace data {
|
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
|
||||||
|
|
||||||
|
|
||||||
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
|
||||||
public:
|
public:
|
||||||
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
|
bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override {
|
||||||
auto* impl = page->Impl();
|
auto* impl = page->Impl();
|
||||||
if (!ReadHistogramCuts(&impl->Cuts(), fi)) {
|
if (!ReadHistogramCuts(&impl->Cuts(), fi)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
fi->Read(&impl->n_rows);
|
if (!fi->Read(&impl->n_rows)) {
|
||||||
fi->Read(&impl->is_dense);
|
return false;
|
||||||
fi->Read(&impl->row_stride);
|
}
|
||||||
fi->Read(&impl->gidx_buffer.HostVector());
|
if (!fi->Read(&impl->is_dense)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!fi->Read(&impl->row_stride)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!common::ReadVec(fi, &impl->gidx_buffer.HostVector())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (!fi->Read(&impl->base_rowid)) {
|
if (!fi->Read(&impl->base_rowid)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Write(const EllpackPage& page, dmlc::Stream* fo) override {
|
size_t Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) override {
|
||||||
size_t bytes = 0;
|
std::size_t bytes{0};
|
||||||
auto* impl = page.Impl();
|
auto* impl = page.Impl();
|
||||||
bytes += WriteHistogramCuts(impl->Cuts(), fo);
|
bytes += WriteHistogramCuts(impl->Cuts(), fo);
|
||||||
fo->Write(impl->n_rows);
|
bytes += fo->Write(impl->n_rows);
|
||||||
bytes += sizeof(impl->n_rows);
|
bytes += fo->Write(impl->is_dense);
|
||||||
fo->Write(impl->is_dense);
|
bytes += fo->Write(impl->row_stride);
|
||||||
bytes += sizeof(impl->is_dense);
|
|
||||||
fo->Write(impl->row_stride);
|
|
||||||
bytes += sizeof(impl->row_stride);
|
|
||||||
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
|
CHECK(!impl->gidx_buffer.ConstHostVector().empty());
|
||||||
fo->Write(impl->gidx_buffer.HostVector());
|
bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector());
|
||||||
bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
bytes += fo->Write(impl->base_rowid);
|
||||||
fo->Write(impl->base_rowid);
|
|
||||||
bytes += sizeof(impl->base_rowid);
|
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw)
|
XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(raw)
|
||||||
.describe("Raw ELLPACK binary data format.")
|
.describe("Raw ELLPACK binary data format.")
|
||||||
.set_body([]() {
|
.set_body([]() { return new EllpackPageRawFormat(); });
|
||||||
return new EllpackPageRawFormat();
|
} // namespace xgboost::data
|
||||||
});
|
|
||||||
|
|
||||||
} // namespace data
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_
|
|||||||
cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess);
|
cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess);
|
||||||
|
|
||||||
const uint32_t nbins = cut.Ptrs().back();
|
const uint32_t nbins = cut.Ptrs().back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count = common::MakeFixedVecWithMalloc(nbins, std::size_t{0});
|
||||||
hit_count_tloc_.resize(ctx->Threads() * nbins, 0);
|
hit_count_tloc_.resize(ctx->Threads() * nbins, 0);
|
||||||
|
|
||||||
size_t new_size = 1;
|
size_t new_size = 1;
|
||||||
@ -37,8 +37,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_
|
|||||||
new_size += batch.Size();
|
new_size += batch.Size();
|
||||||
}
|
}
|
||||||
|
|
||||||
row_ptr.resize(new_size);
|
row_ptr = common::MakeFixedVecWithMalloc(new_size, std::size_t{0});
|
||||||
row_ptr[0] = 0;
|
|
||||||
|
|
||||||
const bool isDense = p_fmat->IsDense();
|
const bool isDense = p_fmat->IsDense();
|
||||||
this->isDense_ = isDense;
|
this->isDense_ = isDense;
|
||||||
@ -61,8 +60,8 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_
|
|||||||
|
|
||||||
GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &&cuts,
|
GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &&cuts,
|
||||||
bst_bin_t max_bin_per_feat)
|
bst_bin_t max_bin_per_feat)
|
||||||
: row_ptr(info.num_row_ + 1, 0),
|
: row_ptr{common::MakeFixedVecWithMalloc(info.num_row_ + 1, std::size_t{0})},
|
||||||
hit_count(cuts.TotalBins(), 0),
|
hit_count{common::MakeFixedVecWithMalloc(cuts.TotalBins(), std::size_t{0})},
|
||||||
cut{std::forward<common::HistogramCuts>(cuts)},
|
cut{std::forward<common::HistogramCuts>(cuts)},
|
||||||
max_numeric_bins_per_feat(max_bin_per_feat),
|
max_numeric_bins_per_feat(max_bin_per_feat),
|
||||||
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
||||||
@ -95,12 +94,10 @@ GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span<Feature
|
|||||||
isDense_{isDense} {
|
isDense_{isDense} {
|
||||||
CHECK_GE(n_threads, 1);
|
CHECK_GE(n_threads, 1);
|
||||||
CHECK_EQ(row_ptr.size(), 0);
|
CHECK_EQ(row_ptr.size(), 0);
|
||||||
// The number of threads is pegged to the batch size. If the OMP
|
row_ptr = common::MakeFixedVecWithMalloc(batch.Size() + 1, std::size_t{0});
|
||||||
// block is parallelized on anything other than the batch/block size,
|
|
||||||
// it should be reassigned
|
|
||||||
row_ptr.resize(batch.Size() + 1, 0);
|
|
||||||
const uint32_t nbins = cut.Ptrs().back();
|
const uint32_t nbins = cut.Ptrs().back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count = common::MakeFixedVecWithMalloc(nbins, std::size_t{0});
|
||||||
hit_count_tloc_.resize(n_threads * nbins, 0);
|
hit_count_tloc_.resize(n_threads * nbins, 0);
|
||||||
|
|
||||||
this->PushBatch(batch, ft, n_threads);
|
this->PushBatch(batch, ft, n_threads);
|
||||||
@ -128,20 +125,45 @@ INSTANTIATION_PUSH(data::SparsePageAdapterBatch)
|
|||||||
#undef INSTANTIATION_PUSH
|
#undef INSTANTIATION_PUSH
|
||||||
|
|
||||||
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
|
||||||
|
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
|
||||||
|
// Must resize instead of allocating a new one. This function is called everytime a
|
||||||
|
// new batch is pushed, and we grow the size accordingly without loosing the data the
|
||||||
|
// previous batches.
|
||||||
|
using T = decltype(t);
|
||||||
|
std::size_t n_bytes = sizeof(T) * n_index;
|
||||||
|
CHECK_GE(n_bytes, this->data.size());
|
||||||
|
|
||||||
|
auto resource = this->data.Resource();
|
||||||
|
decltype(this->data) new_vec;
|
||||||
|
if (!resource) {
|
||||||
|
CHECK(this->data.empty());
|
||||||
|
new_vec = common::MakeFixedVecWithMalloc(n_bytes, std::uint8_t{0});
|
||||||
|
} else {
|
||||||
|
CHECK(resource->Type() == common::ResourceHandler::kMalloc);
|
||||||
|
auto malloc_resource = std::dynamic_pointer_cast<common::MallocResource>(resource);
|
||||||
|
CHECK(malloc_resource);
|
||||||
|
malloc_resource->Resize(n_bytes);
|
||||||
|
|
||||||
|
// gcc-11.3 doesn't work if DataAs is used.
|
||||||
|
std::uint8_t *new_ptr = reinterpret_cast<std::uint8_t *>(malloc_resource->Data());
|
||||||
|
new_vec = {new_ptr, n_bytes / sizeof(std::uint8_t), malloc_resource};
|
||||||
|
}
|
||||||
|
this->data = std::move(new_vec);
|
||||||
|
this->index = common::Index{common::Span{data.data(), data.size()}, t_size};
|
||||||
|
};
|
||||||
|
|
||||||
if ((MaxNumBinPerFeat() - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) &&
|
if ((MaxNumBinPerFeat() - 1 <= static_cast<int>(std::numeric_limits<uint8_t>::max())) &&
|
||||||
isDense) {
|
isDense) {
|
||||||
// compress dense index to uint8
|
// compress dense index to uint8
|
||||||
index.SetBinTypeSize(common::kUint8BinsTypeSize);
|
make_index(std::uint8_t{}, common::kUint8BinsTypeSize);
|
||||||
index.Resize((sizeof(uint8_t)) * n_index);
|
|
||||||
} else if ((MaxNumBinPerFeat() - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
|
} else if ((MaxNumBinPerFeat() - 1 > static_cast<int>(std::numeric_limits<uint8_t>::max()) &&
|
||||||
MaxNumBinPerFeat() - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) &&
|
MaxNumBinPerFeat() - 1 <= static_cast<int>(std::numeric_limits<uint16_t>::max())) &&
|
||||||
isDense) {
|
isDense) {
|
||||||
// compress dense index to uint16
|
// compress dense index to uint16
|
||||||
index.SetBinTypeSize(common::kUint16BinsTypeSize);
|
make_index(std::uint16_t{}, common::kUint16BinsTypeSize);
|
||||||
index.Resize((sizeof(uint16_t)) * n_index);
|
|
||||||
} else {
|
} else {
|
||||||
index.SetBinTypeSize(common::kUint32BinsTypeSize);
|
// no compression
|
||||||
index.Resize((sizeof(uint32_t)) * n_index);
|
make_index(std::uint32_t{}, common::kUint32BinsTypeSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,11 +236,11 @@ float GHistIndexMatrix::GetFvalue(std::vector<std::uint32_t> const &ptrs,
|
|||||||
return std::numeric_limits<float>::quiet_NaN();
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
|
bool GHistIndexMatrix::ReadColumnPage(common::AlignedResourceReadStream *fi) {
|
||||||
return this->columns_->Read(fi, this->cut.Ptrs().data());
|
return this->columns_->Read(fi, this->cut.Ptrs().data());
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const {
|
std::size_t GHistIndexMatrix::WriteColumnPage(common::AlignedFileWriteStream *fo) const {
|
||||||
return this->columns_->Write(fo);
|
return this->columns_->Write(fo);
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2022 by XGBoost Contributors
|
* Copyright 2022-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <memory> // std::unique_ptr
|
#include <memory> // std::unique_ptr
|
||||||
|
|
||||||
@ -41,9 +41,9 @@ void SetIndexData(Context const* ctx, EllpackPageImpl const* page,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page,
|
void GetRowPtrFromEllpack(Context const* ctx, EllpackPageImpl const* page,
|
||||||
std::vector<size_t>* p_out) {
|
common::RefResourceView<std::size_t>* p_out) {
|
||||||
auto& row_ptr = *p_out;
|
auto& row_ptr = *p_out;
|
||||||
row_ptr.resize(page->Size() + 1, 0);
|
row_ptr = common::MakeFixedVecWithMalloc(page->Size() + 1, std::size_t{0});
|
||||||
if (page->is_dense) {
|
if (page->is_dense) {
|
||||||
std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride);
|
std::fill(row_ptr.begin() + 1, row_ptr.end(), page->row_stride);
|
||||||
} else {
|
} else {
|
||||||
@ -95,7 +95,7 @@ GHistIndexMatrix::GHistIndexMatrix(Context const* ctx, MetaInfo const& info,
|
|||||||
ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this);
|
ctx, page, &hit_count_tloc_, [&](auto bin_idx, auto) { return bin_idx; }, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
this->hit_count.resize(n_bins_total, 0);
|
this->hit_count = common::MakeFixedVecWithMalloc(n_bins_total, std::size_t{0});
|
||||||
this->GatherHitCount(ctx->Threads(), n_bins_total);
|
this->GatherHitCount(ctx->Threads(), n_bins_total);
|
||||||
|
|
||||||
// sanity checks
|
// sanity checks
|
||||||
|
|||||||
@ -9,13 +9,14 @@
|
|||||||
#include <atomic> // for atomic
|
#include <atomic> // for atomic
|
||||||
#include <cinttypes> // for uint32_t
|
#include <cinttypes> // for uint32_t
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <memory>
|
#include <memory> // for make_unique
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
#include "../common/error_msg.h" // for InfInData
|
#include "../common/error_msg.h" // for InfInData
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/numeric.h"
|
#include "../common/numeric.h"
|
||||||
|
#include "../common/ref_resource_view.h" // for RefResourceView
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
@ -25,9 +26,11 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
class ColumnMatrix;
|
class ColumnMatrix;
|
||||||
|
class AlignedFileWriteStream;
|
||||||
} // namespace common
|
} // namespace common
|
||||||
/*!
|
|
||||||
* \brief preprocessed global index matrix, in CSR format
|
/**
|
||||||
|
* @brief preprocessed global index matrix, in CSR format.
|
||||||
*
|
*
|
||||||
* Transform floating values to integer index in histogram This is a global histogram
|
* Transform floating values to integer index in histogram This is a global histogram
|
||||||
* index for CPU histogram. On GPU ellpack page is used.
|
* index for CPU histogram. On GPU ellpack page is used.
|
||||||
@ -133,20 +136,22 @@ class GHistIndexMatrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/*! \brief row pointer to rows by element position */
|
/** @brief row pointer to rows by element position */
|
||||||
std::vector<size_t> row_ptr;
|
common::RefResourceView<std::size_t> row_ptr;
|
||||||
/*! \brief The index data */
|
/** @brief data storage for index. */
|
||||||
|
common::RefResourceView<std::uint8_t> data;
|
||||||
|
/** @brief The histogram index. */
|
||||||
common::Index index;
|
common::Index index;
|
||||||
/*! \brief hit count of each index, used for constructing the ColumnMatrix */
|
/** @brief hit count of each index, used for constructing the ColumnMatrix */
|
||||||
std::vector<size_t> hit_count;
|
common::RefResourceView<std::size_t> hit_count;
|
||||||
/*! \brief The corresponding cuts */
|
/** @brief The corresponding cuts */
|
||||||
common::HistogramCuts cut;
|
common::HistogramCuts cut;
|
||||||
/** \brief max_bin for each feature. */
|
/** @brief max_bin for each feature. */
|
||||||
bst_bin_t max_numeric_bins_per_feat;
|
bst_bin_t max_numeric_bins_per_feat;
|
||||||
/*! \brief base row index for current page (used by external memory) */
|
/** @brief base row index for current page (used by external memory) */
|
||||||
size_t base_rowid{0};
|
bst_row_t base_rowid{0};
|
||||||
|
|
||||||
bst_bin_t MaxNumBinPerFeat() const {
|
[[nodiscard]] bst_bin_t MaxNumBinPerFeat() const {
|
||||||
return std::max(static_cast<bst_bin_t>(cut.MaxCategory() + 1), max_numeric_bins_per_feat);
|
return std::max(static_cast<bst_bin_t>(cut.MaxCategory() + 1), max_numeric_bins_per_feat);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,29 +223,27 @@ class GHistIndexMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsDense() const {
|
[[nodiscard]] bool IsDense() const { return isDense_; }
|
||||||
return isDense_;
|
|
||||||
}
|
|
||||||
void SetDense(bool is_dense) { isDense_ = is_dense; }
|
void SetDense(bool is_dense) { isDense_ = is_dense; }
|
||||||
/**
|
/**
|
||||||
* \brief Get the local row index.
|
* @brief Get the local row index.
|
||||||
*/
|
*/
|
||||||
size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; }
|
[[nodiscard]] std::size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; }
|
||||||
|
|
||||||
bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; }
|
[[nodiscard]] bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; }
|
||||||
bst_feature_t Features() const { return cut.Ptrs().size() - 1; }
|
[[nodiscard]] bst_feature_t Features() const { return cut.Ptrs().size() - 1; }
|
||||||
|
|
||||||
bool ReadColumnPage(dmlc::SeekStream* fi);
|
[[nodiscard]] bool ReadColumnPage(common::AlignedResourceReadStream* fi);
|
||||||
size_t WriteColumnPage(dmlc::Stream* fo) const;
|
[[nodiscard]] std::size_t WriteColumnPage(common::AlignedFileWriteStream* fo) const;
|
||||||
|
|
||||||
common::ColumnMatrix const& Transpose() const;
|
[[nodiscard]] common::ColumnMatrix const& Transpose() const;
|
||||||
|
|
||||||
bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
|
[[nodiscard]] bst_bin_t GetGindex(size_t ridx, size_t fidx) const;
|
||||||
|
|
||||||
float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
[[nodiscard]] float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const;
|
||||||
float GetFvalue(std::vector<std::uint32_t> const& ptrs, std::vector<float> const& values,
|
[[nodiscard]] float GetFvalue(std::vector<std::uint32_t> const& ptrs,
|
||||||
std::vector<float> const& mins, bst_row_t ridx, bst_feature_t fidx,
|
std::vector<float> const& values, std::vector<float> const& mins,
|
||||||
bool is_cat) const;
|
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<common::ColumnMatrix> columns_;
|
std::unique_ptr<common::ColumnMatrix> columns_;
|
||||||
@ -294,5 +297,5 @@ void AssignColumnBinIndex(GHistIndexMatrix const& page, Fn&& assign) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
#endif // XGBOOST_DATA_GRADIENT_INDEX_H_
|
||||||
|
|||||||
@ -1,38 +1,49 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 XGBoost contributors
|
* Copyright 2021-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include "sparse_page_writer.h"
|
#include <cstddef> // for size_t
|
||||||
#include "gradient_index.h"
|
#include <cstdint> // for uint8_t
|
||||||
#include "histogram_cut_format.h"
|
#include <type_traits> // for underlying_type_t
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
namespace xgboost {
|
#include "../common/io.h" // for AlignedResourceReadStream
|
||||||
namespace data {
|
#include "../common/ref_resource_view.h" // for ReadVec, WriteVec
|
||||||
|
#include "gradient_index.h" // for GHistIndexMatrix
|
||||||
|
#include "histogram_cut_format.h" // for ReadHistogramCuts
|
||||||
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
||||||
public:
|
public:
|
||||||
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
|
bool Read(GHistIndexMatrix* page, common::AlignedResourceReadStream* fi) override {
|
||||||
|
CHECK(fi);
|
||||||
|
|
||||||
if (!ReadHistogramCuts(&page->cut, fi)) {
|
if (!ReadHistogramCuts(&page->cut, fi)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// indptr
|
// indptr
|
||||||
fi->Read(&page->row_ptr);
|
if (!common::ReadVec(fi, &page->row_ptr)) {
|
||||||
// data
|
|
||||||
std::vector<uint8_t> data;
|
|
||||||
if (!fi->Read(&data)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
page->index.Resize(data.size());
|
|
||||||
std::copy(data.cbegin(), data.cend(), page->index.begin());
|
// data
|
||||||
// bin type
|
// - bin type
|
||||||
// Old gcc doesn't support reading from enum.
|
// Old gcc doesn't support reading from enum.
|
||||||
std::underlying_type_t<common::BinTypeSize> uint_bin_type{0};
|
std::underlying_type_t<common::BinTypeSize> uint_bin_type{0};
|
||||||
if (!fi->Read(&uint_bin_type)) {
|
if (!fi->Read(&uint_bin_type)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
common::BinTypeSize size_type =
|
common::BinTypeSize size_type = static_cast<common::BinTypeSize>(uint_bin_type);
|
||||||
static_cast<common::BinTypeSize>(uint_bin_type);
|
// - index buffer
|
||||||
page->index.SetBinTypeSize(size_type);
|
if (!common::ReadVec(fi, &page->data)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// - index
|
||||||
|
page->index = common::Index{common::Span{page->data.data(), page->data.size()}, size_type};
|
||||||
|
|
||||||
// hit count
|
// hit count
|
||||||
if (!fi->Read(&page->hit_count)) {
|
if (!common::ReadVec(fi, &page->hit_count)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!fi->Read(&page->max_numeric_bins_per_feat)) {
|
if (!fi->Read(&page->max_numeric_bins_per_feat)) {
|
||||||
@ -50,38 +61,33 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
|
|||||||
page->index.SetBinOffset(page->cut.Ptrs());
|
page->index.SetBinOffset(page->cut.Ptrs());
|
||||||
}
|
}
|
||||||
|
|
||||||
page->ReadColumnPage(fi);
|
if (!page->ReadColumnPage(fi)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Write(GHistIndexMatrix const &page, dmlc::Stream *fo) override {
|
std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
|
||||||
size_t bytes = 0;
|
std::size_t bytes = 0;
|
||||||
bytes += WriteHistogramCuts(page.cut, fo);
|
bytes += WriteHistogramCuts(page.cut, fo);
|
||||||
// indptr
|
// indptr
|
||||||
fo->Write(page.row_ptr);
|
bytes += common::WriteVec(fo, page.row_ptr);
|
||||||
bytes += page.row_ptr.size() * sizeof(decltype(page.row_ptr)::value_type) +
|
|
||||||
sizeof(uint64_t);
|
|
||||||
// data
|
// data
|
||||||
std::vector<uint8_t> data(page.index.begin(), page.index.end());
|
// - bin type
|
||||||
fo->Write(data);
|
std::underlying_type_t<common::BinTypeSize> uint_bin_type = page.index.GetBinTypeSize();
|
||||||
bytes += data.size() * sizeof(decltype(data)::value_type) + sizeof(uint64_t);
|
bytes += fo->Write(uint_bin_type);
|
||||||
// bin type
|
// - index buffer
|
||||||
std::underlying_type_t<common::BinTypeSize> uint_bin_type =
|
std::vector<std::uint8_t> data(page.index.begin(), page.index.end());
|
||||||
page.index.GetBinTypeSize();
|
bytes += fo->Write(static_cast<std::uint64_t>(data.size()));
|
||||||
fo->Write(uint_bin_type);
|
bytes += fo->Write(data.data(), data.size());
|
||||||
bytes += sizeof(page.index.GetBinTypeSize());
|
|
||||||
// hit count
|
// hit count
|
||||||
fo->Write(page.hit_count);
|
bytes += common::WriteVec(fo, page.hit_count);
|
||||||
bytes +=
|
|
||||||
page.hit_count.size() * sizeof(decltype(page.hit_count)::value_type) +
|
|
||||||
sizeof(uint64_t);
|
|
||||||
// max_bins, base row, is_dense
|
// max_bins, base row, is_dense
|
||||||
fo->Write(page.max_numeric_bins_per_feat);
|
bytes += fo->Write(page.max_numeric_bins_per_feat);
|
||||||
bytes += sizeof(page.max_numeric_bins_per_feat);
|
bytes += fo->Write(page.base_rowid);
|
||||||
fo->Write(page.base_rowid);
|
bytes += fo->Write(page.IsDense());
|
||||||
bytes += sizeof(page.base_rowid);
|
|
||||||
fo->Write(page.IsDense());
|
|
||||||
bytes += sizeof(page.IsDense());
|
|
||||||
|
|
||||||
bytes += page.WriteColumnPage(fo);
|
bytes += page.WriteColumnPage(fo);
|
||||||
return bytes;
|
return bytes;
|
||||||
@ -93,6 +99,4 @@ DMLC_REGISTRY_FILE_TAG(gradient_index_format);
|
|||||||
XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw)
|
XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(raw)
|
||||||
.describe("Raw GHistIndex binary data format.")
|
.describe("Raw GHistIndex binary data format.")
|
||||||
.set_body([]() { return new GHistIndexRawFormat(); });
|
.set_body([]() { return new GHistIndexRawFormat(); });
|
||||||
|
} // namespace xgboost::data
|
||||||
} // namespace data
|
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -1,36 +1,38 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021 XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||||
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||||
|
|
||||||
#include "../common/hist_util.h"
|
#include <dmlc/io.h> // for Stream
|
||||||
|
|
||||||
namespace xgboost {
|
#include <cstddef> // for size_t
|
||||||
namespace data {
|
|
||||||
inline bool ReadHistogramCuts(common::HistogramCuts *cuts, dmlc::SeekStream *fi) {
|
#include "../common/hist_util.h" // for HistogramCuts
|
||||||
if (!fi->Read(&cuts->cut_values_.HostVector())) {
|
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||||
|
#include "../common/ref_resource_view.h" // for WriteVec, ReadVec
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
|
inline bool ReadHistogramCuts(common::HistogramCuts *cuts, common::AlignedResourceReadStream *fi) {
|
||||||
|
if (!common::ReadVec(fi, &cuts->cut_values_.HostVector())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!fi->Read(&cuts->cut_ptrs_.HostVector())) {
|
if (!common::ReadVec(fi, &cuts->cut_ptrs_.HostVector())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!fi->Read(&cuts->min_vals_.HostVector())) {
|
if (!common::ReadVec(fi, &cuts->min_vals_.HostVector())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t WriteHistogramCuts(common::HistogramCuts const &cuts, dmlc::Stream *fo) {
|
inline std::size_t WriteHistogramCuts(common::HistogramCuts const &cuts,
|
||||||
size_t bytes = 0;
|
common::AlignedFileWriteStream *fo) {
|
||||||
fo->Write(cuts.cut_values_.ConstHostVector());
|
std::size_t bytes = 0;
|
||||||
bytes += cuts.cut_values_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
bytes += common::WriteVec(fo, cuts.Values());
|
||||||
fo->Write(cuts.cut_ptrs_.ConstHostVector());
|
bytes += common::WriteVec(fo, cuts.Ptrs());
|
||||||
bytes += cuts.cut_ptrs_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
bytes += common::WriteVec(fo, cuts.MinValues());
|
||||||
fo->Write(cuts.min_vals_.ConstHostVector());
|
|
||||||
bytes += cuts.min_vals_.ConstHostSpan().size_bytes() + sizeof(uint64_t);
|
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
#endif // XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
|
||||||
|
|||||||
@ -240,9 +240,9 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
|||||||
* Generate gradient index.
|
* Generate gradient index.
|
||||||
*/
|
*/
|
||||||
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
|
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
|
||||||
size_t rbegin = 0;
|
std::size_t rbegin = 0;
|
||||||
size_t prev_sum = 0;
|
std::size_t prev_sum = 0;
|
||||||
size_t i = 0;
|
std::size_t i = 0;
|
||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
|
|||||||
@ -1,59 +1,57 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2015-2021 by Contributors
|
* Copyright 2015-2023, XGBoost Contributors
|
||||||
* \file sparse_page_raw_format.cc
|
* \file sparse_page_raw_format.cc
|
||||||
* Raw binary format of sparse page.
|
* Raw binary format of sparse page.
|
||||||
*/
|
*/
|
||||||
#include <xgboost/data.h>
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include "xgboost/logging.h"
|
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||||
|
#include "../common/ref_resource_view.h" // for WriteVec
|
||||||
#include "./sparse_page_writer.h"
|
#include "./sparse_page_writer.h"
|
||||||
|
#include "xgboost/data.h"
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format);
|
DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format);
|
||||||
|
|
||||||
template<typename T>
|
template <typename T>
|
||||||
class SparsePageRawFormat : public SparsePageFormat<T> {
|
class SparsePageRawFormat : public SparsePageFormat<T> {
|
||||||
public:
|
public:
|
||||||
bool Read(T* page, dmlc::SeekStream* fi) override {
|
bool Read(T* page, common::AlignedResourceReadStream* fi) override {
|
||||||
auto& offset_vec = page->offset.HostVector();
|
auto& offset_vec = page->offset.HostVector();
|
||||||
if (!fi->Read(&offset_vec)) {
|
if (!common::ReadVec(fi, &offset_vec)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto& data_vec = page->data.HostVector();
|
auto& data_vec = page->data.HostVector();
|
||||||
CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file";
|
CHECK_NE(page->offset.Size(), 0U) << "Invalid SparsePage file";
|
||||||
data_vec.resize(offset_vec.back());
|
data_vec.resize(offset_vec.back());
|
||||||
if (page->data.Size() != 0) {
|
if (page->data.Size() != 0) {
|
||||||
size_t n_bytes = fi->Read(dmlc::BeginPtr(data_vec),
|
if (!common::ReadVec(fi, &data_vec)) {
|
||||||
(page->data).Size() * sizeof(Entry));
|
return false;
|
||||||
CHECK_EQ(n_bytes, (page->data).Size() * sizeof(Entry))
|
}
|
||||||
<< "Invalid SparsePage file";
|
}
|
||||||
|
if (!fi->Read(&page->base_rowid, sizeof(page->base_rowid))) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
fi->Read(&page->base_rowid, sizeof(page->base_rowid));
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Write(const T& page, dmlc::Stream* fo) override {
|
std::size_t Write(const T& page, common::AlignedFileWriteStream* fo) override {
|
||||||
const auto& offset_vec = page.offset.HostVector();
|
const auto& offset_vec = page.offset.HostVector();
|
||||||
const auto& data_vec = page.data.HostVector();
|
const auto& data_vec = page.data.HostVector();
|
||||||
CHECK(page.offset.Size() != 0 && offset_vec[0] == 0);
|
CHECK(page.offset.Size() != 0 && offset_vec[0] == 0);
|
||||||
CHECK_EQ(offset_vec.back(), page.data.Size());
|
CHECK_EQ(offset_vec.back(), page.data.Size());
|
||||||
fo->Write(offset_vec);
|
|
||||||
auto bytes = page.MemCostBytes();
|
std::size_t bytes{0};
|
||||||
bytes += sizeof(uint64_t);
|
bytes += common::WriteVec(fo, offset_vec);
|
||||||
if (page.data.Size() != 0) {
|
if (page.data.Size() != 0) {
|
||||||
fo->Write(dmlc::BeginPtr(data_vec), page.data.Size() * sizeof(Entry));
|
bytes += common::WriteVec(fo, data_vec);
|
||||||
}
|
}
|
||||||
fo->Write(&page.base_rowid, sizeof(page.base_rowid));
|
bytes += fo->Write(&page.base_rowid, sizeof(page.base_rowid));
|
||||||
bytes += sizeof(page.base_rowid);
|
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief external memory column offset */
|
|
||||||
std::vector<size_t> disk_offset_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
|
XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
|
||||||
@ -74,5 +72,4 @@ XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw)
|
|||||||
return new SparsePageRawFormat<SortedCSCPage>();
|
return new SparsePageRawFormat<SortedCSCPage>();
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -6,9 +6,11 @@
|
|||||||
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
|
|
||||||
#include <algorithm> // for min
|
#include <algorithm> // for min
|
||||||
|
#include <atomic> // for atomic
|
||||||
#include <future> // for async
|
#include <future> // for async
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex> // for mutex
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <utility> // for pair, move
|
#include <utility> // for pair, move
|
||||||
@ -18,7 +20,6 @@
|
|||||||
#include "../common/io.h" // for PrivateMmapConstStream
|
#include "../common/io.h" // for PrivateMmapConstStream
|
||||||
#include "../common/timer.h" // for Monitor, Timer
|
#include "../common/timer.h" // for Monitor, Timer
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
#include "dmlc/common.h" // for OMPException
|
|
||||||
#include "proxy_dmatrix.h" // for DMatrixProxy
|
#include "proxy_dmatrix.h" // for DMatrixProxy
|
||||||
#include "sparse_page_writer.h" // for SparsePageFormat
|
#include "sparse_page_writer.h" // for SparsePageFormat
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
@ -93,6 +94,47 @@ class TryLockGuard {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Similar to `dmlc::OMPException`, but doesn't need the threads to be joined before rethrow
|
||||||
|
class ExceHandler {
|
||||||
|
std::mutex mutex_;
|
||||||
|
std::atomic<bool> flag_{false};
|
||||||
|
std::exception_ptr curr_exce_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <typename Fn>
|
||||||
|
decltype(auto) Run(Fn&& fn) noexcept(true) {
|
||||||
|
try {
|
||||||
|
return fn();
|
||||||
|
} catch (dmlc::Error const& e) {
|
||||||
|
std::lock_guard<std::mutex> guard{mutex_};
|
||||||
|
if (!curr_exce_) {
|
||||||
|
curr_exce_ = std::current_exception();
|
||||||
|
}
|
||||||
|
flag_ = true;
|
||||||
|
} catch (std::exception const& e) {
|
||||||
|
std::lock_guard<std::mutex> guard{mutex_};
|
||||||
|
if (!curr_exce_) {
|
||||||
|
curr_exce_ = std::current_exception();
|
||||||
|
}
|
||||||
|
flag_ = true;
|
||||||
|
} catch (...) {
|
||||||
|
std::lock_guard<std::mutex> guard{mutex_};
|
||||||
|
if (!curr_exce_) {
|
||||||
|
curr_exce_ = std::current_exception();
|
||||||
|
}
|
||||||
|
flag_ = true;
|
||||||
|
}
|
||||||
|
return std::invoke_result_t<Fn>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Rethrow() noexcept(false) {
|
||||||
|
if (flag_) {
|
||||||
|
CHECK(curr_exce_);
|
||||||
|
std::rethrow_exception(curr_exce_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
|
||||||
*/
|
*/
|
||||||
@ -122,7 +164,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
// Catching exception in pre-fetch threads to prevent segfault. Not always work though,
|
||||||
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
// OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then
|
||||||
// OOM error should be rare.
|
// OOM error should be rare.
|
||||||
dmlc::OMPException exec_;
|
ExceHandler exce_;
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
|
|
||||||
bool ReadCache() {
|
bool ReadCache() {
|
||||||
@ -141,7 +183,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||||
std::size_t fetch_it = count_;
|
std::size_t fetch_it = count_;
|
||||||
|
|
||||||
exec_.Rethrow();
|
exce_.Rethrow();
|
||||||
|
|
||||||
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||||
fetch_it %= n_batches_; // ring
|
fetch_it %= n_batches_; // ring
|
||||||
@ -152,7 +194,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
|
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
this->exec_.Run([&] {
|
this->exce_.Run([&] {
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||||
auto name = self->cache_info_->ShardName();
|
auto name = self->cache_info_->ShardName();
|
||||||
auto [offset, length] = self->cache_info_->View(fetch_it);
|
auto [offset, length] = self->cache_info_->View(fetch_it);
|
||||||
@ -172,7 +214,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
CHECK(!(*ring_)[count_].valid());
|
CHECK(!(*ring_)[count_].valid());
|
||||||
monitor_.Stop("Wait");
|
monitor_.Stop("Wait");
|
||||||
|
|
||||||
exec_.Rethrow();
|
exce_.Rethrow();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -184,11 +226,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||||
|
|
||||||
auto name = cache_info_->ShardName();
|
auto name = cache_info_->ShardName();
|
||||||
std::unique_ptr<dmlc::Stream> fo;
|
std::unique_ptr<common::AlignedFileWriteStream> fo;
|
||||||
if (this->Iter() == 0) {
|
if (this->Iter() == 0) {
|
||||||
fo.reset(dmlc::Stream::Create(name.c_str(), "wb"));
|
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "wb");
|
||||||
} else {
|
} else {
|
||||||
fo.reset(dmlc::Stream::Create(name.c_str(), "ab"));
|
fo = std::make_unique<common::AlignedFileWriteStream>(StringView{name}, "ab");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto bytes = fmt->Write(*page_, fo.get());
|
auto bytes = fmt->Write(*page_, fo.get());
|
||||||
|
|||||||
@ -1,52 +1,44 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2014-2019 by Contributors
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file sparse_page_writer.h
|
* \file sparse_page_writer.h
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
#ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
||||||
#define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
#define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
||||||
|
|
||||||
#include <xgboost/data.h>
|
#include <functional> // for function
|
||||||
#include <dmlc/io.h>
|
#include <string> // for string
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstring>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include <functional>
|
|
||||||
|
|
||||||
#if DMLC_ENABLE_STD_THREAD
|
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||||
#include <dmlc/concurrency.h>
|
#include "dmlc/io.h" // for Stream
|
||||||
#include <thread>
|
#include "dmlc/registry.h" // for Registry, FunctionRegEntryBase
|
||||||
#endif // DMLC_ENABLE_STD_THREAD
|
#include "xgboost/data.h" // for SparsePage,CSCPage,SortedCSCPage,EllpackPage ...
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace data {
|
|
||||||
|
|
||||||
|
namespace xgboost::data {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct SparsePageFormatReg;
|
struct SparsePageFormatReg;
|
||||||
|
|
||||||
/*!
|
/**
|
||||||
* \brief Format specification of SparsePage.
|
* @brief Format specification of various data formats like SparsePage.
|
||||||
*/
|
*/
|
||||||
template<typename T>
|
template <typename T>
|
||||||
class SparsePageFormat {
|
class SparsePageFormat {
|
||||||
public:
|
public:
|
||||||
/*! \brief virtual destructor */
|
|
||||||
virtual ~SparsePageFormat() = default;
|
virtual ~SparsePageFormat() = default;
|
||||||
/*!
|
/**
|
||||||
* \brief Load all the segments into page, advance fi to end of the block.
|
* @brief Load all the segments into page, advance fi to end of the block.
|
||||||
* \param page The data to read page into.
|
*
|
||||||
* \param fi the input stream of the file
|
* @param page The data to read page into.
|
||||||
* \return true of the loading as successful, false if end of file was reached
|
* @param fi the input stream of the file
|
||||||
|
* @return true of the loading as successful, false if end of file was reached
|
||||||
*/
|
*/
|
||||||
virtual bool Read(T* page, dmlc::SeekStream* fi) = 0;
|
virtual bool Read(T* page, common::AlignedResourceReadStream* fi) = 0;
|
||||||
/*!
|
/**
|
||||||
* \brief save the data to fo, when a page was written.
|
* @brief save the data to fo, when a page was written.
|
||||||
* \param fo output stream
|
*
|
||||||
|
* @param fo output stream
|
||||||
*/
|
*/
|
||||||
virtual size_t Write(const T& page, dmlc::Stream* fo) = 0;
|
virtual size_t Write(const T& page, common::AlignedFileWriteStream* fo) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -105,6 +97,5 @@ struct SparsePageFormatReg
|
|||||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<GHistIndexMatrix>, \
|
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<GHistIndexMatrix>, \
|
||||||
GHistIndexPageFmt, Name)
|
GHistIndexPageFmt, Name)
|
||||||
|
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
#endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
||||||
|
|||||||
@ -634,6 +634,22 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
|||||||
return cpu_predictor_;
|
return cpu_predictor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Data comes from SparsePageDMatrix. Since we are loading data in pages, no need to
|
||||||
|
// prevent data copy.
|
||||||
|
if (f_dmat && !f_dmat->SingleColBlock()) {
|
||||||
|
if (ctx_->IsCPU()) {
|
||||||
|
return cpu_predictor_;
|
||||||
|
} else {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||||
|
return gpu_predictor_;
|
||||||
|
#else
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
return cpu_predictor_;
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Data comes from Device DMatrix.
|
// Data comes from Device DMatrix.
|
||||||
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
|
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
|
||||||
!f_dmat->PageExists<SparsePage>();
|
!f_dmat->PageExists<SparsePage>();
|
||||||
|
|||||||
@ -3,11 +3,12 @@
|
|||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <cstddef> // for size_t
|
||||||
|
#include <fstream> // for ofstream
|
||||||
|
|
||||||
#include "../../../src/common/io.h"
|
#include "../../../src/common/io.h"
|
||||||
#include "../helpers.h"
|
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost::common {
|
namespace xgboost::common {
|
||||||
TEST(MemoryFixSizeBuffer, Seek) {
|
TEST(MemoryFixSizeBuffer, Seek) {
|
||||||
@ -89,6 +90,57 @@ TEST(IO, LoadSequentialFile) {
|
|||||||
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
|
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(IO, Resource) {
|
||||||
|
{
|
||||||
|
// test malloc basic
|
||||||
|
std::size_t n = 128;
|
||||||
|
std::shared_ptr<ResourceHandler> resource = std::make_shared<MallocResource>(n);
|
||||||
|
ASSERT_EQ(resource->Size(), n);
|
||||||
|
ASSERT_EQ(resource->Type(), ResourceHandler::kMalloc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// test malloc resize
|
||||||
|
auto test_malloc_resize = [](bool force_malloc) {
|
||||||
|
std::size_t n = 64;
|
||||||
|
std::shared_ptr<ResourceHandler> resource = std::make_shared<MallocResource>(n);
|
||||||
|
auto ptr = reinterpret_cast<std::uint8_t *>(resource->Data());
|
||||||
|
std::iota(ptr, ptr + n, 0);
|
||||||
|
|
||||||
|
auto malloc_resource = std::dynamic_pointer_cast<MallocResource>(resource);
|
||||||
|
ASSERT_TRUE(malloc_resource);
|
||||||
|
if (force_malloc) {
|
||||||
|
malloc_resource->Resize<true>(n * 2);
|
||||||
|
} else {
|
||||||
|
malloc_resource->Resize<false>(n * 2);
|
||||||
|
}
|
||||||
|
for (std::size_t i = 0; i < n; ++i) {
|
||||||
|
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], i) << force_malloc;
|
||||||
|
}
|
||||||
|
for (std::size_t i = n; i < 2 * n; ++i) {
|
||||||
|
ASSERT_EQ(malloc_resource->DataAs<std::uint8_t>()[i], 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
test_malloc_resize(true);
|
||||||
|
test_malloc_resize(false);
|
||||||
|
|
||||||
|
{
|
||||||
|
// test mmap
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto path = tmpdir.path + "/testfile";
|
||||||
|
|
||||||
|
std::ofstream fout(path, std::ios::binary);
|
||||||
|
double val{1.0};
|
||||||
|
fout.write(reinterpret_cast<char const *>(&val), sizeof(val));
|
||||||
|
fout << 1.0 << std::endl;
|
||||||
|
fout.close();
|
||||||
|
|
||||||
|
auto resource = std::make_shared<MmapResource>(path, 0, sizeof(double));
|
||||||
|
ASSERT_EQ(resource->Size(), sizeof(double));
|
||||||
|
ASSERT_EQ(resource->Type(), ResourceHandler::kMmap);
|
||||||
|
ASSERT_EQ(resource->DataAs<double>()[0], val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(IO, PrivateMmapStream) {
|
TEST(IO, PrivateMmapStream) {
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
auto path = tempdir.path + "/testfile";
|
auto path = tempdir.path + "/testfile";
|
||||||
@ -124,17 +176,35 @@ TEST(IO, PrivateMmapStream) {
|
|||||||
// Turn size info offset
|
// Turn size info offset
|
||||||
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
||||||
|
|
||||||
|
// Test read
|
||||||
for (std::size_t i = 0; i < n_batches; ++i) {
|
for (std::size_t i = 0; i < n_batches; ++i) {
|
||||||
std::size_t off = offset[i];
|
std::size_t off = offset[i];
|
||||||
std::size_t n = offset.at(i + 1) - offset[i];
|
std::size_t n = offset.at(i + 1) - offset[i];
|
||||||
std::unique_ptr<dmlc::Stream> fi{std::make_unique<PrivateMmapConstStream>(path, off, n)};
|
auto fi{std::make_unique<PrivateMmapConstStream>(path, off, n)};
|
||||||
std::vector<T> data;
|
std::vector<T> data;
|
||||||
|
|
||||||
std::uint64_t size{0};
|
std::uint64_t size{0};
|
||||||
fi->Read(&size);
|
ASSERT_TRUE(fi->Read(&size));
|
||||||
|
ASSERT_EQ(fi->Tell(), sizeof(size));
|
||||||
data.resize(size);
|
data.resize(size);
|
||||||
|
|
||||||
fi->Read(data.data(), size * sizeof(T));
|
ASSERT_EQ(fi->Read(data.data(), size * sizeof(T)), size * sizeof(T));
|
||||||
|
ASSERT_EQ(data, batches[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test consume
|
||||||
|
for (std::size_t i = 0; i < n_batches; ++i) {
|
||||||
|
std::size_t off = offset[i];
|
||||||
|
std::size_t n = offset.at(i + 1) - offset[i];
|
||||||
|
std::unique_ptr<AlignedResourceReadStream> fi{std::make_unique<PrivateMmapConstStream>(path, off, n)};
|
||||||
|
std::vector<T> data;
|
||||||
|
|
||||||
|
std::uint64_t size{0};
|
||||||
|
ASSERT_TRUE(fi->Consume(&size));
|
||||||
|
ASSERT_EQ(fi->Tell(), sizeof(size));
|
||||||
|
data.resize(size);
|
||||||
|
|
||||||
|
ASSERT_EQ(fi->Read(data.data(), size * sizeof(T)), sizeof(T) * size);
|
||||||
ASSERT_EQ(data, batches[i]);
|
ASSERT_EQ(data, batches[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
108
tests/cpp/common/test_ref_resource_view.cc
Normal file
108
tests/cpp/common/test_ref_resource_view.cc
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023, XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <memory> // for make_shared, make_unique
|
||||||
|
#include <numeric> // for iota
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
|
#include "../../../src/common/ref_resource_view.h"
|
||||||
|
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
|
TEST(RefResourceView, Basic) {
|
||||||
|
std::size_t n_bytes = 1024;
|
||||||
|
auto mem = std::make_shared<MallocResource>(n_bytes);
|
||||||
|
{
|
||||||
|
RefResourceView view{reinterpret_cast<float*>(mem->Data()), mem->Size() / sizeof(float), mem};
|
||||||
|
|
||||||
|
RefResourceView kview{reinterpret_cast<float const*>(mem->Data()), mem->Size() / sizeof(float),
|
||||||
|
mem};
|
||||||
|
ASSERT_EQ(mem.use_count(), 3);
|
||||||
|
ASSERT_EQ(view.size(), n_bytes / sizeof(1024));
|
||||||
|
ASSERT_EQ(kview.size(), n_bytes / sizeof(1024));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
RefResourceView view{reinterpret_cast<float*>(mem->Data()), mem->Size() / sizeof(float), mem,
|
||||||
|
1.5f};
|
||||||
|
for (auto v : view) {
|
||||||
|
ASSERT_EQ(v, 1.5f);
|
||||||
|
}
|
||||||
|
std::iota(view.begin(), view.end(), 0.0f);
|
||||||
|
ASSERT_EQ(view.front(), 0.0f);
|
||||||
|
ASSERT_EQ(view.back(), static_cast<float>(view.size() - 1));
|
||||||
|
|
||||||
|
view.front() = 1.0f;
|
||||||
|
view.back() = 2.0f;
|
||||||
|
ASSERT_EQ(view.front(), 1.0f);
|
||||||
|
ASSERT_EQ(view.back(), 2.0f);
|
||||||
|
}
|
||||||
|
ASSERT_EQ(mem.use_count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefResourceView, IO) {
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto path = tmpdir.path + "/testfile";
|
||||||
|
auto data = MakeFixedVecWithMalloc(123, std::size_t{1});
|
||||||
|
|
||||||
|
{
|
||||||
|
auto fo = std::make_unique<AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
|
ASSERT_EQ(fo->Write(data.data(), data.size_bytes()), data.size_bytes());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto fo = std::make_unique<AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
|
ASSERT_EQ(WriteVec(fo.get(), data),
|
||||||
|
data.size_bytes() + sizeof(RefResourceView<std::size_t>::size_type));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto fi = std::make_unique<PrivateMmapConstStream>(
|
||||||
|
path, 0, data.size_bytes() + sizeof(RefResourceView<std::size_t>::size_type));
|
||||||
|
auto read = MakeFixedVecWithMalloc(123, std::size_t{1});
|
||||||
|
ASSERT_TRUE(ReadVec(fi.get(), &read));
|
||||||
|
for (auto v : read) {
|
||||||
|
ASSERT_EQ(v, 1ul);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RefResourceView, IOAligned) {
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
auto path = tmpdir.path + "/testfile";
|
||||||
|
auto data = MakeFixedVecWithMalloc(123, 1.0f);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto fo = std::make_unique<AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
|
// + sizeof(float) for alignment
|
||||||
|
ASSERT_EQ(WriteVec(fo.get(), data),
|
||||||
|
data.size_bytes() + sizeof(RefResourceView<std::size_t>::size_type) + sizeof(float));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto fi = std::make_unique<PrivateMmapConstStream>(
|
||||||
|
path, 0, data.size_bytes() + sizeof(RefResourceView<std::size_t>::size_type));
|
||||||
|
// wrong type, float vs. double
|
||||||
|
auto read = MakeFixedVecWithMalloc(123, 2.0);
|
||||||
|
ASSERT_FALSE(ReadVec(fi.get(), &read));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto fi = std::make_unique<PrivateMmapConstStream>(
|
||||||
|
path, 0, data.size_bytes() + sizeof(RefResourceView<std::size_t>::size_type));
|
||||||
|
auto read = MakeFixedVecWithMalloc(123, 2.0f);
|
||||||
|
ASSERT_TRUE(ReadVec(fi.get(), &read));
|
||||||
|
for (auto v : read) {
|
||||||
|
ASSERT_EQ(v, 1ul);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Test std::vector
|
||||||
|
std::vector<float> data(123);
|
||||||
|
std::iota(data.begin(), data.end(), 0.0f);
|
||||||
|
auto fo = std::make_unique<AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
|
// + sizeof(float) for alignment
|
||||||
|
ASSERT_EQ(WriteVec(fo.get(), data), data.size() * sizeof(float) +
|
||||||
|
sizeof(RefResourceView<std::size_t>::size_type) +
|
||||||
|
sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
@ -4,14 +4,14 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream...
|
||||||
#include "../../../src/data/ellpack_page.cuh"
|
#include "../../../src/data/ellpack_page.cuh"
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_source.h"
|
||||||
#include "../../../src/tree/param.h" // TrainParam
|
#include "../../../src/tree/param.h" // TrainParam
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
TEST(EllpackPageRawFormat, IO) {
|
TEST(EllpackPageRawFormat, IO) {
|
||||||
Context ctx{MakeCUDACtx(0)};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||||
@ -22,15 +22,17 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string path = tmpdir.path + "/ellpack.page";
|
std::string path = tmpdir.path + "/ellpack.page";
|
||||||
|
|
||||||
|
std::size_t n_bytes{0};
|
||||||
{
|
{
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
auto fo = std::make_unique<common::AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
format->Write(ellpack, fo.get());
|
n_bytes += format->Write(ellpack, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EllpackPage page;
|
EllpackPage page;
|
||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
std::unique_ptr<common::AlignedResourceReadStream> fi{
|
||||||
|
std::make_unique<common::PrivateMmapConstStream>(path.c_str(), 0, n_bytes)};
|
||||||
format->Read(&page, fi.get());
|
format->Read(&page, fi.get());
|
||||||
|
|
||||||
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
for (auto const &ellpack : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||||
@ -44,5 +46,4 @@ TEST(EllpackPageRawFormat, IO) {
|
|||||||
ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector());
|
ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -26,8 +26,7 @@
|
|||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
TEST(GradientIndex, ExternalMemory) {
|
TEST(GradientIndex, ExternalMemory) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
|
||||||
@ -171,7 +170,7 @@ class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, flo
|
|||||||
gpu_ctx.gpu_id = 0;
|
gpu_ctx.gpu_id = 0;
|
||||||
for (auto const &page : Xy->GetBatches<EllpackPage>(
|
for (auto const &page : Xy->GetBatches<EllpackPage>(
|
||||||
&gpu_ctx, BatchParam{kBins, tree::TrainParam::DftSparseThreshold()})) {
|
&gpu_ctx, BatchParam{kBins, tree::TrainParam::DftSparseThreshold()})) {
|
||||||
from_ellpack.reset(new GHistIndexMatrix{&ctx, Xy->Info(), page, p});
|
from_ellpack = std::make_unique<GHistIndexMatrix>(&ctx, Xy->Info(), page, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
for (auto const &from_sparse_page : Xy->GetBatches<GHistIndexMatrix>(&ctx, p)) {
|
||||||
@ -199,13 +198,15 @@ class GHistIndexMatrixTest : public testing::TestWithParam<std::tuple<float, flo
|
|||||||
|
|
||||||
std::string from_sparse_buf;
|
std::string from_sparse_buf;
|
||||||
{
|
{
|
||||||
common::MemoryBufferStream fo{&from_sparse_buf};
|
common::AlignedMemWriteStream fo{&from_sparse_buf};
|
||||||
columns_from_sparse.Write(&fo);
|
auto n_bytes = columns_from_sparse.Write(&fo);
|
||||||
|
ASSERT_EQ(fo.Tell(), n_bytes);
|
||||||
}
|
}
|
||||||
std::string from_ellpack_buf;
|
std::string from_ellpack_buf;
|
||||||
{
|
{
|
||||||
common::MemoryBufferStream fo{&from_ellpack_buf};
|
common::AlignedMemWriteStream fo{&from_ellpack_buf};
|
||||||
columns_from_sparse.Write(&fo);
|
auto n_bytes = columns_from_sparse.Write(&fo);
|
||||||
|
ASSERT_EQ(fo.Tell(), n_bytes);
|
||||||
}
|
}
|
||||||
ASSERT_EQ(from_sparse_buf, from_ellpack_buf);
|
ASSERT_EQ(from_sparse_buf, from_ellpack_buf);
|
||||||
}
|
}
|
||||||
@ -229,5 +230,4 @@ INSTANTIATE_TEST_SUITE_P(GHistIndexMatrix, GHistIndexMatrixTest,
|
|||||||
std::make_tuple(.6f, .4))); // dense columns
|
std::make_tuple(.6f, .4))); // dense columns
|
||||||
|
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -2,14 +2,18 @@
|
|||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/context.h> // for Context
|
||||||
|
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
|
||||||
#include "../../../src/common/column_matrix.h"
|
#include "../../../src/common/column_matrix.h"
|
||||||
#include "../../../src/data/gradient_index.h"
|
#include "../../../src/common/io.h" // for MmapResource, AlignedResourceReadStream...
|
||||||
|
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "../../../src/data/sparse_page_source.h"
|
#include "../../../src/data/sparse_page_source.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
TEST(GHistIndexPageRawFormat, IO) {
|
TEST(GHistIndexPageRawFormat, IO) {
|
||||||
Context ctx;
|
Context ctx;
|
||||||
|
|
||||||
@ -20,15 +24,18 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
std::string path = tmpdir.path + "/ghistindex.page";
|
std::string path = tmpdir.path + "/ghistindex.page";
|
||||||
auto batch = BatchParam{256, 0.5};
|
auto batch = BatchParam{256, 0.5};
|
||||||
|
|
||||||
|
std::size_t bytes{0};
|
||||||
{
|
{
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
auto fo = std::make_unique<common::AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
for (auto const &index : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
for (auto const &index : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
format->Write(index, fo.get());
|
bytes += format->Write(index, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GHistIndexMatrix page;
|
GHistIndexMatrix page;
|
||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
|
||||||
|
std::unique_ptr<common::AlignedResourceReadStream> fi{
|
||||||
|
std::make_unique<common::PrivateMmapConstStream>(path, 0, bytes)};
|
||||||
format->Read(&page, fi.get());
|
format->Read(&page, fi.get());
|
||||||
|
|
||||||
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
|
||||||
@ -37,6 +44,8 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
|
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
|
||||||
ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
|
ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
|
||||||
ASSERT_EQ(loaded.base_rowid, page.base_rowid);
|
ASSERT_EQ(loaded.base_rowid, page.base_rowid);
|
||||||
|
ASSERT_EQ(loaded.row_ptr.size(), page.row_ptr.size());
|
||||||
|
ASSERT_TRUE(std::equal(loaded.row_ptr.cbegin(), loaded.row_ptr.cend(), page.row_ptr.cbegin()));
|
||||||
ASSERT_EQ(loaded.IsDense(), page.IsDense());
|
ASSERT_EQ(loaded.IsDense(), page.IsDense());
|
||||||
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
|
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
|
||||||
ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
|
ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
|
||||||
@ -45,5 +54,4 @@ TEST(GHistIndexPageRawFormat, IO) {
|
|||||||
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
|
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -2,20 +2,20 @@
|
|||||||
* Copyright 2021-2023, XGBoost contributors
|
* Copyright 2021-2023, XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/data.h> // for CSCPage, SortedCSCPage, SparsePage
|
#include <xgboost/data.h> // for CSCPage, SortedCSCPage, SparsePage
|
||||||
|
|
||||||
#include <memory> // for allocator, unique_ptr, __shared_ptr_ac...
|
#include <memory> // for allocator, unique_ptr, __shared_ptr_ac...
|
||||||
#include <string> // for char_traits, operator+, basic_string
|
#include <string> // for char_traits, operator+, basic_string
|
||||||
|
|
||||||
|
#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream...
|
||||||
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
#include "../../../src/data/sparse_page_writer.h" // for CreatePageFormat
|
||||||
#include "../helpers.h" // for RandomDataGenerator
|
#include "../helpers.h" // for RandomDataGenerator
|
||||||
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
#include "dmlc/filesystem.h" // for TemporaryDirectory
|
||||||
#include "dmlc/io.h" // for SeekStream, Stream
|
#include "dmlc/io.h" // for Stream
|
||||||
#include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST
|
#include "gtest/gtest_pred_impl.h" // for Test, AssertionResult, ASSERT_EQ, TEST
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::data {
|
||||||
namespace data {
|
|
||||||
template <typename S> void TestSparsePageRawFormat() {
|
template <typename S> void TestSparsePageRawFormat() {
|
||||||
std::unique_ptr<SparsePageFormat<S>> format{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> format{CreatePageFormat<S>("raw")};
|
||||||
Context ctx;
|
Context ctx;
|
||||||
@ -25,17 +25,19 @@ template <typename S> void TestSparsePageRawFormat() {
|
|||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string path = tmpdir.path + "/sparse.page";
|
std::string path = tmpdir.path + "/sparse.page";
|
||||||
S orig;
|
S orig;
|
||||||
|
std::size_t n_bytes{0};
|
||||||
{
|
{
|
||||||
// block code to flush the stream
|
// block code to flush the stream
|
||||||
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
auto fo = std::make_unique<common::AlignedFileWriteStream>(StringView{path}, "wb");
|
||||||
for (auto const &page : m->GetBatches<S>(&ctx)) {
|
for (auto const &page : m->GetBatches<S>(&ctx)) {
|
||||||
orig.Push(page);
|
orig.Push(page);
|
||||||
format->Write(page, fo.get());
|
n_bytes = format->Write(page, fo.get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
S page;
|
S page;
|
||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
|
std::unique_ptr<common::AlignedResourceReadStream> fi{
|
||||||
|
std::make_unique<common::PrivateMmapConstStream>(path.c_str(), 0, n_bytes)};
|
||||||
format->Read(&page, fi.get());
|
format->Read(&page, fi.get());
|
||||||
for (size_t i = 0; i < orig.data.Size(); ++i) {
|
for (size_t i = 0; i < orig.data.Size(); ++i) {
|
||||||
ASSERT_EQ(page.data.HostVector()[i].fvalue,
|
ASSERT_EQ(page.data.HostVector()[i].fvalue,
|
||||||
@ -59,5 +61,4 @@ TEST(SparsePageRawFormat, CSCPage) {
|
|||||||
TEST(SparsePageRawFormat, SortedCSCPage) {
|
TEST(SparsePageRawFormat, SortedCSCPage) {
|
||||||
TestSparsePageRawFormat<SortedCSCPage>();
|
TestSparsePageRawFormat<SortedCSCPage>();
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user