diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index cc5527611..11a05c61c 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -82,10 +82,10 @@ def main(tmpdir: str) -> xgboost.Booster: missing = np.NaN Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) - # Other tree methods including ``hist`` and ``gpu_hist`` also work, see tutorial in - # doc for details. + # Other tree methods including ``approx``, and ``gpu_hist`` are supported. GPU + # behaves differently than CPU tree methods. See tutorial in doc for details. booster = xgboost.train( - {"tree_method": "approx", "max_depth": 2}, + {"tree_method": "hist", "max_depth": 4}, Xy, evals=[(Xy, "Train")], num_boost_round=10, diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index 006d63b43..f5b6132c7 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -2,11 +2,25 @@ Using XGBoost External Memory Version ##################################### -XGBoost supports loading data from external memory using builtin data parser. And -starting from version 1.5, users can also define a custom iterator to load data in chunks. -The feature is still experimental and not yet ready for production use. In this tutorial -we will introduce both methods. Please note that training on data from external memory is -not supported by ``exact`` tree method. +When working with large datasets, training XGBoost models can be challenging as the entire +dataset needs to be loaded into memory. This can be costly and sometimes +infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks +for running XGBoost algorithms. External memory can be used for both training and +prediction, but training is the primary use case and it will be our focus in this +tutorial. For prediction and evaluation, users can iterate through the data themseleves +while training requires the full dataset to be loaded into the memory. + +During training, there are two different modes for external memory support available in +XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the +GPU-based training algorithm. We will introduce them in the following sections. + +.. note:: + + Training on data from external memory is not supported by the ``exact`` tree method. + +.. note:: + + The feature is still experimental as of 2.0. The performance is not well optimized. ************* Data Iterator @@ -15,8 +29,8 @@ Data Iterator Starting from XGBoost 1.5, users can define their own data loader using Python or C interface. There are some examples in the ``demo`` directory for quick start. This is a generalized version of text input external memory, where users no longer need to prepare a -text file that XGBoost recognizes. To enable the feature, user need to define a data -iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix`` +text file that XGBoost recognizes. To enable the feature, users need to define a data +iterator with 2 class methods: ``next`` and ``reset``, then pass it into the ``DMatrix`` constructor. .. code-block:: python @@ -60,20 +74,96 @@ constructor. # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats # as noted in following sections. - booster = xgboost.train({"tree_method": "approx"}, Xy) + booster = xgboost.train({"tree_method": "hist"}, Xy) -The above snippet is a simplified version of ``demo/guide-python/external_memory.py``. For -an example in C, please see ``demo/c-api/external-memory/``. +The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`. +For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the +common interface for using external memory with XGBoost, you can pass the resulting +``DMatrix`` object for training, prediction, and evaluation. + +It is important to set the batch size based on the memory available. A good starting point +is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not* +recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt +performance in gradient boosting. + +*********** +CPU Version +*********** + +In the previous section, we demonstrated how to train a tree-based model using the +``hist`` tree method on a CPU. This method involves iterating through data batches stored +in a cache during tree construction. For optimal performance, we recommend using the +``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree +nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy +requires XGBoost to iterate over the data set for each tree node, resulting in slower +performance. + +If external memory is used, the performance of CPU training is limited by IO +(input/output) speed. This means that the disk IO speed primarily determines the training +speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of +storage can be too slow for practical usage. In addition, your system may perform caching +to reduce the overhead of file reading. + +********************************** +GPU Version (GPU Hist tree method) +********************************** + +External memory is supported by GPU algorithms (i.e. when ``tree_method`` is set to +``gpu_hist``). However, the algorithm used for GPU is different from the one used for +CPU. When training on a CPU, the tree method iterates through all batches from external +memory for each step of the tree construction algorithm. On the other hand, the GPU +algorithm concatenates all batches into one and stores it in GPU memory. To reduce overall +memory usage, users can utilize subsampling. The good news is that the GPU hist tree +method supports gradient-based sampling, enabling users to set a low sampling rate without +compromising accuracy. + +.. code-block:: python + + param = { + ... + 'subsample': 0.2, + 'sampling_method': 'gradient_based', + } + +For more information about the sampling algorithm and its use in external memory training, +see `this paper `_. + +.. warning:: + + When GPU is running out of memory during iteration on external memory, user might + recieve a segfault instead of an OOM exception. + +******* +Remarks +******* + +When using external memory with XBGoost, data is divided into smaller chunks so that only +a fraction of it needs to be stored in memory at any given time. It's important to note +that this method only applies to the predictor data (``X``), while other data, like labels +and internal runtime structures are concatenated. This means that memory reduction is most +effective when dealing with wide datasets where ``X`` is larger compared to other data +like ``y``, while it has little impact on slim datasets. + +Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not +yet tested against system errors like disconnected network devices (`SIGBUS`). Also, it's +worth noting that most tests have been conducted on Linux distributions. + +Another important point to keep in mind is that creating the initial cache for XGBoost may +take some time. The interface to external memory is through custom iterators, which may or +may not be thread-safe. Therefore, initialization is performed sequentially. + **************** Text File Inputs **************** -There is no big difference between using external memory version and in-memory version. -The only difference is the filename format. +This is the original form of external memory support, users are encouraged to use custom +data iterator instead. There is no big difference between using external memory version of +text input and the in-memory version. The only difference is the filename format. -The external memory version takes in the following `URI `_ format: +The external memory version takes in the following `URI +`_ format: .. code-block:: none @@ -91,9 +181,8 @@ To load from csv files, use the following syntax: where ``label_column`` should point to the csv column acting as the label. -To provide a simple example for illustration, extracting the code from -`demo/guide-python/external_memory.py `_. If -you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSVM format, the external memory support can be enabled by: +If you have a dataset stored in a file similar to ``demo/data/agaricus.txt.train`` with LIBSVM +format, the external memory support can be enabled by: .. code-block:: python @@ -104,35 +193,3 @@ XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to more notes about text input formats, see :doc:`/tutorials/input_format`. For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``. - - -********************************** -GPU Version (GPU Hist tree method) -********************************** -External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``). - -If you are still getting out-of-memory errors after enabling external memory, try subsampling the -data to further reduce GPU memory usage: - -.. code-block:: python - - param = { - ... - 'subsample': 0.1, - 'sampling_method': 'gradient_based', - } - -For more information, see `this paper `_. Internally -the tree method still concatenate all the chunks into 1 final histogram index due to -performance reason, but in compressed format. So its scalability has an upper bound but -still has lower memory cost in general. - -*********** -CPU Version -*********** - -For CPU histogram based tree methods (``approx``, ``hist``) it's recommended to use -``grow_policy=depthwise`` for performance reason. Iterating over data batches is slow, -with ``depthwise`` policy XGBoost can build a entire layer of tree nodes with a few -iterations, while with ``lossguide`` XGBoost needs to iterate over the data set for each -tree node. diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 70e536101..f6abb867e 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -198,14 +198,14 @@ class IteratorForTest(xgb.core.DataIter): X: Sequence, y: Sequence, w: Optional[Sequence], - cache: Optional[str] = "./", + cache: Optional[str], ) -> None: assert len(X) == len(y) self.X = X self.y = y self.w = w self.it = 0 - super().__init__(cache) + super().__init__(cache_prefix=cache) def next(self, input_data: Callable) -> int: if self.it == len(self.X): @@ -347,7 +347,9 @@ class TestDataset: if w is not None: weight.append(w) - it = IteratorForTest(predictor, response, weight if weight else None) + it = IteratorForTest( + predictor, response, weight if weight else None, cache="cache" + ) return xgb.DMatrix(it) def __repr__(self) -> str: diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h index 978eebd8a..a12e1decd 100644 --- a/rabit/include/rabit/internal/io.h +++ b/rabit/include/rabit/internal/io.h @@ -1,18 +1,21 @@ -/*! - * Copyright (c) 2014-2019 by Contributors +/** + * Copyright 2014-2023, XGBoost Contributors * \file io.h * \brief utilities with different serializable implementations * \author Tianqi Chen */ #ifndef RABIT_INTERNAL_IO_H_ #define RABIT_INTERNAL_IO_H_ -#include -#include -#include -#include + #include -#include +#include // for size_t +#include +#include // for memcpy #include +#include +#include +#include + #include "rabit/internal/utils.h" #include "rabit/serializable.h" @@ -20,54 +23,61 @@ namespace rabit { namespace utils { /*! \brief re-use definition of dmlc::SeekStream */ using SeekStream = dmlc::SeekStream; -/*! \brief fixed size memory buffer */ +/** + * @brief Fixed size memory buffer as a stream. + */ struct MemoryFixSizeBuffer : public SeekStream { public: // similar to SEEK_END in libc - static size_t constexpr kSeekEnd = std::numeric_limits::max(); + static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); + + protected: + MemoryFixSizeBuffer() = default; public: - MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) - : p_buffer_(reinterpret_cast(p_buffer)), - buffer_size_(buffer_size) { - curr_ptr_ = 0; - } + /** + * @brief Ctor + * + * @param p_buffer Pointer to the source buffer with size `buffer_size`. + * @param buffer_size Size of the source buffer + */ + MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) {} ~MemoryFixSizeBuffer() override = default; - size_t Read(void *ptr, size_t size) override { - size_t nread = std::min(buffer_size_ - curr_ptr_, size); + + std::size_t Read(void *ptr, std::size_t size) override { + std::size_t nread = std::min(buffer_size_ - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); curr_ptr_ += nread; return nread; } - void Write(const void *ptr, size_t size) override { + void Write(const void *ptr, std::size_t size) override { if (size == 0) return; - utils::Assert(curr_ptr_ + size <= buffer_size_, - "write position exceed fixed buffer size"); + CHECK_LE(curr_ptr_ + size, buffer_size_); std::memcpy(p_buffer_ + curr_ptr_, ptr, size); curr_ptr_ += size; } - void Seek(size_t pos) override { + void Seek(std::size_t pos) override { if (pos == kSeekEnd) { curr_ptr_ = buffer_size_; } else { - curr_ptr_ = static_cast(pos); + curr_ptr_ = static_cast(pos); } } - size_t Tell() override { - return curr_ptr_; - } - virtual bool AtEnd() const { - return curr_ptr_ == buffer_size_; - } + /** + * @brief Current position in the buffer (stream). + */ + std::size_t Tell() override { return curr_ptr_; } + virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } - private: + protected: /*! \brief in memory buffer */ - char *p_buffer_; + char *p_buffer_{nullptr}; /*! \brief current pointer */ - size_t buffer_size_; + std::size_t buffer_size_{0}; /*! \brief current pointer */ - size_t curr_ptr_; -}; // class MemoryFixSizeBuffer + std::size_t curr_ptr_{0}; +}; /*! \brief a in memory buffer that can be read and write as stream interface */ struct MemoryBufferStream : public SeekStream { diff --git a/src/common/io.cc b/src/common/io.cc index da3a75d65..ba97db574 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -1,24 +1,47 @@ -/*! - * Copyright (c) by XGBoost Contributors 2019-2022 +/** + * Copyright 2019-2023, by XGBoost Contributors */ -#if defined(__unix__) -#include -#include -#include +#if !defined(NOMINMAX) && defined(_WIN32) +#define NOMINMAX +#endif // !defined(NOMINMAX) + +#if !defined(xgboost_IS_WIN) + +#if defined(_MSC_VER) || defined(__MINGW32__) +#define xgboost_IS_WIN 1 +#endif // defined(_MSC_VER) || defined(__MINGW32__) + +#endif // !defined(xgboost_IS_WIN) + +#if defined(__unix__) || defined(__APPLE__) +#include // for open, O_RDONLY +#include // for mmap, mmap64, munmap +#include // for close, getpagesize +#elif defined(xgboost_IS_WIN) +#define WIN32_LEAN_AND_MEAN +#include #endif // defined(__unix__) -#include -#include -#include -#include -#include -#include -#include "xgboost/logging.h" +#include // for copy, transform +#include // for tolower +#include // for errno +#include // for size_t +#include // for int32_t, uint32_t +#include // for memcpy +#include // for ifstream +#include // for distance +#include // for numeric_limits +#include // for unique_ptr +#include // for string +#include // for error_code, system_category +#include // for move +#include // for vector + #include "io.h" +#include "xgboost/collective/socket.h" // for LastError +#include "xgboost/logging.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { size_t PeekableInStream::Read(void* dptr, size_t size) { size_t nbuffer = buffer_.length() - buffer_ptr_; if (nbuffer == 0) return strm_->Read(dptr, size); @@ -94,11 +117,32 @@ void FixedSizeStream::Take(std::string* out) { *out = std::move(buffer_); } +namespace { +// Get system alignment value for IO with mmap. +std::size_t GetMmapAlignment() { +#if defined(xgboost_IS_WIN) + SYSTEM_INFO sys_info; + GetSystemInfo(&sys_info); + // During testing, `sys_info.dwPageSize` is of size 4096 while `dwAllocationGranularity` is of + // size 65536. + return sys_info.dwAllocationGranularity; +#else + return getpagesize(); +#endif +} + +auto SystemErrorMsg() { + std::int32_t errsv = system::LastError(); + auto err = std::error_code{errsv, std::system_category()}; + return err.message(); +} +} // anonymous namespace + std::string LoadSequentialFile(std::string uri, bool stream) { auto OpenErr = [&uri]() { std::string msg; msg = "Opening " + uri + " failed: "; - msg += strerror(errno); + msg += SystemErrorMsg(); LOG(FATAL) << msg; }; @@ -155,5 +199,99 @@ std::string FileExtension(std::string fname, bool lower) { return ""; } } -} // namespace common -} // namespace xgboost + +struct PrivateMmapConstStream::MMAPFile { +#if defined(xgboost_IS_WIN) + HANDLE fd{INVALID_HANDLE_VALUE}; + HANDLE file_map{INVALID_HANDLE_VALUE}; +#else + std::int32_t fd{0}; +#endif + char* base_ptr{nullptr}; + std::size_t base_size{0}; + std::string path; +}; + +char* PrivateMmapConstStream::Open(std::string path, std::size_t offset, std::size_t length) { + if (length == 0) { + return nullptr; + } + +#if defined(xgboost_IS_WIN) + HANDLE fd = CreateFile(path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, nullptr); + CHECK_NE(fd, INVALID_HANDLE_VALUE) << "Failed to open:" << path << ". " << SystemErrorMsg(); +#else + auto fd = open(path.c_str(), O_RDONLY); + CHECK_GE(fd, 0) << "Failed to open:" << path << ". " << SystemErrorMsg(); +#endif + + char* ptr{nullptr}; + // Round down for alignment. + auto view_start = offset / GetMmapAlignment() * GetMmapAlignment(); + auto view_size = length + (offset - view_start); + +#if defined(__linux__) || defined(__GLIBC__) + int prot{PROT_READ}; + ptr = reinterpret_cast(mmap64(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); + CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg(); + handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)}); +#elif defined(xgboost_IS_WIN) + auto file_size = GetFileSize(fd, nullptr); + DWORD access = PAGE_READONLY; + auto map_file = CreateFileMapping(fd, nullptr, access, 0, file_size, nullptr); + access = FILE_MAP_READ; + std::uint32_t loff = static_cast(view_start); + std::uint32_t hoff = view_start >> 32; + CHECK(map_file) << "Failed to map: " << path << ". " << SystemErrorMsg(); + ptr = reinterpret_cast(MapViewOfFile(map_file, access, hoff, loff, view_size)); + CHECK_NE(ptr, nullptr) << "Failed to map: " << path << ". " << SystemErrorMsg(); + handle_.reset(new MMAPFile{fd, map_file, ptr, view_size, std::move(path)}); +#else + CHECK_LE(offset, std::numeric_limits::max()) + << "File size has exceeded the limit on the current system."; + int prot{PROT_READ}; + ptr = reinterpret_cast(mmap(nullptr, view_size, prot, MAP_PRIVATE, fd, view_start)); + CHECK_NE(ptr, MAP_FAILED) << "Failed to map: " << path << ". " << SystemErrorMsg(); + handle_.reset(new MMAPFile{fd, ptr, view_size, std::move(path)}); +#endif // defined(__linux__) + + ptr += (offset - view_start); + return ptr; +} + +PrivateMmapConstStream::PrivateMmapConstStream(std::string path, std::size_t offset, + std::size_t length) + : MemoryFixSizeBuffer{}, handle_{nullptr} { + this->p_buffer_ = Open(std::move(path), offset, length); + this->buffer_size_ = length; +} + +PrivateMmapConstStream::~PrivateMmapConstStream() { + CHECK(handle_); +#if defined(xgboost_IS_WIN) + if (p_buffer_) { + CHECK(UnmapViewOfFile(handle_->base_ptr)) "Faled to call munmap: " << SystemErrorMsg(); + } + if (handle_->fd != INVALID_HANDLE_VALUE) { + CHECK(CloseHandle(handle_->fd)) << "Failed to close handle: " << SystemErrorMsg(); + } + if (handle_->file_map != INVALID_HANDLE_VALUE) { + CHECK(CloseHandle(handle_->file_map)) << "Failed to close mapping object: " << SystemErrorMsg(); + } +#else + if (handle_->base_ptr) { + CHECK_NE(munmap(handle_->base_ptr, handle_->base_size), -1) + << "Faled to call munmap: " << handle_->path << ". " << SystemErrorMsg(); + } + if (handle_->fd != 0) { + CHECK_NE(close(handle_->fd), -1) + << "Faled to close: " << handle_->path << ". " << SystemErrorMsg(); + } +#endif +} +} // namespace xgboost::common + +#if defined(xgboost_IS_WIN) +#undef xgboost_IS_WIN +#endif // defined(xgboost_IS_WIN) diff --git a/src/common/io.h b/src/common/io.h index 2dd593c60..ab408dec1 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -1,5 +1,5 @@ -/*! - * Copyright by XGBoost Contributors 2014-2022 +/** + * Copyright 2014-2023, XGBoost Contributors * \file io.h * \brief general stream interface for serialization, I/O * \author Tianqi Chen @@ -10,9 +10,11 @@ #include #include -#include + #include #include +#include // for unique_ptr +#include // for string #include "common.h" @@ -127,6 +129,31 @@ inline std::string ReadAll(std::string const &path) { return content; } +/** + * @brief Private mmap file as a read-only stream. + * + * It can calculate alignment automatically based on system page size (or allocation + * granularity on Windows). + */ +class PrivateMmapConstStream : public MemoryFixSizeBuffer { + struct MMAPFile; + std::unique_ptr handle_; + + char* Open(std::string path, std::size_t offset, std::size_t length); + + public: + /** + * @brief Construct a private mmap stream. + * + * @param path File path. + * @param offset See the `offset` parameter of `mmap` for details. + * @param length See the `length` parameter of `mmap` for details. + */ + explicit PrivateMmapConstStream(std::string path, std::size_t offset, std::size_t length); + void Write(void const*, std::size_t) override { LOG(FATAL) << "Read-only stream."; } + + ~PrivateMmapConstStream() override; +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 088f1e98c..b4e42f2db 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -1,35 +1,34 @@ -/*! - * Copyright 2014-2022 by XGBoost Contributors +/** + * Copyright 2014-2023, XGBoost Contributors * \file sparse_page_source.h */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ -#include // std::min -#include -#include -#include -#include -#include +#include // for min +#include // async #include #include +#include +#include +#include +#include +#include "../common/common.h" +#include "../common/io.h" // for PrivateMmapStream, PadPageForMMAP +#include "../common/timer.h" // for Monitor, Timer +#include "adapter.h" +#include "dmlc/common.h" // OMPException +#include "proxy_dmatrix.h" +#include "sparse_page_writer.h" #include "xgboost/base.h" #include "xgboost/data.h" -#include "adapter.h" -#include "sparse_page_writer.h" -#include "proxy_dmatrix.h" - -#include "../common/common.h" -#include "../common/timer.h" - -namespace xgboost { -namespace data { +namespace xgboost::data { inline void TryDeleteCacheFile(const std::string& file) { if (std::remove(file.c_str()) != 0) { LOG(WARNING) << "Couldn't remove external memory cache file " << file - << "; you may want to remove it manually"; + << "; you may want to remove it manually"; } } @@ -54,6 +53,9 @@ struct Cache { std::string ShardName() { return ShardName(this->name, this->format); } + void Push(std::size_t n_bytes) { + offset.push_back(n_bytes); + } // The write is completed. void Commit() { @@ -95,56 +97,72 @@ class SparsePageSourceImpl : public BatchIteratorImpl { uint32_t n_batches_ {0}; std::shared_ptr cache_info_; - std::unique_ptr fo_; using Ring = std::vector>>; // A ring storing futures to data. Since the DMatrix iterator is forward only, so we // can pre-fetch data in a ring. std::unique_ptr ring_{new Ring}; + dmlc::OMPException exec_; + common::Monitor monitor_; bool ReadCache() { CHECK(!at_end_); if (!cache_info_->written) { return false; } - if (fo_) { - fo_.reset(); // flush the data to disk. + if (ring_->empty()) { ring_->resize(n_batches_); } // An heuristic for number of pre-fetched batches. We can make it part of BatchParam // to let user adjust number of pre-fetched batches when needed. - uint32_t constexpr kPreFetch = 4; + uint32_t constexpr kPreFetch = 3; size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; - size_t fetch_it = count_; + std::size_t fetch_it = count_; - for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { + exec_.Rethrow(); + + monitor_.Start("launch"); + for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { continue; } - auto const *self = this; // make sure it's const + auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); - ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { - common::Timer timer; - timer.Start(); - std::unique_ptr> fmt{CreatePageFormat("raw")}; - auto n = self->cache_info_->ShardName(); - size_t offset = self->cache_info_->offset.at(fetch_it); - std::unique_ptr fi{dmlc::SeekStream::CreateForRead(n.c_str())}; - fi->Seek(offset); - CHECK_EQ(fi->Tell(), offset); + ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { auto page = std::make_shared(); - CHECK(fmt->Read(page.get(), fi.get())); - LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds."; + this->exec_.Run([&] { + common::Timer timer; + timer.Start(); + std::unique_ptr> fmt{CreatePageFormat("raw")}; + auto n = self->cache_info_->ShardName(); + + std::uint64_t offset = self->cache_info_->offset.at(fetch_it); + std::uint64_t length = self->cache_info_->offset.at(fetch_it + 1) - offset; + + auto fi = std::make_unique(n, offset, length); + CHECK(fmt->Read(page.get(), fi.get())); + timer.Stop(); + + LOG(INFO) << "Read a page `" << typeid(S).name() << "` in " << timer.ElapsedSeconds() + << " seconds."; + }); return page; }); } + monitor_.Stop("launch"); + CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; + monitor_.Start("Wait"); page_ = (*ring_)[count_].get(); + monitor_.Stop("Wait"); + CHECK(!(*ring_)[count_].valid()); + exec_.Rethrow(); + return true; } @@ -153,25 +171,35 @@ class SparsePageSourceImpl : public BatchIteratorImpl { common::Timer timer; timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; - if (!fo_) { - auto n = cache_info_->ShardName(); - fo_.reset(dmlc::Stream::Create(n.c_str(), "w")); - } - auto bytes = fmt->Write(*page_, fo_.get()); - timer.Stop(); + auto name = cache_info_->ShardName(); + std::unique_ptr fo; + if (this->Iter() == 0) { + fo.reset(dmlc::Stream::Create(name.c_str(), "wb")); + } else { + fo.reset(dmlc::Stream::Create(name.c_str(), "ab")); + } + + auto bytes = fmt->Write(*page_, fo.get()); + + timer.Stop(); LOG(INFO) << static_cast(bytes) / 1024.0 / 1024.0 << " MB written in " << timer.ElapsedSeconds() << " seconds."; - cache_info_->offset.push_back(bytes); + cache_info_->Push(bytes); } virtual void Fetch() = 0; public: - SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, - uint32_t n_batches, std::shared_ptr cache) - : missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, - n_batches_{n_batches}, cache_info_{std::move(cache)} {} + SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache) + : missing_{missing}, + nthreads_{nthreads}, + n_features_{n_features}, + n_batches_{n_batches}, + cache_info_{std::move(cache)} { + monitor_.Init(typeid(S).name()); // not pretty, but works for basic profiling + } SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete; @@ -244,7 +272,7 @@ class SparsePageSource : public SparsePageSourceImpl { iter_{iter}, proxy_{proxy} { if (!cache_info_->written) { iter_.Reset(); - CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch."; + CHECK(iter_.Next()) << "Must have at least 1 batch."; } this->Fetch(); } @@ -259,6 +287,7 @@ class SparsePageSource : public SparsePageSourceImpl { } if (at_end_) { + CHECK_EQ(cache_info_->offset.size(), n_batches_ + 1); cache_info_->Commit(); if (n_batches_ != 0) { CHECK_EQ(count_, n_batches_); @@ -371,6 +400,5 @@ class SortedCSCPageSource : public PageSourceIncMixIn { this->Fetch(); } }; -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index f22fa172f..5f763fb93 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -146,27 +146,30 @@ class PoissonSampling : public thrust::binary_function gpair, +GradientBasedSample NoSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { - return {dmat->Info().num_row_, page_, gpair}; + auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); + return {dmat->Info().num_row_, page, gpair}; } -ExternalMemoryNoSampling::ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, - size_t n_rows, BatchParam batch_param) - : batch_param_{std::move(batch_param)}, - page_(new EllpackPageImpl(ctx->gpu_id, page->Cuts(), page->is_dense, page->row_stride, - n_rows)) {} +ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) + : batch_param_{std::move(batch_param)} {} GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { if (!page_concatenated_) { // Concatenate all the external memory ELLPACK pages into a single in-memory page. + page_.reset(nullptr); size_t offset = 0; for (auto& batch : dmat->GetBatches(ctx, batch_param_)) { auto page = batch.Impl(); + if (!page_) { + page_ = std::make_unique(ctx->gpu_id, page->Cuts(), page->is_dense, + page->row_stride, dmat->Info().num_row_); + } size_t num_elements = page_->Copy(ctx->gpu_id, page, offset); offset += num_elements; } @@ -175,8 +178,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx, return {dmat->Info().num_row_, page_.get(), gpair}; } -UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample) - : page_(page), subsample_(subsample) {} +UniformSampling::UniformSampling(BatchParam batch_param, float subsample) + : batch_param_{std::move(batch_param)}, subsample_(subsample) {} GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { @@ -185,7 +188,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::SpanCTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair()); - return {dmat->Info().num_row_, page_, gpair}; + auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); + return {dmat->Info().num_row_, page, gpair}; } ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows, @@ -236,12 +240,10 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; } -GradientBasedSampling::GradientBasedSampling(EllpackPageImpl const* page, - size_t n_rows, - const BatchParam&, +GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, float subsample) - : page_(page), - subsample_(subsample), + : subsample_(subsample), + batch_param_{std::move(batch_param)}, threshold_(n_rows + 1, 0.0f), grad_sum_(n_rows, 0.0f) {} @@ -252,18 +254,19 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); + auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); + // Perform Poisson sampling in place. thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); - return {n_rows, page_, gpair}; + return {n_rows, page, gpair}; } -ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( - size_t n_rows, - BatchParam batch_param, - float subsample) +ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows, + BatchParam batch_param, + float subsample) : batch_param_(std::move(batch_param)), subsample_(subsample), threshold_(n_rows + 1, 0.0f), @@ -273,16 +276,15 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling( GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { - size_t n_rows = dmat->Info().num_row_; + auto cuctx = ctx->CUDACtx(); + bst_row_t n_rows = dmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); // Perform Poisson sampling in place. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - thrust::counting_iterator(0), - dh::tbegin(gpair), - PoissonSampling(dh::ToSpan(threshold_), - threshold_index, + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), + thrust::counting_iterator(0), dh::tbegin(gpair), + PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); // Count the sampled rows. @@ -290,16 +292,15 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c // Compact gradient pairs. gpair_.resize(sample_rows); - thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); + thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); // Index the sample rows. - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero()); - thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(), - sample_row_index_.begin()); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), - sample_row_index_.begin(), - sample_row_index_.begin(), - ClearEmptyRows()); + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), + IsNonZero()); + thrust::exclusive_scan(cuctx->CTP(), sample_row_index_.begin(), sample_row_index_.end(), + sample_row_index_.begin()); + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), + sample_row_index_.begin(), ClearEmptyRows()); auto batch_iterator = dmat->GetBatches(ctx, batch_param_); auto first_page = (*batch_iterator.begin()).Impl(); @@ -317,13 +318,13 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; } -GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, - size_t n_rows, const BatchParam& batch_param, - float subsample, int sampling_method) { +GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows, + const BatchParam& batch_param, float subsample, + int sampling_method, bool is_external_memory) { + // The ctx is kept here for future development of stream-based operations. monitor_.Init("gradient_based_sampler"); bool is_sampling = subsample < 1.0; - bool is_external_memory = page->n_rows != n_rows; if (is_sampling) { switch (sampling_method) { @@ -331,24 +332,24 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c if (is_external_memory) { strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample)); } else { - strategy_.reset(new UniformSampling(page, subsample)); + strategy_.reset(new UniformSampling(batch_param, subsample)); } break; case TrainParam::kGradientBased: if (is_external_memory) { - strategy_.reset( - new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample)); + strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample)); } else { - strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample)); + strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample)); } break; - default:LOG(FATAL) << "unknown sampling method"; + default: + LOG(FATAL) << "unknown sampling method"; } } else { if (is_external_memory) { - strategy_.reset(new ExternalMemoryNoSampling(ctx, page, n_rows, batch_param)); + strategy_.reset(new ExternalMemoryNoSampling(batch_param)); } else { - strategy_.reset(new NoSampling(page)); + strategy_.reset(new NoSampling(batch_param)); } } } @@ -362,11 +363,11 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx, return sample; } -size_t GradientBasedSampler::CalculateThresholdIndex( - common::Span gpair, common::Span threshold, - common::Span grad_sum, size_t sample_rows) { - thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), - std::numeric_limits::max()); +size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, + common::Span threshold, + common::Span grad_sum, + size_t sample_rows) { + thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold), CombineGradientPair()); thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); @@ -379,6 +380,5 @@ size_t GradientBasedSampler::CalculateThresholdIndex( thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); return thrust::distance(dh::tbegin(grad_sum), min) + 1; } - }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index dafb98cfd..f89bf242e 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -1,5 +1,5 @@ -/*! - * Copyright 2019 by XGBoost Contributors +/** + * Copyright 2019-2023, XGBoost Contributors */ #pragma once #include @@ -32,37 +32,36 @@ class SamplingStrategy { /*! \brief No sampling in in-memory mode. */ class NoSampling : public SamplingStrategy { public: - explicit NoSampling(EllpackPageImpl const* page); - GradientBasedSample Sample(Context const* ctx, common::Span gpair, - DMatrix* dmat) override; - - private: - EllpackPageImpl const* page_; -}; - -/*! \brief No sampling in external memory mode. */ -class ExternalMemoryNoSampling : public SamplingStrategy { - public: - ExternalMemoryNoSampling(Context const* ctx, EllpackPageImpl const* page, size_t n_rows, - BatchParam batch_param); + explicit NoSampling(BatchParam batch_param); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: BatchParam batch_param_; - std::unique_ptr page_; +}; + +/*! \brief No sampling in external memory mode. */ +class ExternalMemoryNoSampling : public SamplingStrategy { + public: + explicit ExternalMemoryNoSampling(BatchParam batch_param); + GradientBasedSample Sample(Context const* ctx, common::Span gpair, + DMatrix* dmat) override; + + private: + BatchParam batch_param_; + std::unique_ptr page_{nullptr}; bool page_concatenated_{false}; }; /*! \brief Uniform sampling in in-memory mode. */ class UniformSampling : public SamplingStrategy { public: - UniformSampling(EllpackPageImpl const* page, float subsample); + UniformSampling(BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: - EllpackPageImpl const* page_; + BatchParam batch_param_; float subsample_; }; @@ -84,13 +83,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy { /*! \brief Gradient-based sampling in in-memory mode.. */ class GradientBasedSampling : public SamplingStrategy { public: - GradientBasedSampling(EllpackPageImpl const* page, size_t n_rows, const BatchParam& batch_param, - float subsample); + GradientBasedSampling(std::size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; private: - EllpackPageImpl const* page_; + BatchParam batch_param_; float subsample_; dh::caching_device_vector threshold_; dh::caching_device_vector grad_sum_; @@ -106,11 +104,11 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { private: BatchParam batch_param_; float subsample_; - dh::caching_device_vector threshold_; - dh::caching_device_vector grad_sum_; + dh::device_vector threshold_; + dh::device_vector grad_sum_; std::unique_ptr page_; dh::device_vector gpair_; - dh::caching_device_vector sample_row_index_; + dh::device_vector sample_row_index_; }; /*! \brief Draw a sample of rows from a DMatrix. @@ -124,8 +122,8 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { */ class GradientBasedSampler { public: - GradientBasedSampler(Context const* ctx, EllpackPageImpl const* page, size_t n_rows, - const BatchParam& batch_param, float subsample, int sampling_method); + GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param, + float subsample, int sampling_method, bool is_external_memory); /*! \brief Sample from a DMatrix based on the given gradient pairs. */ GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 5e5d2b5cb..2807dcfd7 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -176,7 +176,7 @@ struct GPUHistMakerDevice { Context const* ctx_; public: - EllpackPageImpl const* page; + EllpackPageImpl const* page{nullptr}; common::Span feature_types; BatchParam batch_param; @@ -205,41 +205,41 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; - - GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, - common::Span _feature_types, bst_uint _n_rows, + GPUHistMakerDevice(Context const* ctx, bool is_external_memory, + common::Span _feature_types, bst_row_t _n_rows, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, BatchParam _batch_param) : evaluator_{_param, n_features, ctx->gpu_id}, ctx_(ctx), - page(_page), feature_types{_feature_types}, param(std::move(_param)), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), batch_param(std::move(_batch_param)) { - sampler.reset(new GradientBasedSampler(ctx, page, _n_rows, batch_param, param.subsample, - param.sampling_method)); + sampler.reset(new GradientBasedSampler(ctx, _n_rows, batch_param, param.subsample, + param.sampling_method, is_external_memory)); if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; } - // Init histogram - hist.Init(ctx_->gpu_id, page->Cuts().TotalBins()); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); - feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, - dh::MaxSharedMemoryOptin(ctx_->gpu_id), - sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); } + void InitFeatureGroupsOnce() { + if (!feature_groups) { + CHECK(page); + feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, + dh::MaxSharedMemoryOptin(ctx_->gpu_id), + sizeof(GradientSumT))); + } + } + // Reset values for each update iteration - // Note that the column sampler must be passed by value because it is not - // thread safe void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { auto const& info = dmat->Info(); this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(), @@ -247,26 +247,30 @@ struct GPUHistMakerDevice { param.colsample_bytree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); - this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, - ctx_->gpu_id); - this->interaction_constraints.Reset(); if (d_gpair.size() != dh_gpair->Size()) { d_gpair.resize(dh_gpair->Size()); } - dh::safe_cuda(cudaMemcpyAsync( - d_gpair.data().get(), dh_gpair->ConstDevicePointer(), - dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); + dh::safe_cuda(cudaMemcpyAsync(d_gpair.data().get(), dh_gpair->ConstDevicePointer(), + dh_gpair->Size() * sizeof(GradientPair), + cudaMemcpyDeviceToDevice)); auto sample = sampler->Sample(ctx_, dh::ToSpan(d_gpair), dmat); page = sample.page; gpair = sample.gpair; + this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id); + quantiser.reset(new GradientQuantiser(this->gpair)); row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows)); + row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, sample.sample_rows)); + + // Init histogram + hist.Init(ctx_->gpu_id, page->Cuts().TotalBins()); hist.Reset(); + + this->InitFeatureGroupsOnce(); } GPUExpandEntry EvaluateRootSplit(GradientPairInt64 root_sum) { @@ -808,12 +812,11 @@ class GPUHistMaker : public TreeUpdater { collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; - auto page = (*dmat->GetBatches(ctx_, batch_param).begin()).Impl(); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); info_->feature_types.SetDevice(ctx_->gpu_id); maker.reset(new GPUHistMakerDevice( - ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param, - column_sampling_seed, info_->num_col_, batch_param)); + ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, + *param, column_sampling_seed, info_->num_col_, batch_param)); p_last_fmat_ = dmat; initialised_ = true; diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index feac8bd89..a64b60b80 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by XGBoost Contributors 2019 +/** + * Copyright 2019-2023, XGBoost Contributors */ #include @@ -9,8 +9,7 @@ #include "../helpers.h" #include "../filesystem.h" // dmlc::TemporaryDirectory -namespace xgboost { -namespace common { +namespace xgboost::common { TEST(MemoryFixSizeBuffer, Seek) { size_t constexpr kSize { 64 }; std::vector memory( kSize ); @@ -89,5 +88,54 @@ TEST(IO, LoadSequentialFile) { ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error); } -} // namespace common -} // namespace xgboost + +TEST(IO, PrivateMmapStream) { + dmlc::TemporaryDirectory tempdir; + auto path = tempdir.path + "/testfile"; + + // The page size on Linux is usually set to 4096, while the allocation granularity on + // the Windows machine where this test is writted is 65536. We span the test to cover + // all of them. + std::size_t n_batches{64}; + std::size_t multiplier{2048}; + + std::vector> batches; + std::vector offset{0ul}; + + using T = std::int32_t; + + { + std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; + for (std::size_t i = 0; i < n_batches; ++i) { + std::size_t size = (i + 1) * multiplier; + std::vector data(size, 0); + std::iota(data.begin(), data.end(), i * i); + + fo->Write(static_cast(data.size())); + fo->Write(data.data(), data.size() * sizeof(T)); + + std::size_t bytes = sizeof(std::uint64_t) + data.size() * sizeof(T); + offset.push_back(bytes); + + batches.emplace_back(std::move(data)); + } + } + + // Turn size info offset + std::partial_sum(offset.begin(), offset.end(), offset.begin()); + + for (std::size_t i = 0; i < n_batches; ++i) { + std::size_t off = offset[i]; + std::size_t n = offset.at(i + 1) - offset[i]; + std::unique_ptr fi{std::make_unique(path, off, n)}; + std::vector data; + + std::uint64_t size{0}; + fi->Read(&size); + data.resize(size); + + fi->Read(data.data(), size * sizeof(T)); + ASSERT_EQ(data, batches[i]); + } +} +} // namespace xgboost::common diff --git a/tests/cpp/histogram_helpers.h b/tests/cpp/histogram_helpers.h index 127f6fe44..6774f531c 100644 --- a/tests/cpp/histogram_helpers.h +++ b/tests/cpp/histogram_helpers.h @@ -2,6 +2,10 @@ #include "../../src/data/ellpack_page.cuh" #endif +#include // for SparsePage + +#include "./helpers.h" // for RandomDataGenerator + namespace xgboost { #if defined(__CUDACC__) namespace { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 95ae02aee..26ddfd8cc 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -39,7 +39,8 @@ void VerifySampling(size_t page_size, EXPECT_NE(page->n_rows, kRows); } - GradientBasedSampler sampler(&ctx, page, kRows, param, subsample, sampling_method); + GradientBasedSampler sampler(&ctx, kRows, param, subsample, sampling_method, + !fixed_size_sampling); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); if (fixed_size_sampling) { @@ -93,7 +94,7 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); EXPECT_NE(page->n_rows, kRows); - GradientBasedSampler sampler(&ctx, page, kRows, param, kSubsample, TrainParam::kUniform); + GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); auto sampled_page = sample.page; EXPECT_EQ(sample.sample_rows, kRows); @@ -141,7 +142,8 @@ TEST(GradientBasedSampler, GradientBasedSampling) { constexpr size_t kPageSize = 0; constexpr float kSubsample = 0.8; constexpr int kSamplingMethod = TrainParam::kGradientBased; - VerifySampling(kPageSize, kSubsample, kSamplingMethod); + constexpr bool kFixedSizeSampling = true; + VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 1bd4ece20..fd3034db5 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -92,8 +92,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; Context ctx{MakeCUDACtx(0)}; - GPUHistMakerDevice maker(&ctx, page.get(), {}, kNRows, param, kNCols, kNCols, - batch_param); + GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, + kNCols, kNCols, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); @@ -106,9 +106,15 @@ void TestBuildHist(bool use_shared_memory_histograms) { thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); + + maker.hist.Init(0, page->Cuts().TotalBins()); maker.hist.AllocateHistograms({0}); + maker.gpair = gpair.DeviceSpan(); maker.quantiser.reset(new GradientQuantiser(maker.gpair)); + maker.page = page.get(); + + maker.InitFeatureGroupsOnce(); BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), gpair.DeviceSpan(), @@ -126,8 +132,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { std::vector solution = GetHostHistGpair(); for (size_t i = 0; i < h_result.size(); ++i) { auto result = maker.quantiser->ToFloatingPoint(h_result[i]); - EXPECT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f); - EXPECT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f); + ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f); + ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f); } } diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 523dbf931..610c717a9 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -305,7 +305,7 @@ class IterForDMatrixTest(xgb.core.DataIter): self._labels = [rng.randn(self.rows)] * self.BATCHES self.it = 0 # set iterator to 0 - super().__init__() + super().__init__(cache_prefix=None) def as_array(self): import cudf diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 0590a4954..24c117f15 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -64,7 +64,8 @@ def run_data_iterator( subsample_rate = 0.8 if subsample else 1.0 it = IteratorForTest( - *make_batches(n_samples_per_batch, n_features, n_batches, use_cupy) + *make_batches(n_samples_per_batch, n_features, n_batches, use_cupy), + cache="cache" ) if n_batches == 0: with pytest.raises(ValueError, match="1 batch"):