Use mmap for external memory. (#9282)

- Have basic infrastructure for mmap.
- Release file write handle.
This commit is contained in:
Jiaming Yuan 2023-06-19 18:52:55 +08:00 committed by GitHub
parent d8beb517ed
commit ee6809e642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 599 additions and 275 deletions

View File

@ -82,10 +82,10 @@ def main(tmpdir: str) -> xgboost.Booster:
missing = np.NaN missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in # Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU
# doc for details. # behaves differently than CPU tree methods. See tutorial in doc for details.
booster = xgboost.train( booster = xgboost.train(
{"tree_method": "approx", "max_depth": 2}, {"tree_method": "hist", "max_depth": 4},
Xy, Xy,
evals=[(Xy, "Train")], evals=[(Xy, "Train")],
num_boost_round=10, num_boost_round=10,

View File

@ -2,11 +2,25 @@
Using XGBoost External Memory Version Using XGBoost External Memory Version
##################################### #####################################
XGBoost supports loading data from external memory using builtin data parser. And When working with large datasets, training XGBoost models can be challenging as the entire
starting from version 1.5, users can also define a custom iterator to load data in chunks. dataset needs to be loaded into memory. This can be costly and sometimes
The feature is still experimental and not yet ready for production use. In this tutorial infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks
we will introduce both methods. Please note that training on data from external memory is for running XGBoost algorithms. External memory can be used for both training and
not supported by ``exact`` tree method. prediction, but training is the primary use case and it will be our focus in this
tutorial. For prediction and evaluation, users can iterate through the data themseleves
while training requires the full dataset to be loaded into the memory.
During training, there are two different modes for external memory support available in
XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the
GPU-based training algorithm. We will introduce them in the following sections.
.. note::
Training on data from external memory is not supported by the ``exact`` tree method.
.. note::
The feature is still experimental as of 2.0. The performance is not well optimized.
************* *************
Data Iterator Data Iterator
@ -15,8 +29,8 @@ Data Iterator
Starting from XGBoost 1.5, users can define their own data loader using Python or C Starting from XGBoost 1.5, users can define their own data loader using Python or C
interface. There are some examples in the ``demo`` directory for quick start. This is a interface. There are some examples in the ``demo`` directory for quick start. This is a
generalized version of text input external memory, where users no longer need to prepare a generalized version of text input external memory, where users no longer need to prepare a
text file that XGBoost recognizes. To enable the feature, user need to define a data text file that XGBoost recognizes. To enable the feature, users need to define a data
iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix`` iterator with 2 class methods: ``next`` and ``reset``, then pass it into the ``DMatrix``
constructor. constructor.
.. code-block:: python .. code-block:: python
@ -60,20 +74,96 @@ constructor.
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
# as noted in following sections. # as noted in following sections.
booster = xgboost.train({"tree_method": "approx"}, Xy) booster = xgboost.train({"tree_method": "hist"}, Xy)
The above snippet is a simplified version of ``demo/guide-python/external_memory.py``. For The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`.
an example in C, please see ``demo/c-api/external-memory/``. For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the
common interface for using external memory with XGBoost, you can pass the resulting
``DMatrix`` object for training, prediction, and evaluation.
It is important to set the batch size based on the memory available. A good starting point
is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not*
recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt
performance in gradient boosting.
***********
CPU Version
***********
In the previous section, we demonstrated how to train a tree-based model using the
``hist`` tree method on a CPU. This method involves iterating through data batches stored
in a cache during tree construction. For optimal performance, we recommend using the
``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree
nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy
requires XGBoost to iterate over the data set for each tree node, resulting in slower
performance.
If external memory is used, the performance of CPU training is limited by IO
(input/output) speed. This means that the disk IO speed primarily determines the training
speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of
storage can be too slow for practical usage. In addition, your system may perform caching
to reduce the overhead of file reading.
**********************************
GPU Version (GPU Hist tree method)
**********************************
External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set to
``gpu_hist``). However, the algorithm used for GPU is different from the one used for
CPU. When training on a CPU, the tree method iterates through all batches from external
memory for each step of the tree construction algorithm. On the other hand, the GPU
algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall
memory usage, users can utilize subsampling. The good news is that the GPU hist tree
method supports gradient-based sampling, enabling users to set a low sampling rate without
compromising accuracy.
.. code-block:: python
param = {
...
'subsample': 0.2,
'sampling_method': 'gradient_based',
}
For more information about the sampling algorithm and its use in external memory training,
see `this paper <https://arxiv.org/abs/2005.09148>`_.
.. warning::
When GPU is running out of memory during iteration on external memory, user might
recieve a segfault instead of an OOM exception.
*******
Remarks
*******
When using external memory with XBGoost, data is divided into smaller chunks so that only
a fraction of it needs to be stored in memory at any given time. It's important to note
that this method only applies to the predictor data (``X``), while other data, like labels
and internal runtime structures are concatenated. This means that memory reduction is most
effective when dealing with wide datasets where ``X`` is larger compared to other data
like ``y``, while it has little impact on slim datasets.
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's
worth noting that most tests have been conducted on Linux distributions.
Another important point to keep in mind is that creating the initial cache for XGBoost may
take some time. The interface to external memory is through custom iterators, which may or
may not be thread-safe. Therefore, initialization is performed sequentially.
**************** ****************
Text File Inputs Text File Inputs
**************** ****************
There is no big difference between using external memory version and in-memory version. This is the original form of external memory support, users are encouraged to use custom
The only difference is the filename format. data iterator instead. There is no big difference between using external memory version of
text input and the in-memory version. The only difference is the filename format.
The external memory version takes in the following `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format: The external memory version takes in the following `URI
<https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:
.. code-block:: none .. code-block:: none
@ -91,9 +181,8 @@ To load from csv files, use the following syntax:
where ``label_column`` should point to the csv column acting as the label. where ``label_column`` should point to the csv column acting as the label.
To provide a simple example for illustration, extracting the code from If you have a dataset stored in a file similar to ``demo/data/agaricus.txt.train`` with LIBSVM
`demo/guide-python/external_memory.py <https://github.com/dmlc/xgboost/blob/master/demo/guide-python/external_memory.py>`_. If format, the external memory support can be enabled by:
you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSVM format, the external memory support can be enabled by:
.. code-block:: python .. code-block:: python
@ -104,35 +193,3 @@ XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to
more notes about text input formats, see :doc:`/tutorials/input_format`. more notes about text input formats, see :doc:`/tutorials/input_format`.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``. For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
**********************************
GPU Version (GPU Hist tree method)
**********************************
External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``).
If you are still getting out-of-memory errors after enabling external memory, try subsampling the
data to further reduce GPU memory usage:
.. code-block:: python
param = {
...
'subsample': 0.1,
'sampling_method': 'gradient_based',
}
For more information, see `this paper <https://arxiv.org/abs/2005.09148>`_. Internally
the tree method still concatenate all the chunks into 1 final histogram index due to
performance reason, but in compressed format. So its scalability has an upper bound but
still has lower memory cost in general.
***********
CPU Version
***********
For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use
``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow,
with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few
iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each
tree node.

View File

@ -198,14 +198,14 @@ class IteratorForTest(xgb.core.DataIter):
X: Sequence, X: Sequence,
y: Sequence, y: Sequence,
w: Optional[Sequence], w: Optional[Sequence],
cache: Optional[str] = "./", cache: Optional[str],
) -> None: ) -> None:
assert len(X) == len(y) assert len(X) == len(y)
self.X = X self.X = X
self.y = y self.y = y
self.w = w self.w = w
self.it = 0 self.it = 0
super().__init__(cache) super().__init__(cache_prefix=cache)
def next(self, input_data: Callable) -> int: def next(self, input_data: Callable) -> int:
if self.it == len(self.X): if self.it == len(self.X):
@ -347,7 +347,9 @@ class TestDataset:
if w is not None: if w is not None:
weight.append(w) weight.append(w)
it = IteratorForTest(predictor, response, weight if weight else None) it = IteratorForTest(
predictor, response, weight if weight else None, cache="cache"
)
return xgb.DMatrix(it) return xgb.DMatrix(it)
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@ -1,18 +1,21 @@
/*! /**
* Copyright (c) 2014-2019 by Contributors * Copyright 2014-2023, XGBoost Contributors
* \file io.h * \file io.h
* \brief utilities with different serializable implementations * \brief utilities with different serializable implementations
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef RABIT_INTERNAL_IO_H_ #ifndef RABIT_INTERNAL_IO_H_
#define RABIT_INTERNAL_IO_H_ #define RABIT_INTERNAL_IO_H_
#include <cstdio>
#include <vector>
#include <cstring>
#include <string>
#include <algorithm> #include <algorithm>
#include <numeric> #include <cstddef> // for size_t
#include <cstdio>
#include <cstring> // for memcpy
#include <limits> #include <limits>
#include <numeric>
#include <string>
#include <vector>
#include "rabit/internal/utils.h" #include "rabit/internal/utils.h"
#include "rabit/serializable.h" #include "rabit/serializable.h"
@ -20,54 +23,61 @@ namespace rabit {
namespace utils { namespace utils {
/*! \brief re-use definition of dmlc::SeekStream */ /*! \brief re-use definition of dmlc::SeekStream */
using SeekStream = dmlc::SeekStream; using SeekStream = dmlc::SeekStream;
/*! \brief fixed size memory buffer */ /**
* @brief Fixed size memory buffer as a stream.
*/
struct MemoryFixSizeBuffer : public SeekStream { struct MemoryFixSizeBuffer : public SeekStream {
public: public:
// similar to SEEK_END in libc // similar to SEEK_END in libc
static size_t constexpr kSeekEnd = std::numeric_limits<size_t>::max(); static std::size_t constexpr kSeekEnd = std::numeric_limits<std::size_t>::max();
protected:
MemoryFixSizeBuffer() = default;
public: public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) /**
: p_buffer_(reinterpret_cast<char*>(p_buffer)), * @brief Ctor
buffer_size_(buffer_size) { *
curr_ptr_ = 0; * @param p_buffer Pointer to the source buffer with size `buffer_size`.
} * @param buffer_size Size of the source buffer
*/
MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size)
: p_buffer_(reinterpret_cast<char *>(p_buffer)), buffer_size_(buffer_size) {}
~MemoryFixSizeBuffer() override = default; ~MemoryFixSizeBuffer() override = default;
size_t Read(void *ptr, size_t size) override {
size_t nread = std::min(buffer_size_ - curr_ptr_, size); std::size_t Read(void *ptr, std::size_t size) override {
std::size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread; curr_ptr_ += nread;
return nread; return nread;
} }
void Write(const void *ptr, size_t size) override { void Write(const void *ptr, std::size_t size) override {
if (size == 0) return; if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_, CHECK_LE(curr_ptr_ + size, buffer_size_);
"write position exceed fixed buffer size");
std::memcpy(p_buffer_ + curr_ptr_, ptr, size); std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size; curr_ptr_ += size;
} }
void Seek(size_t pos) override { void Seek(std::size_t pos) override {
if (pos == kSeekEnd) { if (pos == kSeekEnd) {
curr_ptr_ = buffer_size_; curr_ptr_ = buffer_size_;
} else { } else {
curr_ptr_ = static_cast<size_t>(pos); curr_ptr_ = static_cast<std::size_t>(pos);
} }
} }
size_t Tell() override { /**
return curr_ptr_; * @brief Current position in the buffer (stream).
} */
virtual bool AtEnd() const { std::size_t Tell() override { return curr_ptr_; }
return curr_ptr_ == buffer_size_; virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; }
}
private: protected:
/*! \brief in memory buffer */ /*! \brief in memory buffer */
char *p_buffer_; char *p_buffer_{nullptr};
/*! \brief current pointer */ /*! \brief current pointer */
size_t buffer_size_; std::size_t buffer_size_{0};
/*! \brief current pointer */ /*! \brief current pointer */
size_t curr_ptr_; std::size_t curr_ptr_{0};
}; // class MemoryFixSizeBuffer };
/*! \brief a in memory buffer that can be read and write as stream interface */ /*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public SeekStream { struct MemoryBufferStream : public SeekStream {

View File

@ -1,24 +1,47 @@
/*! /**
* Copyright (c) by XGBoost Contributors 2019-2022 * Copyright 2019-2023, by XGBoost Contributors
*/ */
#if defined(__unix__) #if !defined(NOMINMAX) && defined(_WIN32)
#include <sys/stat.h> #define NOMINMAX
#include <fcntl.h> #endif // !defined(NOMINMAX)
#include <unistd.h>
#if !defined(xgboost_IS_WIN)
#if defined(_MSC_VER) || defined(__MINGW32__)
#define xgboost_IS_WIN 1
#endif // defined(_MSC_VER) || defined(__MINGW32__)
#endif // !defined(xgboost_IS_WIN)
#if defined(__unix__) || defined(__APPLE__)
#include <fcntl.h> // for open, O_RDONLY
#include <sys/mman.h> // for mmap, mmap64, munmap
#include <unistd.h> // for close, getpagesize
#elif defined(xgboost_IS_WIN)
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif // defined(__unix__) #endif // defined(__unix__)
#include <algorithm>
#include <fstream>
#include <string>
#include <memory>
#include <utility>
#include <cstdio>
#include "xgboost/logging.h" #include <algorithm> // for copy, transform
#include <cctype> // for tolower
#include <cerrno> // for errno
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint32_t
#include <cstring> // for memcpy
#include <fstream> // for ifstream
#include <iterator> // for distance
#include <limits> // for numeric_limits
#include <memory> // for unique_ptr
#include <string> // for string
#include <system_error> // for error_code, system_category
#include <utility> // for move
#include <vector> // for vector
#include "io.h" #include "io.h"
#include "xgboost/collective/socket.h" // for LastError
#include "xgboost/logging.h"
namespace xgboost { namespace xgboost::common {
namespace common {
size_t PeekableInStream::Read(void* dptr, size_t size) { size_t PeekableInStream::Read(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_; size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer == 0) return strm_->Read(dptr, size); if (nbuffer == 0) return strm_->Read(dptr, size);
@ -94,11 +117,32 @@ void FixedSizeStream::Take(std::string* out) {
*out = std::move(buffer_); *out = std::move(buffer_);
} }
namespace {
// Get system alignment value for IO with mmap.
std::size_t GetMmapAlignment() {
#if defined(xgboost_IS_WIN)
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
// During testing, `sys_info.dwPageSize` is of size 4096 while `dwAllocationGranularity` is of
// size 65536.
return sys_info.dwAllocationGranularity;
#else
return getpagesize();
#endif
}
auto SystemErrorMsg() {
std::int32_t errsv = system::LastError();
auto err = std::error_code{errsv, std::system_category()};
return err.message();
}
} // anonymous namespace
std::string LoadSequentialFile(std::string uri, bool stream) { std::string LoadSequentialFile(std::string uri, bool stream) {
auto OpenErr = [&uri]() { auto OpenErr = [&uri]() {
std::string msg; std::string msg;
msg = "Opening " + uri + " failed: "; msg = "Opening " + uri + " failed: ";
msg += strerror(errno); msg += SystemErrorMsg();
LOG(FATAL) << msg; LOG(FATAL) << msg;
}; };
@ -155,5 +199,99 @@ std::string FileExtension(std::string fname, bool lower) {
return ""; return "";
} }
} }
} // namespace common
} // namespace xgboost struct PrivateMmapConstStream::MMAPFile {
#if defined(xgboost_IS_WIN)
HANDLE fd{INVALID_HANDLE_VALUE};
HANDLE file_map{INVALID_HANDLE_VALUE};
#else
std::int32_t fd{0};
#endif
char* base_ptr{nullptr};
std::size_t base_size{0};
std::string path;
};
char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::size_t length) {
if (length == 0) {
return nullptr;
}
#if defined(xgboost_IS_WIN)
HANDLE fd = CreateFile(path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, nullptr);
CHECK_NE(fd, INVALID_HANDLE_VALUE) << "Failed to open:" << path << ". " << SystemErrorMsg();
#else
auto fd = open(path.c_str(), O_RDONLY);
CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg();
#endif
char* ptr{nullptr};
// Round down for alignment.
auto view_start = offset / GetMmapAlignment() * GetMmapAlignment();
auto view_size = length + (offset - view_start);
#if defined(__linux__) || defined(__GLIBC__)
int prot{PROT_READ};
ptr = reinterpret_cast<char*>(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
#elif defined(xgboost_IS_WIN)
auto file_size = GetFileSize(fd, nullptr);
DWORD access = PAGE_READONLY;
auto map_file = CreateFileMapping(fd, nullptr, access, 0, file_size, nullptr);
access = FILE_MAP_READ;
std::uint32_t loff = static_cast<std::uint32_t>(view_start);
std::uint32_t hoff = view_start >> 32;
CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg();
ptr = reinterpret_cast<char*>(MapViewOfFile(map_file, access, hoff, loff, view_size));
CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg();
handle_.reset(new MMAPFile{fd, map_file, ptr, view_size, std::move(path)});
#else
CHECK_LE(offset, std::numeric_limits<off_t>::max())
<< "File size has exceeded the limit on the current system.";
int prot{PROT_READ};
ptr = reinterpret_cast<char*>(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start));
CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg();
handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)});
#endif // defined(__linux__)
ptr += (offset - view_start);
return ptr;
}
PrivateMmapConstStream::PrivateMmapConstStream(std::string path, std::size_t offset,
std::size_t length)
: MemoryFixSizeBuffer{}, handle_{nullptr} {
this->p_buffer_ = Open(std::move(path), offset, length);
this->buffer_size_ = length;
}
PrivateMmapConstStream::~PrivateMmapConstStream() {
CHECK(handle_);
#if defined(xgboost_IS_WIN)
if (p_buffer_) {
CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg();
}
if (handle_->fd != INVALID_HANDLE_VALUE) {
CHECK(CloseHandle(handle_->fd)) << "Failed to close handle: " << SystemErrorMsg();
}
if (handle_->file_map != INVALID_HANDLE_VALUE) {
CHECK(CloseHandle(handle_->file_map)) << "Failed to close mapping object: " << SystemErrorMsg();
}
#else
if (handle_->base_ptr) {
CHECK_NE(munmap(handle_->base_ptr, handle_->base_size), -1)
<< "Faled to call munmap: " << handle_->path << ". " << SystemErrorMsg();
}
if (handle_->fd != 0) {
CHECK_NE(close(handle_->fd), -1)
<< "Faled to close: " << handle_->path << ". " << SystemErrorMsg();
}
#endif
}
} // namespace xgboost::common
#if defined(xgboost_IS_WIN)
#undef xgboost_IS_WIN
#endif // defined(xgboost_IS_WIN)

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright by XGBoost Contributors 2014-2022 * Copyright 2014-2023, XGBoost Contributors
* \file io.h * \file io.h
* \brief general stream interface for serialization, I/O * \brief general stream interface for serialization, I/O
* \author Tianqi Chen * \author Tianqi Chen
@ -10,9 +10,11 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <string>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <memory> // for unique_ptr
#include <string> // for string
#include "common.h" #include "common.h"
@ -127,6 +129,31 @@ inline std::string ReadAll(std::string const &path) {
return content; return content;
} }
/**
* @brief Private mmap file as a read-only stream.
*
* It can calculate alignment automatically based on system page size (or allocation
* granularity on Windows).
*/
class PrivateMmapConstStream : public MemoryFixSizeBuffer {
struct MMAPFile;
std::unique_ptr<MMAPFile> handle_;
char* Open(std::string path, std::size_t offset, std::size_t length);
public:
/**
* @brief Construct a private mmap stream.
*
* @param path File path.
* @param offset See the `offset` parameter of `mmap` for details.
* @param length See the `length` parameter of `mmap` for details.
*/
explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length);
void Write(void const*, std::size_t) override { LOG(FATAL) << "Read-only stream."; }
~PrivateMmapConstStream() override;
};
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_ #endif // XGBOOST_COMMON_IO_H_

View File

@ -1,35 +1,34 @@
/*! /**
* Copyright 2014-2022 by XGBoost Contributors * Copyright 2014-2023, XGBoost Contributors
* \file sparse_page_source.h * \file sparse_page_source.h
*/ */
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#include <algorithm> // std::min #include <algorithm> // for min
#include <string> #include <future> // async
#include <utility>
#include <vector>
#include <future>
#include <thread>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "../common/common.h"
#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP
#include "../common/timer.h" // for Monitor, Timer
#include "adapter.h"
#include "dmlc/common.h" // OMPException
#include "proxy_dmatrix.h"
#include "sparse_page_writer.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "adapter.h" namespace xgboost::data {
#include "sparse_page_writer.h"
#include "proxy_dmatrix.h"
#include "../common/common.h"
#include "../common/timer.h"
namespace xgboost {
namespace data {
inline void TryDeleteCacheFile(const std::string& file) { inline void TryDeleteCacheFile(const std::string& file) {
if (std::remove(file.c_str()) != 0) { if (std::remove(file.c_str()) != 0) {
LOG(WARNING) << "Couldn't remove external memory cache file " << file LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually"; << "; you may want to remove it manually";
} }
} }
@ -54,6 +53,9 @@ struct Cache {
std::string ShardName() { std::string ShardName() {
return ShardName(this->name, this->format); return ShardName(this->name, this->format);
} }
void Push(std::size_t n_bytes) {
offset.push_back(n_bytes);
}
// The write is completed. // The write is completed.
void Commit() { void Commit() {
@ -95,56 +97,72 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
uint32_t n_batches_ {0}; uint32_t n_batches_ {0};
std::shared_ptr<Cache> cache_info_; std::shared_ptr<Cache> cache_info_;
std::unique_ptr<dmlc::Stream> fo_;
using Ring = std::vector<std::future<std::shared_ptr<S>>>; using Ring = std::vector<std::future<std::shared_ptr<S>>>;
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we // A ring storing futures to data. Since the DMatrix iterator is forward only, so we
// can pre-fetch data in a ring. // can pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring}; std::unique_ptr<Ring> ring_{new Ring};
dmlc::OMPException exec_;
common::Monitor monitor_;
bool ReadCache() { bool ReadCache() {
CHECK(!at_end_); CHECK(!at_end_);
if (!cache_info_->written) { if (!cache_info_->written) {
return false; return false;
} }
if (fo_) { if (ring_->empty()) {
fo_.reset(); // flush the data to disk.
ring_->resize(n_batches_); ring_->resize(n_batches_);
} }
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam // An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed. // to let user adjust number of pre-fetched batches when needed.
uint32_t constexpr kPreFetch = 4; uint32_t constexpr kPreFetch = 3;
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
size_t fetch_it = count_; std::size_t fetch_it = count_;
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { exec_.Rethrow();
monitor_.Start("launch");
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_; // ring fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { if (ring_->at(fetch_it).valid()) {
continue; continue;
} }
auto const *self = this; // make sure it's const auto const* self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size()); CHECK_LT(fetch_it, cache_info_->offset.size());
ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() {
common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
size_t offset = self->cache_info_->offset.at(fetch_it);
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(n.c_str())};
fi->Seek(offset);
CHECK_EQ(fi->Tell(), offset);
auto page = std::make_shared<S>(); auto page = std::make_shared<S>();
CHECK(fmt->Read(page.get(), fi.get())); this->exec_.Run([&] {
LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds."; common::Timer timer;
timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
auto n = self->cache_info_->ShardName();
std::uint64_t offset = self->cache_info_->offset.at(fetch_it);
std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset;
auto fi = std::make_unique<common::PrivateMmapConstStream>(n, offset, length);
CHECK(fmt->Read(page.get(), fi.get()));
timer.Stop();
LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds()
<< " seconds.";
});
return page; return page;
}); });
} }
monitor_.Stop("launch");
CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }),
n_prefetch_batches) n_prefetch_batches)
<< "Sparse DMatrix assumes forward iteration."; << "Sparse DMatrix assumes forward iteration.";
monitor_.Start("Wait");
page_ = (*ring_)[count_].get(); page_ = (*ring_)[count_].get();
monitor_.Stop("Wait");
CHECK(!(*ring_)[count_].valid());
exec_.Rethrow();
return true; return true;
} }
@ -153,25 +171,35 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
common::Timer timer; common::Timer timer;
timer.Start(); timer.Start();
std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")}; std::unique_ptr<SparsePageFormat<S>> fmt{CreatePageFormat<S>("raw")};
if (!fo_) {
auto n = cache_info_->ShardName();
fo_.reset(dmlc::Stream::Create(n.c_str(), "w"));
}
auto bytes = fmt->Write(*page_, fo_.get());
timer.Stop();
auto name = cache_info_->ShardName();
std::unique_ptr<dmlc::Stream> fo;
if (this->Iter() == 0) {
fo.reset(dmlc::Stream::Create(name.c_str(), "wb"));
} else {
fo.reset(dmlc::Stream::Create(name.c_str(), "ab"));
}
auto bytes = fmt->Write(*page_, fo.get());
timer.Stop();
LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in " LOG(INFO) << static_cast<double>(bytes) / 1024.0 / 1024.0 << " MB written in "
<< timer.ElapsedSeconds() << " seconds."; << timer.ElapsedSeconds() << " seconds.";
cache_info_->offset.push_back(bytes); cache_info_->Push(bytes);
} }
virtual void Fetch() = 0; virtual void Fetch() = 0;
public: public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches,
uint32_t n_batches, std::shared_ptr<Cache> cache) std::shared_ptr<Cache> cache)
: missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, : missing_{missing},
n_batches_{n_batches}, cache_info_{std::move(cache)} {} nthreads_{nthreads},
n_features_{n_features},
n_batches_{n_batches},
cache_info_{std::move(cache)} {
monitor_.Init(typeid(S).name()); // not pretty, but works for basic profiling
}
SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete; SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete;
@ -244,7 +272,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
iter_{iter}, proxy_{proxy} { iter_{iter}, proxy_{proxy} {
if (!cache_info_->written) { if (!cache_info_->written) {
iter_.Reset(); iter_.Reset();
CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch."; CHECK(iter_.Next()) << "Must have at least 1 batch.";
} }
this->Fetch(); this->Fetch();
} }
@ -259,6 +287,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
} }
if (at_end_) { if (at_end_) {
CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1);
cache_info_->Commit(); cache_info_->Commit();
if (n_batches_ != 0) { if (n_batches_ != 0) {
CHECK_EQ(count_, n_batches_); CHECK_EQ(count_, n_batches_);
@ -371,6 +400,5 @@ class SortedCSCPageSource : public PageSourceIncMixIn<SortedCSCPage> {
this->Fetch(); this->Fetch();
} }
}; };
} // namespace data } // namespace xgboost::data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_

View File

@ -146,27 +146,30 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
CombineGradientPair combine_; CombineGradientPair combine_;
}; };
NoSampling::NoSampling(EllpackPageImpl const* page) : page_(page) {} NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair, GradientBasedSample NoSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
return {dmat->Info().num_row_, page_, gpair}; auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
return {dmat->Info().num_row_, page, gpair};
} }
ExternalMemoryNoSampling::ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
size_t n_rows, BatchParam batch_param) : batch_param_{std::move(batch_param)} {}
: batch_param_{std::move(batch_param)},
page_(new EllpackPageImpl(ctx->gpu_id, page->Cuts(), page->is_dense, page->row_stride,
n_rows)) {}
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
if (!page_concatenated_) { if (!page_concatenated_) {
// Concatenate all the external memory ELLPACK pages into a single in-memory page. // Concatenate all the external memory ELLPACK pages into a single in-memory page.
page_.reset(nullptr);
size_t offset = 0; size_t offset = 0;
for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) { for (auto& batch : dmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
auto page = batch.Impl(); auto page = batch.Impl();
if (!page_) {
page_ = std::make_unique<EllpackPageImpl>(ctx->gpu_id, page->Cuts(), page->is_dense,
page->row_stride, dmat->Info().num_row_);
}
size_t num_elements = page_->Copy(ctx->gpu_id, page, offset); size_t num_elements = page_->Copy(ctx->gpu_id, page, offset);
offset += num_elements; offset += num_elements;
} }
@ -175,8 +178,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
return {dmat->Info().num_row_, page_.get(), gpair}; return {dmat->Info().num_row_, page_.get(), gpair};
} }
UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample) UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
: page_(page), subsample_(subsample) {} : batch_param_{std::move(batch_param)}, subsample_(subsample) {}
GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
@ -185,7 +188,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<std::size_t>(0), thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
return {dmat->Info().num_row_, page_, gpair}; auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
return {dmat->Info().num_row_, page, gpair};
} }
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
@ -236,12 +240,10 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx,
return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
} }
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl const* page, GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param,
size_t n_rows,
const BatchParam&,
float subsample) float subsample)
: page_(page), : subsample_(subsample),
subsample_(subsample), batch_param_{std::move(batch_param)},
threshold_(n_rows + 1, 0.0f), threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f) {} grad_sum_(n_rows, 0.0f) {}
@ -252,18 +254,19 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
// Perform Poisson sampling in place. // Perform Poisson sampling in place.
thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair), thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
PoissonSampling(dh::ToSpan(threshold_), threshold_index, PoissonSampling(dh::ToSpan(threshold_), threshold_index,
RandomWeight(common::GlobalRandom()()))); RandomWeight(common::GlobalRandom()())));
return {n_rows, page_, gpair}; return {n_rows, page, gpair};
} }
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
size_t n_rows, BatchParam batch_param,
BatchParam batch_param, float subsample)
float subsample)
: batch_param_(std::move(batch_param)), : batch_param_(std::move(batch_param)),
subsample_(subsample), subsample_(subsample),
threshold_(n_rows + 1, 0.0f), threshold_(n_rows + 1, 0.0f),
@ -273,16 +276,15 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx, GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx,
common::Span<GradientPair> gpair, common::Span<GradientPair> gpair,
DMatrix* dmat) { DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_; auto cuctx = ctx->CUDACtx();
bst_row_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
// Perform Poisson sampling in place. // Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0), thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index,
PoissonSampling(dh::ToSpan(threshold_),
threshold_index,
RandomWeight(common::GlobalRandom()()))); RandomWeight(common::GlobalRandom()())));
// Count the sampled rows. // Count the sampled rows.
@ -290,16 +292,15 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
// Compact gradient pairs. // Compact gradient pairs.
gpair_.resize(sample_rows); gpair_.resize(sample_rows);
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());
// Index the sample rows. // Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero()); thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(), IsNonZero());
sample_row_index_.begin()); thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(),
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin());
sample_row_index_.begin(), thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(),
sample_row_index_.begin(), sample_row_index_.begin(), ClearEmptyRows());
ClearEmptyRows());
auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_); auto batch_iterator = dmat->GetBatches<EllpackPage>(ctx, batch_param_);
auto first_page = (*batch_iterator.begin()).Impl(); auto first_page = (*batch_iterator.begin()).Impl();
@ -317,13 +318,13 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
} }
GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
size_t n_rows, const BatchParam& batch_param, const BatchParam& batch_param, float subsample,
float subsample, int sampling_method) { int sampling_method, bool is_external_memory) {
// The ctx is kept here for future development of stream-based operations.
monitor_.Init("gradient_based_sampler"); monitor_.Init("gradient_based_sampler");
bool is_sampling = subsample < 1.0; bool is_sampling = subsample < 1.0;
bool is_external_memory = page->n_rows != n_rows;
if (is_sampling) { if (is_sampling) {
switch (sampling_method) { switch (sampling_method) {
@ -331,24 +332,24 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c
if (is_external_memory) { if (is_external_memory) {
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample)); strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
} else { } else {
strategy_.reset(new UniformSampling(page, subsample)); strategy_.reset(new UniformSampling(batch_param, subsample));
} }
break; break;
case TrainParam::kGradientBased: case TrainParam::kGradientBased:
if (is_external_memory) { if (is_external_memory) {
strategy_.reset( strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
} else { } else {
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample)); strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample));
} }
break; break;
default:LOG(FATAL) << "unknown sampling method"; default:
LOG(FATAL) << "unknown sampling method";
} }
} else { } else {
if (is_external_memory) { if (is_external_memory) {
strategy_.reset(new ExternalMemoryNoSampling(ctx, page, n_rows, batch_param)); strategy_.reset(new ExternalMemoryNoSampling(batch_param));
} else { } else {
strategy_.reset(new NoSampling(page)); strategy_.reset(new NoSampling(batch_param));
} }
} }
} }
@ -362,11 +363,11 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
return sample; return sample;
} }
size_t GradientBasedSampler::CalculateThresholdIndex( size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<GradientPair> gpair, common::Span<float> threshold, common::Span<float> threshold,
common::Span<float> grad_sum, size_t sample_rows) { common::Span<float> grad_sum,
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), size_t sample_rows) {
std::numeric_limits<float>::max()); thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold), thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
CombineGradientPair()); CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
@ -379,6 +380,5 @@ size_t GradientBasedSampler::CalculateThresholdIndex(
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1; return thrust::distance(dh::tbegin(grad_sum), min) + 1;
} }
}; // namespace tree }; // namespace tree
}; // namespace xgboost }; // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright 2019 by XGBoost Contributors * Copyright 2019-2023, XGBoost Contributors
*/ */
#pragma once #pragma once
#include <xgboost/base.h> #include <xgboost/base.h>
@ -32,37 +32,36 @@ class SamplingStrategy {
/*! \brief No sampling in in-memory mode. */ /*! \brief No sampling in in-memory mode. */
class NoSampling : public SamplingStrategy { class NoSampling : public SamplingStrategy {
public: public:
explicit NoSampling(EllpackPageImpl const* page); explicit NoSampling(BatchParam batch_param);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
EllpackPageImpl const* page_;
};
/*! \brief No sampling in external memory mode. */
class ExternalMemoryNoSampling : public SamplingStrategy {
public:
ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, size_t n_rows,
BatchParam batch_param);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override; DMatrix* dmat) override;
private: private:
BatchParam batch_param_; BatchParam batch_param_;
std::unique_ptr<EllpackPageImpl> page_; };
/*! \brief No sampling in external memory mode. */
class ExternalMemoryNoSampling : public SamplingStrategy {
public:
explicit ExternalMemoryNoSampling(BatchParam batch_param);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;
private:
BatchParam batch_param_;
std::unique_ptr<EllpackPageImpl> page_{nullptr};
bool page_concatenated_{false}; bool page_concatenated_{false};
}; };
/*! \brief Uniform sampling in in-memory mode. */ /*! \brief Uniform sampling in in-memory mode. */
class UniformSampling : public SamplingStrategy { class UniformSampling : public SamplingStrategy {
public: public:
UniformSampling(EllpackPageImpl const* page, float subsample); UniformSampling(BatchParam batch_param, float subsample);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override; DMatrix* dmat) override;
private: private:
EllpackPageImpl const* page_; BatchParam batch_param_;
float subsample_; float subsample_;
}; };
@ -84,13 +83,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
/*! \brief Gradient-based sampling in in-memory mode.. */ /*! \brief Gradient-based sampling in in-memory mode.. */
class GradientBasedSampling : public SamplingStrategy { class GradientBasedSampling : public SamplingStrategy {
public: public:
GradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, const BatchParam& batch_param, GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, float subsample);
float subsample);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override; DMatrix* dmat) override;
private: private:
EllpackPageImpl const* page_; BatchParam batch_param_;
float subsample_; float subsample_;
dh::caching_device_vector<float> threshold_; dh::caching_device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_; dh::caching_device_vector<float> grad_sum_;
@ -106,11 +104,11 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
private: private:
BatchParam batch_param_; BatchParam batch_param_;
float subsample_; float subsample_;
dh::caching_device_vector<float> threshold_; dh::device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_; dh::device_vector<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_; std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_; dh::device_vector<GradientPair> gpair_;
dh::caching_device_vector<size_t> sample_row_index_; dh::device_vector<size_t> sample_row_index_;
}; };
/*! \brief Draw a sample of rows from a DMatrix. /*! \brief Draw a sample of rows from a DMatrix.
@ -124,8 +122,8 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
*/ */
class GradientBasedSampler { class GradientBasedSampler {
public: public:
GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, size_t n_rows, GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param,
const BatchParam& batch_param, float subsample, int sampling_method); float subsample, int sampling_method, bool is_external_memory);
/*! \brief Sample from a DMatrix based on the given gradient pairs. */ /*! \brief Sample from a DMatrix based on the given gradient pairs. */
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat); GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);

View File

@ -176,7 +176,7 @@ struct GPUHistMakerDevice {
Context const* ctx_; Context const* ctx_;
public: public:
EllpackPageImpl const* page; EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
BatchParam batch_param; BatchParam batch_param;
@ -205,41 +205,41 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups; std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, bool is_external_memory,
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, common::Span<FeatureType const> _feature_types, bst_row_t _n_rows,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
BatchParam _batch_param) BatchParam _batch_param)
: evaluator_{_param, n_features, ctx->gpu_id}, : evaluator_{_param, n_features, ctx->gpu_id},
ctx_(ctx), ctx_(ctx),
page(_page),
feature_types{_feature_types}, feature_types{_feature_types},
param(std::move(_param)), param(std::move(_param)),
column_sampler(column_sampler_seed), column_sampler(column_sampler_seed),
interaction_constraints(param, n_features), interaction_constraints(param, n_features),
batch_param(std::move(_batch_param)) { batch_param(std::move(_batch_param)) {
sampler.reset(new GradientBasedSampler(ctx, page, _n_rows, batch_param, param.subsample, sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample,
param.sampling_method)); param.sampling_method, is_external_memory));
if (!param.monotone_constraints.empty()) { if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds // Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints; monotone_constraints = param.monotone_constraints;
} }
// Init histogram
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
sizeof(GradientSumT)));
} }
~GPUHistMakerDevice() { // NOLINT ~GPUHistMakerDevice() { // NOLINT
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
} }
void InitFeatureGroupsOnce() {
if (!feature_groups) {
CHECK(page);
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
dh::MaxSharedMemoryOptin(ctx_->gpu_id),
sizeof(GradientSumT)));
}
}
// Reset values for each update iteration // Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) { void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
auto const& info = dmat->Info(); auto const& info = dmat->Info();
this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(), this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(),
@ -247,26 +247,30 @@ struct GPUHistMakerDevice {
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
ctx_->gpu_id);
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
if (d_gpair.size() != dh_gpair->Size()) { if (d_gpair.size() != dh_gpair->Size()) {
d_gpair.resize(dh_gpair->Size()); d_gpair.resize(dh_gpair->Size());
} }
dh::safe_cuda(cudaMemcpyAsync( dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(),
d_gpair.data().get(), dh_gpair->ConstDevicePointer(), dh_gpair->Size() * sizeof(GradientPair),
dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); cudaMemcpyDeviceToDevice));
auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat); auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat);
page = sample.page; page = sample.page;
gpair = sample.gpair; gpair = sample.gpair;
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
quantiser.reset(new GradientQuantiser(this->gpair)); quantiser.reset(new GradientQuantiser(this->gpair));
row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows)); row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows));
// Init histogram
hist.Init(ctx_->gpu_id, page->Cuts().TotalBins());
hist.Reset(); hist.Reset();
this->InitFeatureGroupsOnce();
} }
GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) {
@ -808,12 +812,11 @@ class GPUHistMaker : public TreeUpdater {
collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0);
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(ctx_, batch_param).begin()).Impl();
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
info_->feature_types.SetDevice(ctx_->gpu_id); info_->feature_types.SetDevice(ctx_->gpu_id);
maker.reset(new GPUHistMakerDevice<GradientSumT>( maker.reset(new GPUHistMakerDevice<GradientSumT>(
ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param, ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_,
column_sampling_seed, info_->num_col_, batch_param)); *param, column_sampling_seed, info_->num_col_, batch_param));
p_last_fmat_ = dmat; p_last_fmat_ = dmat;
initialised_ = true; initialised_ = true;

View File

@ -1,5 +1,5 @@
/*! /**
* Copyright (c) by XGBoost Contributors 2019 * Copyright 2019-2023, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -9,8 +9,7 @@
#include "../helpers.h" #include "../helpers.h"
#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../filesystem.h" // dmlc::TemporaryDirectory
namespace xgboost { namespace xgboost::common {
namespace common {
TEST(MemoryFixSizeBuffer, Seek) { TEST(MemoryFixSizeBuffer, Seek) {
size_t constexpr kSize { 64 }; size_t constexpr kSize { 64 };
std::vector<int32_t> memory( kSize ); std::vector<int32_t> memory( kSize );
@ -89,5 +88,54 @@ TEST(IO, LoadSequentialFile) {
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error); ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
} }
} // namespace common
} // namespace xgboost TEST(IO, PrivateMmapStream) {
dmlc::TemporaryDirectory tempdir;
auto path = tempdir.path + "/testfile";
// The page size on Linux is usually set to 4096, while the allocation granularity on
// the Windows machine where this test is writted is 65536. We span the test to cover
// all of them.
std::size_t n_batches{64};
std::size_t multiplier{2048};
std::vector<std::vector<std::int32_t>> batches;
std::vector<std::size_t> offset{0ul};
using T = std::int32_t;
{
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (std::size_t i = 0; i < n_batches; ++i) {
std::size_t size = (i + 1) * multiplier;
std::vector<T> data(size, 0);
std::iota(data.begin(), data.end(), i * i);
fo->Write(static_cast<std::uint64_t>(data.size()));
fo->Write(data.data(), data.size() * sizeof(T));
std::size_t bytes = sizeof(std::uint64_t) + data.size() * sizeof(T);
offset.push_back(bytes);
batches.emplace_back(std::move(data));
}
}
// Turn size info offset
std::partial_sum(offset.begin(), offset.end(), offset.begin());
for (std::size_t i = 0; i < n_batches; ++i) {
std::size_t off = offset[i];
std::size_t n = offset.at(i + 1) - offset[i];
std::unique_ptr<dmlc::Stream> fi{std::make_unique<PrivateMmapConstStream>(path, off, n)};
std::vector<T> data;
std::uint64_t size{0};
fi->Read(&size);
data.resize(size);
fi->Read(data.data(), size * sizeof(T));
ASSERT_EQ(data, batches[i]);
}
}
} // namespace xgboost::common

View File

@ -2,6 +2,10 @@
#include "../../src/data/ellpack_page.cuh" #include "../../src/data/ellpack_page.cuh"
#endif #endif
#include <xgboost/data.h> // for SparsePage
#include "./helpers.h" // for RandomDataGenerator
namespace xgboost { namespace xgboost {
#if defined(__CUDACC__) #if defined(__CUDACC__)
namespace { namespace {

View File

@ -39,7 +39,8 @@ void VerifySampling(size_t page_size,
EXPECT_NE(page->n_rows, kRows); EXPECT_NE(page->n_rows, kRows);
} }
GradientBasedSampler sampler(&ctx, page, kRows, param, subsample, sampling_method); GradientBasedSampler sampler(&ctx, kRows, param, subsample, sampling_method,
!fixed_size_sampling);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
if (fixed_size_sampling) { if (fixed_size_sampling) {
@ -93,7 +94,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl(); auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
EXPECT_NE(page->n_rows, kRows); EXPECT_NE(page->n_rows, kRows);
GradientBasedSampler sampler(&ctx, page, kRows, param, kSubsample, TrainParam::kUniform); GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
auto sampled_page = sample.page; auto sampled_page = sample.page;
EXPECT_EQ(sample.sample_rows, kRows); EXPECT_EQ(sample.sample_rows, kRows);
@ -141,7 +142,8 @@ TEST(GradientBasedSampler, GradientBasedSampling) {
constexpr size_t kPageSize = 0; constexpr size_t kPageSize = 0;
constexpr float kSubsample = 0.8; constexpr float kSubsample = 0.8;
constexpr int kSamplingMethod = TrainParam::kGradientBased; constexpr int kSamplingMethod = TrainParam::kGradientBased;
VerifySampling(kPageSize, kSubsample, kSamplingMethod); constexpr bool kFixedSizeSampling = true;
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
} }
TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) { TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {

View File

@ -92,8 +92,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
auto page = BuildEllpackPage(kNRows, kNCols); auto page = BuildEllpackPage(kNRows, kNCols);
BatchParam batch_param{}; BatchParam batch_param{};
Context ctx{MakeCUDACtx(0)}; Context ctx{MakeCUDACtx(0)};
GPUHistMakerDevice<GradientSumT> maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols, GPUHistMakerDevice<GradientSumT> maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param,
batch_param); kNCols, kNCols, batch_param);
xgboost::SimpleLCG gen; xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f); xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
HostDeviceVector<GradientPair> gpair(kNRows); HostDeviceVector<GradientPair> gpair(kNRows);
@ -106,9 +106,15 @@ void TestBuildHist(bool use_shared_memory_histograms) {
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector()); thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.hist.Init(0, page->Cuts().TotalBins());
maker.hist.AllocateHistograms({0}); maker.hist.AllocateHistograms({0});
maker.gpair = gpair.DeviceSpan(); maker.gpair = gpair.DeviceSpan();
maker.quantiser.reset(new GradientQuantiser(maker.gpair)); maker.quantiser.reset(new GradientQuantiser(maker.gpair));
maker.page = page.get();
maker.InitFeatureGroupsOnce();
BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0),
maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(), maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(),
@ -126,8 +132,8 @@ void TestBuildHist(bool use_shared_memory_histograms) {
std::vector<GradientPairPrecise> solution = GetHostHistGpair(); std::vector<GradientPairPrecise> solution = GetHostHistGpair();
for (size_t i = 0; i < h_result.size(); ++i) { for (size_t i = 0; i < h_result.size(); ++i) {
auto result = maker.quantiser->ToFloatingPoint(h_result[i]); auto result = maker.quantiser->ToFloatingPoint(h_result[i]);
EXPECT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f); ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f);
EXPECT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f); ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f);
} }
} }

View File

@ -305,7 +305,7 @@ class IterForDMatrixTest(xgb.core.DataIter):
self._labels = [rng.randn(self.rows)] * self.BATCHES self._labels = [rng.randn(self.rows)] * self.BATCHES
self.it = 0 # set iterator to 0 self.it = 0 # set iterator to 0
super().__init__() super().__init__(cache_prefix=None)
def as_array(self): def as_array(self):
import cudf import cudf

View File

@ -64,7 +64,8 @@ def run_data_iterator(
subsample_rate = 0.8 if subsample else 1.0 subsample_rate = 0.8 if subsample else 1.0
it = IteratorForTest( it = IteratorForTest(
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy) *make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
cache="cache"
) )
if n_batches == 0: if n_batches == 0:
with pytest.raises(ValueError, match="1 batch"): with pytest.raises(ValueError, match="1 batch"):