Use mmap for external memory. (#9282)
- Have basic infrastructure for mmap. - Release file write handle.
This commit is contained in:
parent
d8beb517ed
commit
ee6809e642
@ -82,10 +82,10 @@ def main(tmpdir: str) -> xgboost.Booster:
|
|||||||
missing = np.NaN
|
missing = np.NaN
|
||||||
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
|
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
|
||||||
|
|
||||||
# Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in
|
# Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU
|
||||||
# doc for details.
|
# behaves differently than CPU tree methods. See tutorial in doc for details.
|
||||||
booster = xgboost.train(
|
booster = xgboost.train(
|
||||||
{"tree_method": "approx", "max_depth": 2},
|
{"tree_method": "hist", "max_depth": 4},
|
||||||
Xy,
|
Xy,
|
||||||
evals=[(Xy, "Train")],
|
evals=[(Xy, "Train")],
|
||||||
num_boost_round=10,
|
num_boost_round=10,
|
||||||
|
|||||||
@ -2,11 +2,25 @@
|
|||||||
Using XGBoost External Memory Version
|
Using XGBoost External Memory Version
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
XGBoost supports loading data from external memory using builtin data parser. And
|
When working with large datasets, training XGBoost models can be challenging as the entire
|
||||||
starting from version 1.5, users can also define a custom iterator to load data in chunks.
|
dataset needs to be loaded into memory. This can be costly and sometimes
|
||||||
The feature is still experimental and not yet ready for production use. In this tutorial
|
infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks
|
||||||
we will introduce both methods. Please note that training on data from external memory is
|
for running XGBoost algorithms. External memory can be used for both training and
|
||||||
not supported by ``exact`` tree method.
|
prediction, but training is the primary use case and it will be our focus in this
|
||||||
|
tutorial. For prediction and evaluation, users can iterate through the data themseleves
|
||||||
|
while training requires the full dataset to be loaded into the memory.
|
||||||
|
|
||||||
|
During training, there are two different modes for external memory support available in
|
||||||
|
XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the
|
||||||
|
GPU-based training algorithm. We will introduce them in the following sections.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Training on data from external memory is not supported by the ``exact`` tree method.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The feature is still experimental as of 2.0. The performance is not well optimized.
|
||||||
|
|
||||||
*************
|
*************
|
||||||
Data Iterator
|
Data Iterator
|
||||||
@ -15,8 +29,8 @@ Data Iterator
|
|||||||
Starting from XGBoost 1.5, users can define their own data loader using Python or C
|
Starting from XGBoost 1.5, users can define their own data loader using Python or C
|
||||||
interface. There are some examples in the ``demo`` directory for quick start. This is a
|
interface. There are some examples in the ``demo`` directory for quick start. This is a
|
||||||
generalized version of text input external memory, where users no longer need to prepare a
|
generalized version of text input external memory, where users no longer need to prepare a
|
||||||
text file that XGBoost recognizes. To enable the feature, user need to define a data
|
text file that XGBoost recognizes. To enable the feature, users need to define a data
|
||||||
iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix``
|
iterator with 2 class methods: ``next`` and ``reset``, then pass it into the ``DMatrix``
|
||||||
constructor.
|
constructor.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -60,20 +74,96 @@ constructor.
|
|||||||
|
|
||||||
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
|
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
|
||||||
# as noted in following sections.
|
# as noted in following sections.
|
||||||
booster = xgboost.train({"tree_method": "approx"}, Xy)
|
booster = xgboost.train({"tree_method": "hist"}, Xy)
|
||||||
|
|
||||||
|
|
||||||
The above snippet is a simplified version of ``demo/guide-python/external_memory.py``. For
|
The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`.
|
||||||
an example in C, please see ``demo/c-api/external-memory/``.
|
For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the
|
||||||
|
common interface for using external memory with XGBoost, you can pass the resulting
|
||||||
|
``DMatrix`` object for training, prediction, and evaluation.
|
||||||
|
|
||||||
|
It is important to set the batch size based on the memory available. A good starting point
|
||||||
|
is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not*
|
||||||
|
recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt
|
||||||
|
performance in gradient boosting.
|
||||||
|
|
||||||
|
***********
|
||||||
|
CPU Version
|
||||||
|
***********
|
||||||
|
|
||||||
|
In the previous section, we demonstrated how to train a tree-based model using the
|
||||||
|
``hist`` tree method on a CPU. This method involves iterating through data batches stored
|
||||||
|
in a cache during tree construction. For optimal performance, we recommend using the
|
||||||
|
``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree
|
||||||
|
nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy
|
||||||
|
requires XGBoost to iterate over the data set for each tree node, resulting in slower
|
||||||
|
performance.
|
||||||
|
|
||||||
|
If external memory is used, the performance of CPU training is limited by IO
|
||||||
|
(input/output) speed. This means that the disk IO speed primarily determines the training
|
||||||
|
speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of
|
||||||
|
storage can be too slow for practical usage. In addition, your system may perform caching
|
||||||
|
to reduce the overhead of file reading.
|
||||||
|
|
||||||
|
**********************************
|
||||||
|
GPU Version (GPU Hist tree method)
|
||||||
|
**********************************
|
||||||
|
|
||||||
|
External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set to
|
||||||
|
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
|
||||||
|
CPU. When training on a CPU, the tree method iterates through all batches from external
|
||||||
|
memory for each step of the tree construction algorithm. On the other hand, the GPU
|
||||||
|
algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall
|
||||||
|
memory usage, users can utilize subsampling. The good news is that the GPU hist tree
|
||||||
|
method supports gradient-based sampling, enabling users to set a low sampling rate without
|
||||||
|
compromising accuracy.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
param = {
|
||||||
|
...
|
||||||
|
'subsample': 0.2,
|
||||||
|
'sampling_method': 'gradient_based',
|
||||||
|
}
|
||||||
|
|
||||||
|
For more information about the sampling algorithm and its use in external memory training,
|
||||||
|
see `this paper <https://arxiv.org/abs/2005.09148>`_.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
When GPU is running out of memory during iteration on external memory, user might
|
||||||
|
recieve a segfault instead of an OOM exception.
|
||||||
|
|
||||||
|
*******
|
||||||
|
Remarks
|
||||||
|
*******
|
||||||
|
|
||||||
|
When using external memory with XBGoost, data is divided into smaller chunks so that only
|
||||||
|
a fraction of it needs to be stored in memory at any given time. It's important to note
|
||||||
|
that this method only applies to the predictor data (``X``), while other data, like labels
|
||||||
|
and internal runtime structures are concatenated. This means that memory reduction is most
|
||||||
|
effective when dealing with wide datasets where ``X`` is larger compared to other data
|
||||||
|
like ``y``, while it has little impact on slim datasets.
|
||||||
|
|
||||||
|
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
|
||||||
|
yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's
|
||||||
|
worth noting that most tests have been conducted on Linux distributions.
|
||||||
|
|
||||||
|
Another important point to keep in mind is that creating the initial cache for XGBoost may
|
||||||
|
take some time. The interface to external memory is through custom iterators, which may or
|
||||||
|
may not be thread-safe. Therefore, initialization is performed sequentially.
|
||||||
|
|
||||||
|
|
||||||
****************
|
****************
|
||||||
Text File Inputs
|
Text File Inputs
|
||||||
****************
|
****************
|
||||||
|
|
||||||
There is no big difference between using external memory version and in-memory version.
|
This is the original form of external memory support, users are encouraged to use custom
|
||||||
The only difference is the filename format.
|
data iterator instead. There is no big difference between using external memory version of
|
||||||
|
text input and the in-memory version. The only difference is the filename format.
|
||||||
|
|
||||||
The external memory version takes in the following `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:
|
The external memory version takes in the following `URI
|
||||||
|
<https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:
|
||||||
|
|
||||||
.. code-block:: none
|
.. code-block:: none
|
||||||
|
|
||||||
@ -91,9 +181,8 @@ To load from csv files, use the following syntax:
|
|||||||
|
|
||||||
where ``label_column`` should point to the csv column acting as the label.
|
where ``label_column`` should point to the csv column acting as the label.
|
||||||
|
|
||||||
To provide a simple example for illustration, extracting the code from
|
If you have a dataset stored in a file similar to ``demo/data/agaricus.txt.train`` with LIBSVM
|
||||||
`demo/guide-python/external_memory.py <https://github.com/dmlc/xgboost/blob/master/demo/guide-python/external_memory.py>`_. If
|
format, the external memory support can be enabled by:
|
||||||
you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSVM format, the external memory support can be enabled by:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -104,35 +193,3 @@ XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to
|
|||||||
more notes about text input formats, see :doc:`/tutorials/input_format`.
|
more notes about text input formats, see :doc:`/tutorials/input_format`.
|
||||||
|
|
||||||
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
|
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
|
||||||
|
|
||||||
|
|
||||||
**********************************
|
|
||||||
GPU Version (GPU Hist tree method)
|
|
||||||
**********************************
|
|
||||||
External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``).
|
|
||||||
|
|
||||||
If you are still getting out-of-memory errors after enabling external memory, try subsampling the
|
|
||||||
data to further reduce GPU memory usage:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
param = {
|
|
||||||
...
|
|
||||||
'subsample': 0.1,
|
|
||||||
'sampling_method': 'gradient_based',
|
|
||||||
}
|
|
||||||
|
|
||||||
For more information, see `this paper <https://arxiv.org/abs/2005.09148>`_. Internally
|
|
||||||
the tree method still concatenate all the chunks into 1 final histogram index due to
|
|
||||||
performance reason, but in compressed format. So its scalability has an upper bound but
|
|
||||||
still has lower memory cost in general.
|
|
||||||
|
|
||||||
***********
|
|
||||||
CPU Version
|
|
||||||
***********
|
|
||||||
|
|
||||||
For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use
|
|
||||||
``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow,
|
|
||||||
with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few
|
|
||||||
iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each
|
|
||||||
tree node.
|
|
||||||
|
|||||||
@ -198,14 +198,14 @@ class IteratorForTest(xgb.core.DataIter):
|
|||||||
X: Sequence,
|
X: Sequence,
|
||||||
y: Sequence,
|
y: Sequence,
|
||||||
w: Optional[Sequence],
|
w: Optional[Sequence],
|
||||||
cache: Optional[str] = "./",
|
cache: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(X) == len(y)
|
assert len(X) == len(y)
|
||||||
self.X = X
|
self.X = X
|
||||||
self.y = y
|
self.y = y
|
||||||
self.w = w
|
self.w = w
|
||||||
self.it = 0
|
self.it = 0
|
||||||
super().__init__(cache)
|
super().__init__(cache_prefix=cache)
|
||||||
|
|
||||||
def next(self, input_data: Callable) -> int:
|
def next(self, input_data: Callable) -> int:
|
||||||
if self.it == len(self.X):
|
if self.it == len(self.X):
|
||||||
@ -347,7 +347,9 @@ class TestDataset:
|
|||||||
if w is not None:
|
if w is not None:
|
||||||
weight.append(w)
|
weight.append(w)
|
||||||
|
|
||||||
it = IteratorForTest(predictor, response, weight if weight else None)
|
it = IteratorForTest(
|
||||||
|
predictor, response, weight if weight else None, cache="cache"
|
||||||
|
)
|
||||||
return xgb.DMatrix(it)
|
return xgb.DMatrix(it)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|||||||
@ -1,18 +1,21 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) 2014-2019 by Contributors
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file io.h
|
* \file io.h
|
||||||
* \brief utilities with different serializable implementations
|
* \brief utilities with different serializable implementations
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
*/
|
*/
|
||||||
#ifndef RABIT_INTERNAL_IO_H_
|
#ifndef RABIT_INTERNAL_IO_H_
|
||||||
#define RABIT_INTERNAL_IO_H_
|
#define RABIT_INTERNAL_IO_H_
|
||||||
#include <cstdio>
|
|
||||||
#include <vector>
|
|
||||||
#include <cstring>
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring> // for memcpy
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <numeric>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "rabit/internal/utils.h"
|
#include "rabit/internal/utils.h"
|
||||||
#include "rabit/serializable.h"
|
#include "rabit/serializable.h"
|
||||||
|
|
||||||
@ -20,54 +23,61 @@ namespace rabit {
|
|||||||
namespace utils {
|
namespace utils {
|
||||||
/*! \brief re-use definition of dmlc::SeekStream */
|
/*! \brief re-use definition of dmlc::SeekStream */
|
||||||
using SeekStream = dmlc::SeekStream;
|
using SeekStream = dmlc::SeekStream;
|
||||||
/*! \brief fixed size memory buffer */
|
/**
|
||||||
|
* @brief Fixed size memory buffer as a stream.
|
||||||
|
*/
|
||||||
struct MemoryFixSizeBuffer : public SeekStream {
|
struct MemoryFixSizeBuffer : public SeekStream {
|
||||||
public:
|
public:
|
||||||
// similar to SEEK_END in libc
|
// similar to SEEK_END in libc
|
||||||
static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max();
|
static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
MemoryFixSizeBuffer() = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
/**
|
||||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
|
* @brief Ctor
|
||||||
buffer_size_(buffer_size) {
|
*
|
||||||
curr_ptr_ = 0;
|
* @param p_buffer Pointer to the source buffer with size `buffer_size`.
|
||||||
}
|
* @param buffer_size Size of the source buffer
|
||||||
|
*/
|
||||||
|
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
|
||||||
|
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
|
||||||
~MemoryFixSizeBuffer() override = default;
|
~MemoryFixSizeBuffer() override = default;
|
||||||
size_t Read(void *ptr, size_t size) override {
|
|
||||||
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
std::size_t Read(void *ptr, std::size_t size) override {
|
||||||
|
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
||||||
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
||||||
curr_ptr_ += nread;
|
curr_ptr_ += nread;
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
void Write(const void *ptr, size_t size) override {
|
void Write(const void *ptr, std::size_t size) override {
|
||||||
if (size == 0) return;
|
if (size == 0) return;
|
||||||
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
CHECK_LE(curr_ptr_ + size, buffer_size_);
|
||||||
"write position exceed fixed buffer size");
|
|
||||||
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
||||||
curr_ptr_ += size;
|
curr_ptr_ += size;
|
||||||
}
|
}
|
||||||
void Seek(size_t pos) override {
|
void Seek(std::size_t pos) override {
|
||||||
if (pos == kSeekEnd) {
|
if (pos == kSeekEnd) {
|
||||||
curr_ptr_ = buffer_size_;
|
curr_ptr_ = buffer_size_;
|
||||||
} else {
|
} else {
|
||||||
curr_ptr_ = static_cast<size_t>(pos);
|
curr_ptr_ = static_cast<std::size_t>(pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
size_t Tell() override {
|
/**
|
||||||
return curr_ptr_;
|
* @brief Current position in the buffer (stream).
|
||||||
}
|
*/
|
||||||
virtual bool AtEnd() const {
|
std::size_t Tell() override { return curr_ptr_; }
|
||||||
return curr_ptr_ == buffer_size_;
|
virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
/*! \brief in memory buffer */
|
/*! \brief in memory buffer */
|
||||||
char *p_buffer_;
|
char *p_buffer_{nullptr};
|
||||||
/*! \brief current pointer */
|
/*! \brief current pointer */
|
||||||
size_t buffer_size_;
|
std::size_t buffer_size_{0};
|
||||||
/*! \brief current pointer */
|
/*! \brief current pointer */
|
||||||
size_t curr_ptr_;
|
std::size_t curr_ptr_{0};
|
||||||
}; // class MemoryFixSizeBuffer
|
};
|
||||||
|
|
||||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||||
struct MemoryBufferStream : public SeekStream {
|
struct MemoryBufferStream : public SeekStream {
|
||||||
|
|||||||
176
src/common/io.cc
176
src/common/io.cc
@ -1,24 +1,47 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) by XGBoost Contributors 2019-2022
|
* Copyright 2019-2023, by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#if defined(__unix__)
|
#if !defined(NOMINMAX) && defined(_WIN32)
|
||||||
#include <sys/stat.h>
|
#define NOMINMAX
|
||||||
#include <fcntl.h>
|
#endif // !defined(NOMINMAX)
|
||||||
#include <unistd.h>
|
|
||||||
|
#if !defined(xgboost_IS_WIN)
|
||||||
|
|
||||||
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
|
#define xgboost_IS_WIN 1
|
||||||
|
#endif // defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
|
|
||||||
|
#endif // !defined(xgboost_IS_WIN)
|
||||||
|
|
||||||
|
#if defined(__unix__) || defined(__APPLE__)
|
||||||
|
#include <fcntl.h> // for open, O_RDONLY
|
||||||
|
#include <sys/mman.h> // for mmap, mmap64, munmap
|
||||||
|
#include <unistd.h> // for close, getpagesize
|
||||||
|
#elif defined(xgboost_IS_WIN)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#include <windows.h>
|
||||||
#endif // defined(__unix__)
|
#endif // defined(__unix__)
|
||||||
#include <algorithm>
|
|
||||||
#include <fstream>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include <cstdio>
|
|
||||||
|
|
||||||
#include "xgboost/logging.h"
|
#include <algorithm> // for copy, transform
|
||||||
|
#include <cctype> // for tolower
|
||||||
|
#include <cerrno> // for errno
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, uint32_t
|
||||||
|
#include <cstring> // for memcpy
|
||||||
|
#include <fstream> // for ifstream
|
||||||
|
#include <iterator> // for distance
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
#include <system_error> // for error_code, system_category
|
||||||
|
#include <utility> // for move
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "io.h"
|
#include "io.h"
|
||||||
|
#include "xgboost/collective/socket.h" // for LastError
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
|
|
||||||
size_t PeekableInStream::Read(void* dptr, size_t size) {
|
size_t PeekableInStream::Read(void* dptr, size_t size) {
|
||||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||||
if (nbuffer == 0) return strm_->Read(dptr, size);
|
if (nbuffer == 0) return strm_->Read(dptr, size);
|
||||||
@ -94,11 +117,32 @@ void FixedSizeStream::Take(std::string* out) {
|
|||||||
*out = std::move(buffer_);
|
*out = std::move(buffer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Get system alignment value for IO with mmap.
|
||||||
|
std::size_t GetMmapAlignment() {
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
SYSTEM_INFO sys_info;
|
||||||
|
GetSystemInfo(&sys_info);
|
||||||
|
// During testing, `sys_info.dwPageSize` is of size 4096 while `dwAllocationGranularity` is of
|
||||||
|
// size 65536.
|
||||||
|
return sys_info.dwAllocationGranularity;
|
||||||
|
#else
|
||||||
|
return getpagesize();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
auto SystemErrorMsg() {
|
||||||
|
std::int32_t errsv = system::LastError();
|
||||||
|
auto err = std::error_code{errsv, std::system_category()};
|
||||||
|
return err.message();
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
std::string LoadSequentialFile(std::string uri, bool stream) {
|
std::string LoadSequentialFile(std::string uri, bool stream) {
|
||||||
auto OpenErr = [&uri]() {
|
auto OpenErr = [&uri]() {
|
||||||
std::string msg;
|
std::string msg;
|
||||||
msg = "Opening " + uri + " failed: ";
|
msg = "Opening " + uri + " failed: ";
|
||||||
msg += strerror(errno);
|
msg += SystemErrorMsg();
|
||||||
LOG(FATAL) << msg;
|
LOG(FATAL) << msg;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -155,5 +199,99 @@ std::string FileExtension(std::string fname, bool lower) {
|
|||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
struct PrivateMmapConstStream::MMAPFile {
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
HANDLE fd{INVALID_HANDLE_VALUE};
|
||||||
|
HANDLE file_map{INVALID_HANDLE_VALUE};
|
||||||
|
#else
|
||||||
|
std::int32_t fd{0};
|
||||||
|
#endif
|
||||||
|
char* base_ptr{nullptr};
|
||||||
|
std::size_t base_size{0};
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
|
char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::size_t length) {
|
||||||
|
if (length == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
HANDLE fd = CreateFile(path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING,
|
||||||
|
FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, nullptr);
|
||||||
|
CHECK_NE(fd, INVALID_HANDLE_VALUE) << "Failed to open:" << path << ". " << SystemErrorMsg();
|
||||||
|
#else
|
||||||
|
auto fd = open(path.c_str(), O_RDONLY);
|
||||||
|
CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
char* ptr{nullptr};
|
||||||
|
// Round down for alignment.
|
||||||
|
auto view_start = offset / GetMmapAlignment() * GetMmapAlignment();
|
||||||
|
auto view_size = length + (offset - view_start);
|
||||||
|
|
||||||
|
#if defined(__linux__) || defined(__GLIBC__)
|
||||||
|
int prot{PROT_READ};
|
||||||
|
ptr = reinterpret_cast<char*>(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
||||||
|
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
|
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
|
||||||
|
#elif defined(xgboost_IS_WIN)
|
||||||
|
auto file_size = GetFileSize(fd, nullptr);
|
||||||
|
DWORD access = PAGE_READONLY;
|
||||||
|
auto map_file = CreateFileMapping(fd, nullptr, access, 0, file_size, nullptr);
|
||||||
|
access = FILE_MAP_READ;
|
||||||
|
std::uint32_t loff = static_cast<std::uint32_t>(view_start);
|
||||||
|
std::uint32_t hoff = view_start >> 32;
|
||||||
|
CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
|
ptr = reinterpret_cast<char*>(MapViewOfFile(map_file, access, hoff, loff, view_size));
|
||||||
|
CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
|
handle_.reset(new MMAPFile{fd, map_file, ptr, view_size, std::move(path)});
|
||||||
|
#else
|
||||||
|
CHECK_LE(offset, std::numeric_limits<off_t>::max())
|
||||||
|
<< "File size has exceeded the limit on the current system.";
|
||||||
|
int prot{PROT_READ};
|
||||||
|
ptr = reinterpret_cast<char*>(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
|
||||||
|
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
|
||||||
|
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
|
||||||
|
#endif // defined(__linux__)
|
||||||
|
|
||||||
|
ptr += (offset - view_start);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrivateMmapConstStream::PrivateMmapConstStream(std::string path, std::size_t offset,
|
||||||
|
std::size_t length)
|
||||||
|
: MemoryFixSizeBuffer{}, handle_{nullptr} {
|
||||||
|
this->p_buffer_ = Open(std::move(path), offset, length);
|
||||||
|
this->buffer_size_ = length;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrivateMmapConstStream::~PrivateMmapConstStream() {
|
||||||
|
CHECK(handle_);
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
if (p_buffer_) {
|
||||||
|
CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg();
|
||||||
|
}
|
||||||
|
if (handle_->fd != INVALID_HANDLE_VALUE) {
|
||||||
|
CHECK(CloseHandle(handle_->fd)) << "Failed to close handle: " << SystemErrorMsg();
|
||||||
|
}
|
||||||
|
if (handle_->file_map != INVALID_HANDLE_VALUE) {
|
||||||
|
CHECK(CloseHandle(handle_->file_map)) << "Failed to close mapping object: " << SystemErrorMsg();
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (handle_->base_ptr) {
|
||||||
|
CHECK_NE(munmap(handle_->base_ptr, handle_->base_size), -1)
|
||||||
|
<< "Faled to call munmap: " << handle_->path << ". " << SystemErrorMsg();
|
||||||
|
}
|
||||||
|
if (handle_->fd != 0) {
|
||||||
|
CHECK_NE(close(handle_->fd), -1)
|
||||||
|
<< "Faled to close: " << handle_->path << ". " << SystemErrorMsg();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
|
|
||||||
|
#if defined(xgboost_IS_WIN)
|
||||||
|
#undef xgboost_IS_WIN
|
||||||
|
#endif // defined(xgboost_IS_WIN)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright by XGBoost Contributors 2014-2022
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file io.h
|
* \file io.h
|
||||||
* \brief general stream interface for serialization, I/O
|
* \brief general stream interface for serialization, I/O
|
||||||
* \author Tianqi Chen
|
* \author Tianqi Chen
|
||||||
@ -10,9 +10,11 @@
|
|||||||
|
|
||||||
#include <dmlc/io.h>
|
#include <dmlc/io.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <string>
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <memory> // for unique_ptr
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
@ -127,6 +129,31 @@ inline std::string ReadAll(std::string const &path) {
|
|||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Private mmap file as a read-only stream.
|
||||||
|
*
|
||||||
|
* It can calculate alignment automatically based on system page size (or allocation
|
||||||
|
* granularity on Windows).
|
||||||
|
*/
|
||||||
|
class PrivateMmapConstStream : public MemoryFixSizeBuffer {
|
||||||
|
struct MMAPFile;
|
||||||
|
std::unique_ptr<MMAPFile> handle_;
|
||||||
|
|
||||||
|
char* Open(std::string path, std::size_t offset, std::size_t length);
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Construct a private mmap stream.
|
||||||
|
*
|
||||||
|
* @param path File path.
|
||||||
|
* @param offset See the `offset` parameter of `mmap` for details.
|
||||||
|
* @param length See the `length` parameter of `mmap` for details.
|
||||||
|
*/
|
||||||
|
explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length);
|
||||||
|
void Write(void const*, std::size_t) override { LOG(FATAL) << "Read-only stream."; }
|
||||||
|
|
||||||
|
~PrivateMmapConstStream() override;
|
||||||
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_COMMON_IO_H_
|
#endif // XGBOOST_COMMON_IO_H_
|
||||||
|
|||||||
@ -1,35 +1,34 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023, XGBoost Contributors
|
||||||
* \file sparse_page_source.h
|
* \file sparse_page_source.h
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
|
|
||||||
#include <algorithm> // std::min
|
#include <algorithm> // for min
|
||||||
#include <string>
|
#include <future> // async
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
#include <future>
|
|
||||||
#include <thread>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <thread>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/common.h"
|
||||||
|
#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP
|
||||||
|
#include "../common/timer.h" // for Monitor, Timer
|
||||||
|
#include "adapter.h"
|
||||||
|
#include "dmlc/common.h" // OMPException
|
||||||
|
#include "proxy_dmatrix.h"
|
||||||
|
#include "sparse_page_writer.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|
||||||
#include "adapter.h"
|
namespace xgboost::data {
|
||||||
#include "sparse_page_writer.h"
|
|
||||||
#include "proxy_dmatrix.h"
|
|
||||||
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/timer.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace data {
|
|
||||||
inline void TryDeleteCacheFile(const std::string& file) {
|
inline void TryDeleteCacheFile(const std::string& file) {
|
||||||
if (std::remove(file.c_str()) != 0) {
|
if (std::remove(file.c_str()) != 0) {
|
||||||
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||||
<< "; you may want to remove it manually";
|
<< "; you may want to remove it manually";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,6 +53,9 @@ struct Cache {
|
|||||||
std::string ShardName() {
|
std::string ShardName() {
|
||||||
return ShardName(this->name, this->format);
|
return ShardName(this->name, this->format);
|
||||||
}
|
}
|
||||||
|
void Push(std::size_t n_bytes) {
|
||||||
|
offset.push_back(n_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
// The write is completed.
|
// The write is completed.
|
||||||
void Commit() {
|
void Commit() {
|
||||||
@ -95,56 +97,72 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
uint32_t n_batches_ {0};
|
uint32_t n_batches_ {0};
|
||||||
|
|
||||||
std::shared_ptr<Cache> cache_info_;
|
std::shared_ptr<Cache> cache_info_;
|
||||||
std::unique_ptr<dmlc::Stream> fo_;
|
|
||||||
|
|
||||||
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
using Ring = std::vector<std::future<std::shared_ptr<S>>>;
|
||||||
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
|
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
|
||||||
// can pre-fetch data in a ring.
|
// can pre-fetch data in a ring.
|
||||||
std::unique_ptr<Ring> ring_{new Ring};
|
std::unique_ptr<Ring> ring_{new Ring};
|
||||||
|
dmlc::OMPException exec_;
|
||||||
|
common::Monitor monitor_;
|
||||||
|
|
||||||
bool ReadCache() {
|
bool ReadCache() {
|
||||||
CHECK(!at_end_);
|
CHECK(!at_end_);
|
||||||
if (!cache_info_->written) {
|
if (!cache_info_->written) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (fo_) {
|
if (ring_->empty()) {
|
||||||
fo_.reset(); // flush the data to disk.
|
|
||||||
ring_->resize(n_batches_);
|
ring_->resize(n_batches_);
|
||||||
}
|
}
|
||||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||||
// to let user adjust number of pre-fetched batches when needed.
|
// to let user adjust number of pre-fetched batches when needed.
|
||||||
uint32_t constexpr kPreFetch = 4;
|
uint32_t constexpr kPreFetch = 3;
|
||||||
|
|
||||||
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
|
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
|
||||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||||
size_t fetch_it = count_;
|
std::size_t fetch_it = count_;
|
||||||
|
|
||||||
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
exec_.Rethrow();
|
||||||
|
|
||||||
|
monitor_.Start("launch");
|
||||||
|
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||||
fetch_it %= n_batches_; // ring
|
fetch_it %= n_batches_; // ring
|
||||||
if (ring_->at(fetch_it).valid()) {
|
if (ring_->at(fetch_it).valid()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto const *self = this; // make sure it's const
|
auto const* self = this; // make sure it's const
|
||||||
CHECK_LT(fetch_it, cache_info_->offset.size());
|
CHECK_LT(fetch_it, cache_info_->offset.size());
|
||||||
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() {
|
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
|
||||||
common::Timer timer;
|
|
||||||
timer.Start();
|
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
|
||||||
auto n = self->cache_info_->ShardName();
|
|
||||||
size_t offset = self->cache_info_->offset.at(fetch_it);
|
|
||||||
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
|
|
||||||
fi->Seek(offset);
|
|
||||||
CHECK_EQ(fi->Tell(), offset);
|
|
||||||
auto page = std::make_shared<S>();
|
auto page = std::make_shared<S>();
|
||||||
CHECK(fmt->Read(page.get(), fi.get()));
|
this->exec_.Run([&] {
|
||||||
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds.";
|
common::Timer timer;
|
||||||
|
timer.Start();
|
||||||
|
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||||
|
auto n = self->cache_info_->ShardName();
|
||||||
|
|
||||||
|
std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
|
||||||
|
std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset;
|
||||||
|
|
||||||
|
auto fi = std::make_unique<common::PrivateMmapConstStream>(n, offset, length);
|
||||||
|
CHECK(fmt->Read(page.get(), fi.get()));
|
||||||
|
timer.Stop();
|
||||||
|
|
||||||
|
LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds()
|
||||||
|
<< " seconds.";
|
||||||
|
});
|
||||||
return page;
|
return page;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
monitor_.Stop("launch");
|
||||||
|
|
||||||
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
|
||||||
n_prefetch_batches)
|
n_prefetch_batches)
|
||||||
<< "Sparse DMatrix assumes forward iteration.";
|
<< "Sparse DMatrix assumes forward iteration.";
|
||||||
|
monitor_.Start("Wait");
|
||||||
page_ = (*ring_)[count_].get();
|
page_ = (*ring_)[count_].get();
|
||||||
|
monitor_.Stop("Wait");
|
||||||
|
CHECK(!(*ring_)[count_].valid());
|
||||||
|
exec_.Rethrow();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,25 +171,35 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
|||||||
common::Timer timer;
|
common::Timer timer;
|
||||||
timer.Start();
|
timer.Start();
|
||||||
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
|
||||||
if (!fo_) {
|
|
||||||
auto n = cache_info_->ShardName();
|
|
||||||
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
|
|
||||||
}
|
|
||||||
auto bytes = fmt->Write(*page_, fo_.get());
|
|
||||||
timer.Stop();
|
|
||||||
|
|
||||||
|
auto name = cache_info_->ShardName();
|
||||||
|
std::unique_ptr<dmlc::Stream> fo;
|
||||||
|
if (this->Iter() == 0) {
|
||||||
|
fo.reset(dmlc::Stream::Create(name.c_str(), "wb"));
|
||||||
|
} else {
|
||||||
|
fo.reset(dmlc::Stream::Create(name.c_str(), "ab"));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto bytes = fmt->Write(*page_, fo.get());
|
||||||
|
|
||||||
|
timer.Stop();
|
||||||
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
|
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
|
||||||
<< timer.ElapsedSeconds() << " seconds.";
|
<< timer.ElapsedSeconds() << " seconds.";
|
||||||
cache_info_->offset.push_back(bytes);
|
cache_info_->Push(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void Fetch() = 0;
|
virtual void Fetch() = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features,
|
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
|
||||||
uint32_t n_batches, std::shared_ptr<Cache> cache)
|
std::shared_ptr<Cache> cache)
|
||||||
: missing_{missing}, nthreads_{nthreads}, n_features_{n_features},
|
: missing_{missing},
|
||||||
n_batches_{n_batches}, cache_info_{std::move(cache)} {}
|
nthreads_{nthreads},
|
||||||
|
n_features_{n_features},
|
||||||
|
n_batches_{n_batches},
|
||||||
|
cache_info_{std::move(cache)} {
|
||||||
|
monitor_.Init(typeid(S).name()); // not pretty, but works for basic profiling
|
||||||
|
}
|
||||||
|
|
||||||
SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
|
SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
|
||||||
|
|
||||||
@ -244,7 +272,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
iter_{iter}, proxy_{proxy} {
|
iter_{iter}, proxy_{proxy} {
|
||||||
if (!cache_info_->written) {
|
if (!cache_info_->written) {
|
||||||
iter_.Reset();
|
iter_.Reset();
|
||||||
CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch.";
|
CHECK(iter_.Next()) << "Must have at least 1 batch.";
|
||||||
}
|
}
|
||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
@ -259,6 +287,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (at_end_) {
|
if (at_end_) {
|
||||||
|
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
|
||||||
cache_info_->Commit();
|
cache_info_->Commit();
|
||||||
if (n_batches_ != 0) {
|
if (n_batches_ != 0) {
|
||||||
CHECK_EQ(count_, n_batches_);
|
CHECK_EQ(count_, n_batches_);
|
||||||
@ -371,6 +400,5 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
|
|||||||
this->Fetch();
|
this->Fetch();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace data
|
} // namespace xgboost::data
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
|
|||||||
@ -146,27 +146,30 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
|
|||||||
CombineGradientPair combine_;
|
CombineGradientPair combine_;
|
||||||
};
|
};
|
||||||
|
|
||||||
NoSampling::NoSampling(EllpackPageImpl const* page) : page_(page) {}
|
NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
|
||||||
|
|
||||||
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
GradientBasedSample NoSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
return {dmat->Info().num_row_, page_, gpair};
|
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
||||||
|
return {dmat->Info().num_row_, page, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page,
|
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
||||||
size_t n_rows, BatchParam batch_param)
|
: batch_param_{std::move(batch_param)} {}
|
||||||
: batch_param_{std::move(batch_param)},
|
|
||||||
page_(new EllpackPageImpl(ctx->gpu_id, page->Cuts(), page->is_dense, page->row_stride,
|
|
||||||
n_rows)) {}
|
|
||||||
|
|
||||||
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
||||||
common::Span<GradientPair> gpair,
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
if (!page_concatenated_) {
|
if (!page_concatenated_) {
|
||||||
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
|
||||||
|
page_.reset(nullptr);
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
|
||||||
auto page = batch.Impl();
|
auto page = batch.Impl();
|
||||||
|
if (!page_) {
|
||||||
|
page_ = std::make_unique<EllpackPageImpl>(ctx->gpu_id, page->Cuts(), page->is_dense,
|
||||||
|
page->row_stride, dmat->Info().num_row_);
|
||||||
|
}
|
||||||
size_t num_elements = page_->Copy(ctx->gpu_id, page, offset);
|
size_t num_elements = page_->Copy(ctx->gpu_id, page, offset);
|
||||||
offset += num_elements;
|
offset += num_elements;
|
||||||
}
|
}
|
||||||
@ -175,8 +178,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
|
|||||||
return {dmat->Info().num_row_, page_.get(), gpair};
|
return {dmat->Info().num_row_, page_.get(), gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample)
|
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
|
||||||
: page_(page), subsample_(subsample) {}
|
: batch_param_{std::move(batch_param)}, subsample_(subsample) {}
|
||||||
|
|
||||||
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
@ -185,7 +188,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
|
|||||||
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<std::size_t>(0),
|
thrust::counting_iterator<std::size_t>(0),
|
||||||
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
|
||||||
return {dmat->Info().num_row_, page_, gpair};
|
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
||||||
|
return {dmat->Info().num_row_, page, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
|
||||||
@ -236,12 +240,10 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
|
|||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl const* page,
|
GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
|
||||||
size_t n_rows,
|
|
||||||
const BatchParam&,
|
|
||||||
float subsample)
|
float subsample)
|
||||||
: page_(page),
|
: subsample_(subsample),
|
||||||
subsample_(subsample),
|
batch_param_{std::move(batch_param)},
|
||||||
threshold_(n_rows + 1, 0.0f),
|
threshold_(n_rows + 1, 0.0f),
|
||||||
grad_sum_(n_rows, 0.0f) {}
|
grad_sum_(n_rows, 0.0f) {}
|
||||||
|
|
||||||
@ -252,18 +254,19 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
|||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
|
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
|
||||||
|
|
||||||
// Perform Poisson sampling in place.
|
// Perform Poisson sampling in place.
|
||||||
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
return {n_rows, page_, gpair};
|
return {n_rows, page, gpair};
|
||||||
}
|
}
|
||||||
|
|
||||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
||||||
size_t n_rows,
|
BatchParam batch_param,
|
||||||
BatchParam batch_param,
|
float subsample)
|
||||||
float subsample)
|
|
||||||
: batch_param_(std::move(batch_param)),
|
: batch_param_(std::move(batch_param)),
|
||||||
subsample_(subsample),
|
subsample_(subsample),
|
||||||
threshold_(n_rows + 1, 0.0f),
|
threshold_(n_rows + 1, 0.0f),
|
||||||
@ -273,16 +276,15 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
|
|||||||
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx,
|
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx,
|
||||||
common::Span<GradientPair> gpair,
|
common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) {
|
DMatrix* dmat) {
|
||||||
size_t n_rows = dmat->Info().num_row_;
|
auto cuctx = ctx->CUDACtx();
|
||||||
|
bst_row_t n_rows = dmat->Info().num_row_;
|
||||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||||
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||||
|
|
||||||
// Perform Poisson sampling in place.
|
// Perform Poisson sampling in place.
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
|
||||||
thrust::counting_iterator<size_t>(0),
|
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||||
dh::tbegin(gpair),
|
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||||
PoissonSampling(dh::ToSpan(threshold_),
|
|
||||||
threshold_index,
|
|
||||||
RandomWeight(common::GlobalRandom()())));
|
RandomWeight(common::GlobalRandom()())));
|
||||||
|
|
||||||
// Count the sampled rows.
|
// Count the sampled rows.
|
||||||
@ -290,16 +292,15 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
|
|
||||||
// Compact gradient pairs.
|
// Compact gradient pairs.
|
||||||
gpair_.resize(sample_rows);
|
gpair_.resize(sample_rows);
|
||||||
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
|
||||||
|
|
||||||
// Index the sample rows.
|
// Index the sample rows.
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
|
IsNonZero());
|
||||||
sample_row_index_.begin());
|
thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
|
sample_row_index_.begin());
|
||||||
sample_row_index_.begin(),
|
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
|
||||||
sample_row_index_.begin(),
|
sample_row_index_.begin(), ClearEmptyRows());
|
||||||
ClearEmptyRows());
|
|
||||||
|
|
||||||
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
|
||||||
auto first_page = (*batch_iterator.begin()).Impl();
|
auto first_page = (*batch_iterator.begin()).Impl();
|
||||||
@ -317,13 +318,13 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
|||||||
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page,
|
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
||||||
size_t n_rows, const BatchParam& batch_param,
|
const BatchParam& batch_param, float subsample,
|
||||||
float subsample, int sampling_method) {
|
int sampling_method, bool is_external_memory) {
|
||||||
|
// The ctx is kept here for future development of stream-based operations.
|
||||||
monitor_.Init("gradient_based_sampler");
|
monitor_.Init("gradient_based_sampler");
|
||||||
|
|
||||||
bool is_sampling = subsample < 1.0;
|
bool is_sampling = subsample < 1.0;
|
||||||
bool is_external_memory = page->n_rows != n_rows;
|
|
||||||
|
|
||||||
if (is_sampling) {
|
if (is_sampling) {
|
||||||
switch (sampling_method) {
|
switch (sampling_method) {
|
||||||
@ -331,24 +332,24 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c
|
|||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
|
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new UniformSampling(page, subsample));
|
strategy_.reset(new UniformSampling(batch_param, subsample));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case TrainParam::kGradientBased:
|
case TrainParam::kGradientBased:
|
||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(
|
strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
|
||||||
new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
|
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
|
strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:LOG(FATAL) << "unknown sampling method";
|
default:
|
||||||
|
LOG(FATAL) << "unknown sampling method";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_external_memory) {
|
if (is_external_memory) {
|
||||||
strategy_.reset(new ExternalMemoryNoSampling(ctx, page, n_rows, batch_param));
|
strategy_.reset(new ExternalMemoryNoSampling(batch_param));
|
||||||
} else {
|
} else {
|
||||||
strategy_.reset(new NoSampling(page));
|
strategy_.reset(new NoSampling(batch_param));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -362,11 +363,11 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
|
|||||||
return sample;
|
return sample;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GradientBasedSampler::CalculateThresholdIndex(
|
size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
|
||||||
common::Span<GradientPair> gpair, common::Span<float> threshold,
|
common::Span<float> threshold,
|
||||||
common::Span<float> grad_sum, size_t sample_rows) {
|
common::Span<float> grad_sum,
|
||||||
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold),
|
size_t sample_rows) {
|
||||||
std::numeric_limits<float>::max());
|
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
|
||||||
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
|
||||||
CombineGradientPair());
|
CombineGradientPair());
|
||||||
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
|
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
|
||||||
@ -379,6 +380,5 @@ size_t GradientBasedSampler::CalculateThresholdIndex(
|
|||||||
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
|
||||||
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // namespace tree
|
}; // namespace tree
|
||||||
}; // namespace xgboost
|
}; // namespace xgboost
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2019 by XGBoost Contributors
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
@ -32,37 +32,36 @@ class SamplingStrategy {
|
|||||||
/*! \brief No sampling in in-memory mode. */
|
/*! \brief No sampling in in-memory mode. */
|
||||||
class NoSampling : public SamplingStrategy {
|
class NoSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
explicit NoSampling(EllpackPageImpl const* page);
|
explicit NoSampling(BatchParam batch_param);
|
||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
|
||||||
DMatrix* dmat) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
EllpackPageImpl const* page_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief No sampling in external memory mode. */
|
|
||||||
class ExternalMemoryNoSampling : public SamplingStrategy {
|
|
||||||
public:
|
|
||||||
ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
|
|
||||||
BatchParam batch_param);
|
|
||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) override;
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_;
|
};
|
||||||
|
|
||||||
|
/*! \brief No sampling in external memory mode. */
|
||||||
|
class ExternalMemoryNoSampling : public SamplingStrategy {
|
||||||
|
public:
|
||||||
|
explicit ExternalMemoryNoSampling(BatchParam batch_param);
|
||||||
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
BatchParam batch_param_;
|
||||||
|
std::unique_ptr<EllpackPageImpl> page_{nullptr};
|
||||||
bool page_concatenated_{false};
|
bool page_concatenated_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Uniform sampling in in-memory mode. */
|
/*! \brief Uniform sampling in in-memory mode. */
|
||||||
class UniformSampling : public SamplingStrategy {
|
class UniformSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
UniformSampling(EllpackPageImpl const* page, float subsample);
|
UniformSampling(BatchParam batch_param, float subsample);
|
||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) override;
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* page_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -84,13 +83,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
|
|||||||
/*! \brief Gradient-based sampling in in-memory mode.. */
|
/*! \brief Gradient-based sampling in in-memory mode.. */
|
||||||
class GradientBasedSampling : public SamplingStrategy {
|
class GradientBasedSampling : public SamplingStrategy {
|
||||||
public:
|
public:
|
||||||
GradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, const BatchParam& batch_param,
|
GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, float subsample);
|
||||||
float subsample);
|
|
||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||||
DMatrix* dmat) override;
|
DMatrix* dmat) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EllpackPageImpl const* page_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
dh::caching_device_vector<float> threshold_;
|
dh::caching_device_vector<float> threshold_;
|
||||||
dh::caching_device_vector<float> grad_sum_;
|
dh::caching_device_vector<float> grad_sum_;
|
||||||
@ -106,11 +104,11 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
|||||||
private:
|
private:
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
float subsample_;
|
float subsample_;
|
||||||
dh::caching_device_vector<float> threshold_;
|
dh::device_vector<float> threshold_;
|
||||||
dh::caching_device_vector<float> grad_sum_;
|
dh::device_vector<float> grad_sum_;
|
||||||
std::unique_ptr<EllpackPageImpl> page_;
|
std::unique_ptr<EllpackPageImpl> page_;
|
||||||
dh::device_vector<GradientPair> gpair_;
|
dh::device_vector<GradientPair> gpair_;
|
||||||
dh::caching_device_vector<size_t> sample_row_index_;
|
dh::device_vector<size_t> sample_row_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief Draw a sample of rows from a DMatrix.
|
/*! \brief Draw a sample of rows from a DMatrix.
|
||||||
@ -124,8 +122,8 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
|||||||
*/
|
*/
|
||||||
class GradientBasedSampler {
|
class GradientBasedSampler {
|
||||||
public:
|
public:
|
||||||
GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
|
GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param,
|
||||||
const BatchParam& batch_param, float subsample, int sampling_method);
|
float subsample, int sampling_method, bool is_external_memory);
|
||||||
|
|
||||||
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
||||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
||||||
|
|||||||
@ -176,7 +176,7 @@ struct GPUHistMakerDevice {
|
|||||||
Context const* ctx_;
|
Context const* ctx_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EllpackPageImpl const* page;
|
EllpackPageImpl const* page{nullptr};
|
||||||
common::Span<FeatureType const> feature_types;
|
common::Span<FeatureType const> feature_types;
|
||||||
BatchParam batch_param;
|
BatchParam batch_param;
|
||||||
|
|
||||||
@ -205,41 +205,41 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
std::unique_ptr<FeatureGroups> feature_groups;
|
std::unique_ptr<FeatureGroups> feature_groups;
|
||||||
|
|
||||||
|
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
|
||||||
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
|
common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
|
||||||
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
|
|
||||||
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
|
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
|
||||||
BatchParam _batch_param)
|
BatchParam _batch_param)
|
||||||
: evaluator_{_param, n_features, ctx->gpu_id},
|
: evaluator_{_param, n_features, ctx->gpu_id},
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
page(_page),
|
|
||||||
feature_types{_feature_types},
|
feature_types{_feature_types},
|
||||||
param(std::move(_param)),
|
param(std::move(_param)),
|
||||||
column_sampler(column_sampler_seed),
|
column_sampler(column_sampler_seed),
|
||||||
interaction_constraints(param, n_features),
|
interaction_constraints(param, n_features),
|
||||||
batch_param(std::move(_batch_param)) {
|
batch_param(std::move(_batch_param)) {
|
||||||
sampler.reset(new GradientBasedSampler(ctx, page, _n_rows, batch_param, param.subsample,
|
sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample,
|
||||||
param.sampling_method));
|
param.sampling_method, is_external_memory));
|
||||||
if (!param.monotone_constraints.empty()) {
|
if (!param.monotone_constraints.empty()) {
|
||||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||||
monotone_constraints = param.monotone_constraints;
|
monotone_constraints = param.monotone_constraints;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init histogram
|
|
||||||
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
|
|
||||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
|
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
|
||||||
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
|
|
||||||
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
|
|
||||||
sizeof(GradientSumT)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~GPUHistMakerDevice() { // NOLINT
|
~GPUHistMakerDevice() { // NOLINT
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitFeatureGroupsOnce() {
|
||||||
|
if (!feature_groups) {
|
||||||
|
CHECK(page);
|
||||||
|
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
|
||||||
|
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
|
||||||
|
sizeof(GradientSumT)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reset values for each update iteration
|
// Reset values for each update iteration
|
||||||
// Note that the column sampler must be passed by value because it is not
|
|
||||||
// thread safe
|
|
||||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||||
auto const& info = dmat->Info();
|
auto const& info = dmat->Info();
|
||||||
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
|
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
|
||||||
@ -247,26 +247,30 @@ struct GPUHistMakerDevice {
|
|||||||
param.colsample_bytree);
|
param.colsample_bytree);
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
|
|
||||||
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
|
|
||||||
ctx_->gpu_id);
|
|
||||||
|
|
||||||
this->interaction_constraints.Reset();
|
this->interaction_constraints.Reset();
|
||||||
|
|
||||||
if (d_gpair.size() != dh_gpair->Size()) {
|
if (d_gpair.size() != dh_gpair->Size()) {
|
||||||
d_gpair.resize(dh_gpair->Size());
|
d_gpair.resize(dh_gpair->Size());
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaMemcpyAsync(
|
dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
||||||
d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
|
dh_gpair->Size() * sizeof(GradientPair),
|
||||||
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice));
|
cudaMemcpyDeviceToDevice));
|
||||||
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat);
|
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat);
|
||||||
page = sample.page;
|
page = sample.page;
|
||||||
gpair = sample.gpair;
|
gpair = sample.gpair;
|
||||||
|
|
||||||
|
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
|
||||||
|
|
||||||
quantiser.reset(new GradientQuantiser(this->gpair));
|
quantiser.reset(new GradientQuantiser(this->gpair));
|
||||||
|
|
||||||
row_partitioner.reset(); // Release the device memory first before reallocating
|
row_partitioner.reset(); // Release the device memory first before reallocating
|
||||||
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
|
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
|
||||||
|
|
||||||
|
// Init histogram
|
||||||
|
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
|
||||||
hist.Reset();
|
hist.Reset();
|
||||||
|
|
||||||
|
this->InitFeatureGroupsOnce();
|
||||||
}
|
}
|
||||||
|
|
||||||
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
|
||||||
@ -808,12 +812,11 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
|
||||||
|
|
||||||
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
|
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
|
||||||
auto page = (*dmat->GetBatches<EllpackPage>(ctx_, batch_param).begin()).Impl();
|
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
info_->feature_types.SetDevice(ctx_->gpu_id);
|
info_->feature_types.SetDevice(ctx_->gpu_id);
|
||||||
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
maker.reset(new GPUHistMakerDevice<GradientSumT>(
|
||||||
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param,
|
ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
|
||||||
column_sampling_seed, info_->num_col_, batch_param));
|
*param, column_sampling_seed, info_->num_col_, batch_param));
|
||||||
|
|
||||||
p_last_fmat_ = dmat;
|
p_last_fmat_ = dmat;
|
||||||
initialised_ = true;
|
initialised_ = true;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) by XGBoost Contributors 2019
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
@ -9,8 +9,7 @@
|
|||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
TEST(MemoryFixSizeBuffer, Seek) {
|
TEST(MemoryFixSizeBuffer, Seek) {
|
||||||
size_t constexpr kSize { 64 };
|
size_t constexpr kSize { 64 };
|
||||||
std::vector<int32_t> memory( kSize );
|
std::vector<int32_t> memory( kSize );
|
||||||
@ -89,5 +88,54 @@ TEST(IO, LoadSequentialFile) {
|
|||||||
|
|
||||||
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
|
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
|
||||||
}
|
}
|
||||||
} // namespace common
|
|
||||||
} // namespace xgboost
|
TEST(IO, PrivateMmapStream) {
|
||||||
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
auto path = tempdir.path + "/testfile";
|
||||||
|
|
||||||
|
// The page size on Linux is usually set to 4096, while the allocation granularity on
|
||||||
|
// the Windows machine where this test is writted is 65536. We span the test to cover
|
||||||
|
// all of them.
|
||||||
|
std::size_t n_batches{64};
|
||||||
|
std::size_t multiplier{2048};
|
||||||
|
|
||||||
|
std::vector<std::vector<std::int32_t>> batches;
|
||||||
|
std::vector<std::size_t> offset{0ul};
|
||||||
|
|
||||||
|
using T = std::int32_t;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
|
||||||
|
for (std::size_t i = 0; i < n_batches; ++i) {
|
||||||
|
std::size_t size = (i + 1) * multiplier;
|
||||||
|
std::vector<T> data(size, 0);
|
||||||
|
std::iota(data.begin(), data.end(), i * i);
|
||||||
|
|
||||||
|
fo->Write(static_cast<std::uint64_t>(data.size()));
|
||||||
|
fo->Write(data.data(), data.size() * sizeof(T));
|
||||||
|
|
||||||
|
std::size_t bytes = sizeof(std::uint64_t) + data.size() * sizeof(T);
|
||||||
|
offset.push_back(bytes);
|
||||||
|
|
||||||
|
batches.emplace_back(std::move(data));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turn size info offset
|
||||||
|
std::partial_sum(offset.begin(), offset.end(), offset.begin());
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < n_batches; ++i) {
|
||||||
|
std::size_t off = offset[i];
|
||||||
|
std::size_t n = offset.at(i + 1) - offset[i];
|
||||||
|
std::unique_ptr<dmlc::Stream> fi{std::make_unique<PrivateMmapConstStream>(path, off, n)};
|
||||||
|
std::vector<T> data;
|
||||||
|
|
||||||
|
std::uint64_t size{0};
|
||||||
|
fi->Read(&size);
|
||||||
|
data.resize(size);
|
||||||
|
|
||||||
|
fi->Read(data.data(), size * sizeof(T));
|
||||||
|
ASSERT_EQ(data, batches[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost::common
|
||||||
|
|||||||
@ -2,6 +2,10 @@
|
|||||||
#include "../../src/data/ellpack_page.cuh"
|
#include "../../src/data/ellpack_page.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <xgboost/data.h> // for SparsePage
|
||||||
|
|
||||||
|
#include "./helpers.h" // for RandomDataGenerator
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
namespace {
|
namespace {
|
||||||
|
|||||||
@ -39,7 +39,8 @@ void VerifySampling(size_t page_size,
|
|||||||
EXPECT_NE(page->n_rows, kRows);
|
EXPECT_NE(page->n_rows, kRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
GradientBasedSampler sampler(&ctx, page, kRows, param, subsample, sampling_method);
|
GradientBasedSampler sampler(&ctx, kRows, param, subsample, sampling_method,
|
||||||
|
!fixed_size_sampling);
|
||||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
|
|
||||||
if (fixed_size_sampling) {
|
if (fixed_size_sampling) {
|
||||||
@ -93,7 +94,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
|||||||
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||||
EXPECT_NE(page->n_rows, kRows);
|
EXPECT_NE(page->n_rows, kRows);
|
||||||
|
|
||||||
GradientBasedSampler sampler(&ctx, page, kRows, param, kSubsample, TrainParam::kUniform);
|
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||||
auto sampled_page = sample.page;
|
auto sampled_page = sample.page;
|
||||||
EXPECT_EQ(sample.sample_rows, kRows);
|
EXPECT_EQ(sample.sample_rows, kRows);
|
||||||
@ -141,7 +142,8 @@ TEST(GradientBasedSampler, GradientBasedSampling) {
|
|||||||
constexpr size_t kPageSize = 0;
|
constexpr size_t kPageSize = 0;
|
||||||
constexpr float kSubsample = 0.8;
|
constexpr float kSubsample = 0.8;
|
||||||
constexpr int kSamplingMethod = TrainParam::kGradientBased;
|
constexpr int kSamplingMethod = TrainParam::kGradientBased;
|
||||||
VerifySampling(kPageSize, kSubsample, kSamplingMethod);
|
constexpr bool kFixedSizeSampling = true;
|
||||||
|
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {
|
TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {
|
||||||
|
|||||||
@ -92,8 +92,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
auto page = BuildEllpackPage(kNRows, kNCols);
|
auto page = BuildEllpackPage(kNRows, kNCols);
|
||||||
BatchParam batch_param{};
|
BatchParam batch_param{};
|
||||||
Context ctx{MakeCUDACtx(0)};
|
Context ctx{MakeCUDACtx(0)};
|
||||||
GPUHistMakerDevice<GradientSumT> maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols,
|
GPUHistMakerDevice<GradientSumT> maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param,
|
||||||
batch_param);
|
kNCols, kNCols, batch_param);
|
||||||
xgboost::SimpleLCG gen;
|
xgboost::SimpleLCG gen;
|
||||||
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
|
||||||
HostDeviceVector<GradientPair> gpair(kNRows);
|
HostDeviceVector<GradientPair> gpair(kNRows);
|
||||||
@ -106,9 +106,15 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
|
|
||||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
|
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
|
||||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||||
|
|
||||||
|
maker.hist.Init(0, page->Cuts().TotalBins());
|
||||||
maker.hist.AllocateHistograms({0});
|
maker.hist.AllocateHistograms({0});
|
||||||
|
|
||||||
maker.gpair = gpair.DeviceSpan();
|
maker.gpair = gpair.DeviceSpan();
|
||||||
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
|
maker.quantiser.reset(new GradientQuantiser(maker.gpair));
|
||||||
|
maker.page = page.get();
|
||||||
|
|
||||||
|
maker.InitFeatureGroupsOnce();
|
||||||
|
|
||||||
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
|
||||||
maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(),
|
maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(),
|
||||||
@ -126,8 +132,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
|||||||
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
std::vector<GradientPairPrecise> solution = GetHostHistGpair();
|
||||||
for (size_t i = 0; i < h_result.size(); ++i) {
|
for (size_t i = 0; i < h_result.size(); ++i) {
|
||||||
auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
|
auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
|
||||||
EXPECT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
|
||||||
EXPECT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -305,7 +305,7 @@ class IterForDMatrixTest(xgb.core.DataIter):
|
|||||||
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
||||||
|
|
||||||
self.it = 0 # set iterator to 0
|
self.it = 0 # set iterator to 0
|
||||||
super().__init__()
|
super().__init__(cache_prefix=None)
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
import cudf
|
import cudf
|
||||||
|
|||||||
@ -64,7 +64,8 @@ def run_data_iterator(
|
|||||||
subsample_rate = 0.8 if subsample else 1.0
|
subsample_rate = 0.8 if subsample else 1.0
|
||||||
|
|
||||||
it = IteratorForTest(
|
it = IteratorForTest(
|
||||||
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy)
|
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
||||||
|
cache="cache"
|
||||||
)
|
)
|
||||||
if n_batches == 0:
|
if n_batches == 0:
|
||||||
with pytest.raises(ValueError, match="1 batch"):
|
with pytest.raises(ValueError, match="1 batch"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user