diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0cabffcad..5fbf479c5 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -72,6 +72,7 @@ OBJECTS= \ $(PKGROOT)/src/data/gradient_index_page_source.o \ $(PKGROOT)/src/data/gradient_index_format.o \ $(PKGROOT)/src/data/sparse_page_dmatrix.o \ + $(PKGROOT)/src/data/sparse_page_source.o \ $(PKGROOT)/src/data/proxy_dmatrix.o \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index c49006c5e..a5a5c131e 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -72,6 +72,7 @@ OBJECTS= \ $(PKGROOT)/src/data/gradient_index_page_source.o \ $(PKGROOT)/src/data/gradient_index_format.o \ $(PKGROOT)/src/data/sparse_page_dmatrix.o \ + $(PKGROOT)/src/data/sparse_page_source.o \ $(PKGROOT)/src/data/proxy_dmatrix.o \ $(PKGROOT)/src/data/iterative_dmatrix.o \ $(PKGROOT)/src/predictor/predictor.o \ diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 6f8c818c8..6319e6480 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -50,7 +50,7 @@ class MetaInfo { static constexpr uint64_t kNumField = 12; /*! \brief number of rows in the data */ - uint64_t num_row_{0}; // NOLINT + bst_idx_t num_row_{0}; // NOLINT /*! \brief number of columns in the data */ uint64_t num_col_{0}; // NOLINT /*! \brief number of nonzero entries in the data */ @@ -535,10 +535,11 @@ class DMatrix { template [[nodiscard]] bool PageExists() const; - // the following are column meta data, should be able to answer them fast. - /*! \return Whether the data columns single column block. */ + /** + * @return Whether the data columns single column block. + */ [[nodiscard]] virtual bool SingleColBlock() const = 0; - /*! \brief virtual destructor */ + virtual ~DMatrix(); /** @@ -600,34 +601,34 @@ class DMatrix { int nthread, bst_bin_t max_bin); /** - * \brief Create an external memory DMatrix with callbacks. + * @brief Create an external memory DMatrix with callbacks. * - * \tparam DataIterHandle External iterator type, defined in C API. - * \tparam DMatrixHandle DMatrix handle, defined in C API. - * \tparam DataIterResetCallback Callback for reset, prototype defined in C API. - * \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API. + * @tparam DataIterHandle External iterator type, defined in C API. + * @tparam DMatrixHandle DMatrix handle, defined in C API. + * @tparam DataIterResetCallback Callback for reset, prototype defined in C API. + * @tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API. * - * \param iter External data iterator - * \param proxy A hanlde to ProxyDMatrix - * \param reset Callback for reset - * \param next Callback for next - * \param missing Value that should be treated as missing. - * \param nthread number of threads used for initialization. - * \param cache Prefix of cache file path. + * @param iter External data iterator + * @param proxy A hanlde to ProxyDMatrix + * @param reset Callback for reset + * @param next Callback for next + * @param missing Value that should be treated as missing. + * @param nthread number of threads used for initialization. + * @param cache Prefix of cache file path. + * @param on_host Used for GPU, whether the data should be cached on host memory. * - * \return A created external memory DMatrix. + * @return A created external memory DMatrix. */ - template - static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy, - DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, - int32_t nthread, std::string cache); + template + static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset, + XGDMatrixCallbackNext* next, float missing, int32_t nthread, + std::string cache, bool on_host); virtual DMatrix *Slice(common::Span ridxs) = 0; /** - * \brief Slice a DMatrix by columns. + * @brief Slice a DMatrix by columns. * * @param num_slices Total number of slices * @param slice_id Index of the current slice diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 40433c28e..177448e0e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -503,18 +503,29 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes ---------- cache_prefix : Prefix to the cache files, only used in external memory. + release_data : Whether the iterator should release the data during iteration. Set it to True if the data transformation (converting data to np.float32 type) is memory intensive. Otherwise, if the transformation is computation intensive then we can keep the cache. + on_host : + Whether the data should be cached on host memory instead of harddrive when using + GPU with external memory. If set to true, then the "external memory" would + simply be CPU (host) memory. This is still working in progress, not ready for + test yet. + """ def __init__( - self, cache_prefix: Optional[str] = None, release_data: bool = True + self, + cache_prefix: Optional[str] = None, + release_data: bool = True, + on_host: bool = False, ) -> None: self.cache_prefix = cache_prefix + self.on_host = on_host self._handle = _ProxyDMatrix() self._exception: Optional[Exception] = None @@ -905,12 +916,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: it = iterator - args = { - "missing": self.missing, - "nthread": self.nthread, - "cache_prefix": it.cache_prefix if it.cache_prefix else "", - } - args_cstr = from_pystr_to_cstr(json.dumps(args)) + args = make_jcargs( + missing=self.missing, + nthread=self.nthread, + cache_prefix=it.cache_prefix if it.cache_prefix else "", + on_host=it.on_host, + ) handle = ctypes.c_void_p() reset_callback, next_callback = it.get_callbacks(enable_categorical) ret = _LIB.XGDMatrixCreateFromCallback( @@ -918,7 +929,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m it.proxy.handle, reset_callback, next_callback, - args_cstr, + args, ctypes.byref(handle), ) it.reraise() diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 64e2a9170..482da68c9 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -198,19 +198,20 @@ def skip_win() -> PytestSkip: class IteratorForTest(xgb.core.DataIter): """Iterator for testing streaming DMatrix. (external memory, quantile)""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, X: Sequence, y: Sequence, w: Optional[Sequence], cache: Optional[str], + on_host: bool = False, ) -> None: assert len(X) == len(y) self.X = X self.y = y self.w = w self.it = 0 - super().__init__(cache_prefix=cache) + super().__init__(cache_prefix=cache, on_host=on_host) def next(self, input_data: Callable) -> int: if self.it == len(self.X): @@ -367,7 +368,11 @@ class TestDataset: weight.append(w) it = IteratorForTest( - predictor, response, weight if weight else None, cache="cache" + predictor, + response, + weight if weight else None, + cache="cache", + on_host=False, ) return xgb.DMatrix(it) diff --git a/python-package/xgboost/testing/data_iter.py b/python-package/xgboost/testing/data_iter.py index 42a9dfca0..f51b303d5 100644 --- a/python-package/xgboost/testing/data_iter.py +++ b/python-package/xgboost/testing/data_iter.py @@ -22,7 +22,7 @@ def run_mixed_sparsity(device: str) -> None: X = [cp.array(batch) for batch in X] - it = tm.IteratorForTest(X, y, None, None) + it = tm.IteratorForTest(X, y, None, None, on_host=False) Xy_0 = xgboost.QuantileDMatrix(it) X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True) diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index c0c014167..1e0b9b0d1 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -207,6 +207,7 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None: it = tm.IteratorForTest( *tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy), cache="cache", + on_host=False, ) Xy: xgb.DMatrix = xgb.DMatrix(it) xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 45160baea..3559660dd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -298,13 +298,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy auto missing = GetMissing(jconfig); std::string cache = RequiredArg(jconfig, "cache_prefix", __func__); auto n_threads = OptionalArg(jconfig, "nthread", 0); + auto on_host = OptionalArg(jconfig, "on_host", false); xgboost_CHECK_C_ARG_PTR(next); xgboost_CHECK_C_ARG_PTR(reset); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr{ - xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)}; + xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache, on_host)}; API_END(); } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index f4fce42f8..7cd00f6f6 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -429,7 +429,7 @@ struct XGBDefaultDeviceAllocatorImpl : XGBBaseDeviceAllocator { } #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 XGBDefaultDeviceAllocatorImpl() - : SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()) {} + : SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()) {} #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 }; @@ -484,8 +484,8 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { } #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 XGBCachingDeviceAllocatorImpl() - : SuperT(rmm::cuda_stream_default, rmm::mr::get_current_device_resource()), - use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {} + : SuperT(rmm::cuda_stream_per_thread, rmm::mr::get_current_device_resource()), + use_cub_allocator_(!xgboost::GlobalConfigThreadLocalStore::Get()->use_rmm) {} #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 XGBOOST_DEVICE void construct(T *) {} // NOLINT private: diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 67114320b..601e63526 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -6,7 +6,7 @@ #ifndef XGBOOST_COMMON_ERROR_MSG_H_ #define XGBOOST_COMMON_ERROR_MSG_H_ -#include // for uint64_t +#include // for uint64_t #include // for numeric_limits #include // for string @@ -103,5 +103,11 @@ inline auto NoFederated() { return "XGBoost is not compiled with federated learn inline auto NoCategorical(std::string name) { return name + " doesn't support categorical features."; } + +inline void NoOnHost(bool on_host) { + if (on_host) { + LOG(FATAL) << "Caching on host memory is only available for GPU."; + } +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 867d671e2..2e24f68ff 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -163,7 +163,7 @@ class HistogramCuts { return vals[bin_idx - 1]; } - void SetDevice(DeviceOrd d) const { + void SetDevice(DeviceOrd d) { this->cut_ptrs_.SetDevice(d); this->cut_ptrs_.ConstDevicePointer(); diff --git a/src/data/data.cc b/src/data/data.cc index f37a10fa3..3f9c13fa5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -901,15 +901,12 @@ DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_p return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin); } -template -DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, - DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, - int32_t n_threads, - std::string cache) { - return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, n_threads, - cache); +template +DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback* reset, + XGDMatrixCallbackNext* next, float missing, int32_t n_threads, + std::string cache, bool on_host) { + return new data::SparsePageDMatrix{iter, proxy, reset, next, missing, n_threads, cache, on_host}; } template DMatrix* DMatrix::Create( - DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string); +template DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, + DataIterResetCallback* reset, + XGDMatrixCallbackNext* next, float missing, + int32_t n_threads, std::string, bool); template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&, diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc index d58364635..dc6d370e5 100644 --- a/src/data/ellpack_page.cc +++ b/src/data/ellpack_page.cc @@ -36,7 +36,7 @@ void EllpackPage::SetBaseRowId(std::size_t) { LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " "EllpackPage is required"; } -size_t EllpackPage::Size() const { +bst_idx_t EllpackPage::Size() const { LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " "EllpackPage is required"; return 0; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index bfbb7f076..81656284e 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -29,7 +29,7 @@ EllpackPage::~EllpackPage() = default; EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); } -size_t EllpackPage::Size() const { return impl_->Size(); } +[[nodiscard]] bst_idx_t EllpackPage::Size() const { return impl_->Size(); } void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); } @@ -91,13 +91,13 @@ __global__ void CompressBinEllpackKernel( // Construct an ELLPACK matrix with the given number of empty rows. EllpackPageImpl::EllpackPageImpl(DeviceOrd device, std::shared_ptr cuts, bool is_dense, - size_t row_stride, size_t n_rows) - : is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) { + bst_idx_t row_stride, bst_idx_t n_rows) + : is_dense(is_dense), cuts_(std::move(cuts)), row_stride{row_stride}, n_rows{n_rows} { monitor_.Init("ellpack_page"); dh::safe_cuda(cudaSetDevice(device.ordinal)); monitor_.Start("InitCompressedData"); - InitCompressedData(device); + this->InitCompressedData(device); monitor_.Stop("InitCompressedData"); } @@ -403,7 +403,7 @@ struct CopyPage { // Copy the data from the given EllpackPage to the current page. size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) { monitor_.Start("Copy"); - size_t num_elements = page->n_rows * page->row_stride; + bst_idx_t num_elements = page->n_rows * page->row_stride; CHECK_EQ(row_stride, page->row_stride); CHECK_EQ(NumSymbols(), page->NumSymbols()); CHECK_GE(n_rows * row_stride, offset + num_elements); @@ -461,16 +461,17 @@ struct CompactPage { }; // Compacts the data from the given EllpackPage into the current page. -void EllpackPageImpl::Compact(DeviceOrd device, EllpackPageImpl const* page, +void EllpackPageImpl::Compact(Context const* ctx, EllpackPageImpl const* page, common::Span row_indexes) { - monitor_.Start("Compact"); + monitor_.Start(__func__); CHECK_EQ(row_stride, page->row_stride); CHECK_EQ(NumSymbols(), page->NumSymbols()); CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size()); - gidx_buffer.SetDevice(device); - page->gidx_buffer.SetDevice(device); - dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes)); - monitor_.Stop("Compact"); + gidx_buffer.SetDevice(ctx->Device()); + page->gidx_buffer.SetDevice(ctx->Device()); + auto cuctx = ctx->CUDACtx(); + dh::LaunchN(page->n_rows, cuctx->Stream(), CompactPage(this, page, row_indexes)); + monitor_.Stop(__func__); } // Initialize the buffer to stored compressed features. @@ -551,7 +552,7 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device, } // Return the number of rows contained in this page. -size_t EllpackPageImpl::Size() const { return n_rows; } +[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; } // Return the memory cost for storing the compressed features. size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index a0fafbe74..04960458f 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -143,7 +143,7 @@ class EllpackPageImpl { * and the given number of rows. */ EllpackPageImpl(DeviceOrd device, std::shared_ptr cuts, - bool is_dense, size_t row_stride, size_t n_rows); + bool is_dense, bst_idx_t row_stride, bst_idx_t n_rows); /*! * \brief Constructor used for external memory. */ @@ -181,14 +181,14 @@ class EllpackPageImpl { /*! \brief Compact the given ELLPACK page into the current page. * - * @param device The GPU device to use. + * @param context The GPU context. * @param page The ELLPACK page to compact from. * @param row_indexes Row indexes for the compacted page. */ - void Compact(DeviceOrd device, EllpackPageImpl const* page, common::Span row_indexes); + void Compact(Context const* ctx, EllpackPageImpl const* page, common::Span row_indexes); /*! \return Number of instances in the page. */ - [[nodiscard]] size_t Size() const; + [[nodiscard]] bst_idx_t Size() const; /*! \brief Set the base row id for this page. */ void SetBaseRowId(std::size_t row_id) { @@ -231,7 +231,7 @@ class EllpackPageImpl { /*! \brief Whether or not if the matrix is dense. */ bool is_dense; /*! \brief Row length for ELLPACK. */ - size_t row_stride; + bst_idx_t row_stride; bst_idx_t base_rowid{0}; bst_idx_t n_rows{}; /*! \brief global index of histogram, which is stored in ELLPACK format. */ diff --git a/src/data/ellpack_page.h b/src/data/ellpack_page.h index 77d1124e0..246b48296 100644 --- a/src/data/ellpack_page.h +++ b/src/data/ellpack_page.h @@ -41,7 +41,7 @@ class EllpackPage { EllpackPage(EllpackPage&& that); /*! \return Number of instances in the page. */ - [[nodiscard]] size_t Size() const; + [[nodiscard]] bst_idx_t Size() const; /*! \brief Set the base row id for this page. */ void SetBaseRowId(std::size_t row_id); diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 3bf528ea8..059dd9f21 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -10,6 +10,7 @@ #include "../common/ref_resource_view.h" // for ReadVec, WriteVec #include "ellpack_page.cuh" // for EllpackPage #include "ellpack_page_raw_format.h" +#include "ellpack_page_source.h" namespace xgboost::data { DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); @@ -32,7 +33,6 @@ template return false; } - vec->SetDevice(DeviceOrd::CUDA(0)); vec->Resize(n); auto d_vec = vec->DeviceSpan(); dh::safe_cuda( @@ -54,6 +54,7 @@ template if (!fi->Read(&impl->row_stride)) { return false; } + impl->gidx_buffer.SetDevice(device_); if (!ReadDeviceVec(fi, &impl->gidx_buffer)) { return false; } @@ -73,6 +74,65 @@ template CHECK(!impl->gidx_buffer.ConstHostVector().empty()); bytes += common::WriteVec(fo, impl->gidx_buffer.HostVector()); bytes += fo->Write(impl->base_rowid); + dh::DefaultStream().Sync(); + return bytes; +} + +[[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const { + auto* impl = page->Impl(); + CHECK(this->cuts_->cut_values_.DeviceCanRead()); + impl->SetCuts(this->cuts_); + if (!fi->Read(&impl->n_rows)) { + return false; + } + if (!fi->Read(&impl->is_dense)) { + return false; + } + if (!fi->Read(&impl->row_stride)) { + return false; + } + + // Read vec + bst_idx_t n{0}; + if (!fi->Read(&n)) { + return false; + } + if (n != 0) { + impl->gidx_buffer.SetDevice(device_); + impl->gidx_buffer.Resize(n); + auto span = impl->gidx_buffer.DeviceSpan(); + if (!fi->Read(span.data(), span.size_bytes())) { + return false; + } + } + + if (!fi->Read(&impl->base_rowid)) { + return false; + } + + dh::DefaultStream().Sync(); + return true; +} + +[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, + EllpackHostCacheStream* fo) const { + bst_idx_t bytes{0}; + auto* impl = page.Impl(); + bytes += fo->Write(impl->n_rows); + bytes += fo->Write(impl->is_dense); + bytes += fo->Write(impl->row_stride); + + // Write vector + bst_idx_t n = impl->gidx_buffer.Size(); + bytes += fo->Write(n); + + if (!impl->gidx_buffer.Empty()) { + auto span = impl->gidx_buffer.ConstDeviceSpan(); + bytes += fo->Write(span.data(), span.size_bytes()); + } + bytes += fo->Write(impl->base_rowid); + + dh::DefaultStream().Sync(); return bytes; } } // namespace xgboost::data diff --git a/src/data/ellpack_page_raw_format.h b/src/data/ellpack_page_raw_format.h index 5825b4896..8c3f89f0c 100644 --- a/src/data/ellpack_page_raw_format.h +++ b/src/data/ellpack_page_raw_format.h @@ -20,15 +20,22 @@ class HistogramCuts; } namespace xgboost::data { + +class EllpackHostCacheStream; + class EllpackPageRawFormat : public SparsePageFormat { std::shared_ptr cuts_; + DeviceOrd device_; public: - explicit EllpackPageRawFormat(std::shared_ptr cuts) - : cuts_{std::move(cuts)} {} + explicit EllpackPageRawFormat(std::shared_ptr cuts, DeviceOrd device) + : cuts_{std::move(cuts)}, device_{device} {} [[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override; [[nodiscard]] std::size_t Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) override; + + [[nodiscard]] bool Read(EllpackPage* page, EllpackHostCacheStream* fi) const; + [[nodiscard]] std::size_t Write(const EllpackPage& page, EllpackHostCacheStream* fo) const; }; #if !defined(XGBOOST_USE_CUDA) diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 66500d58b..f53ae3ef1 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -1,29 +1,161 @@ /** * Copyright 2019-2024, XGBoost contributors */ -#include +#include // for host_vector -#include "ellpack_page.cuh" -#include "ellpack_page.h" // for EllpackPage +#include // for size_t +#include // for int8_t, uint64_t, uint32_t +#include // for shared_ptr, make_unique, make_shared +#include // for move + +#include "../common/common.h" // for safe_cuda +#include "../common/cuda_pinned_allocator.h" // for pinned_allocator +#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream +#include "ellpack_page.cuh" // for EllpackPageImpl +#include "ellpack_page.h" // for EllpackPage #include "ellpack_page_source.h" +#include "xgboost/base.h" // for bst_idx_t namespace xgboost::data { -void EllpackPageSource::Fetch() { - dh::safe_cuda(cudaSetDevice(device_.ordinal)); +struct EllpackHostCache { + thrust::host_vector> cache; + + void Resize(std::size_t n, dh::CUDAStreamView stream) { + stream.Sync(); // Prevent partial copy inside resize. + cache.resize(n); + } +}; + +class EllpackHostCacheStreamImpl { + std::shared_ptr cache_; + bst_idx_t cur_ptr_{0}; + bst_idx_t bound_{0}; + + public: + explicit EllpackHostCacheStreamImpl(std::shared_ptr cache) + : cache_{std::move(cache)} {} + + [[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes) { + auto n = cur_ptr_ + n_bytes; + if (n > cache_->cache.size()) { + cache_->Resize(n, dh::DefaultStream()); + } + dh::safe_cuda(cudaMemcpyAsync(cache_->cache.data() + cur_ptr_, ptr, n_bytes, cudaMemcpyDefault, + dh::DefaultStream())); + cur_ptr_ = n; + return n_bytes; + } + + [[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes) { + CHECK_LE(cur_ptr_ + n_bytes, bound_); + dh::safe_cuda(cudaMemcpyAsync(ptr, cache_->cache.data() + cur_ptr_, n_bytes, cudaMemcpyDefault, + dh::DefaultStream())); + cur_ptr_ += n_bytes; + return true; + } + + [[nodiscard]] bst_idx_t Tell() const { return cur_ptr_; } + void Seek(bst_idx_t offset_bytes) { cur_ptr_ = offset_bytes; } + void Bound(bst_idx_t offset_bytes) { + CHECK_LE(offset_bytes, cache_->cache.size()); + this->bound_ = offset_bytes; + } +}; + +/** + * EllpackHostCacheStream + */ + +EllpackHostCacheStream::EllpackHostCacheStream(std::shared_ptr cache) + : p_impl_{std::make_unique(std::move(cache))} {} + +EllpackHostCacheStream::~EllpackHostCacheStream() = default; + +[[nodiscard]] bst_idx_t EllpackHostCacheStream::Write(void const* ptr, bst_idx_t n_bytes) { + return this->p_impl_->Write(ptr, n_bytes); +} + +[[nodiscard]] bool EllpackHostCacheStream::Read(void* ptr, bst_idx_t n_bytes) { + return this->p_impl_->Read(ptr, n_bytes); +} + +[[nodiscard]] bst_idx_t EllpackHostCacheStream::Tell() const { return this->p_impl_->Tell(); } + +void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); } + +void EllpackHostCacheStream::Bound(bst_idx_t offset_bytes) { this->p_impl_->Bound(offset_bytes); } + +/** + * EllpackFormatType + */ + +template typename F> +EllpackFormatStreamPolicy::EllpackFormatStreamPolicy() + : p_cache_{std::make_shared()} {} + +template typename F> +[[nodiscard]] std::unique_ptr::WriterT> +EllpackFormatStreamPolicy::CreateWriter(StringView, std::uint32_t iter) { + auto fo = std::make_unique(this->p_cache_); + if (iter == 0) { + CHECK(this->p_cache_->cache.empty()); + } else { + fo->Seek(this->p_cache_->cache.size()); + } + return fo; +} + +template typename F> +[[nodiscard]] std::unique_ptr::ReaderT> +EllpackFormatStreamPolicy::CreateReader(StringView, bst_idx_t offset, + bst_idx_t length) const { + auto fi = std::make_unique(this->p_cache_); + fi->Seek(offset); + fi->Bound(offset + length); + CHECK_EQ(fi->Tell(), offset); + return fi; +} + +// Instantiation +template EllpackFormatStreamPolicy::EllpackFormatStreamPolicy(); + +template std::unique_ptr< + typename EllpackFormatStreamPolicy::WriterT> +EllpackFormatStreamPolicy::CreateWriter(StringView name, + std::uint32_t iter); + +template std::unique_ptr< + typename EllpackFormatStreamPolicy::ReaderT> +EllpackFormatStreamPolicy::CreateReader( + StringView name, std::uint64_t offset, std::uint64_t length) const; + +/** + * EllpackPageSourceImpl + */ +template +void EllpackPageSourceImpl::Fetch() { + dh::safe_cuda(cudaSetDevice(this->Device().ordinal)); if (!this->ReadCache()) { - if (count_ != 0 && !sync_) { + if (this->count_ != 0 && !this->sync_) { // source is initialized to be the 0th page during construction, so when count_ is 0 // there's no need to increment the source. - ++(*source_); + ++(*this->source_); } // This is not read from cache so we still need it to be synced with sparse page source. - CHECK_EQ(count_, source_->Iter()); - auto const &csr = source_->Page(); + CHECK_EQ(this->count_, this->source_->Iter()); + auto const& csr = this->source_->Page(); this->page_.reset(new EllpackPage{}); - auto *impl = this->page_->Impl(); - *impl = EllpackPageImpl(device_, cuts_, *csr, is_dense_, row_stride_, feature_types_); - page_->SetBaseRowId(csr->base_rowid); + auto* impl = this->page_->Impl(); + *impl = EllpackPageImpl{this->Device(), this->GetCuts(), *csr, + is_dense_, row_stride_, feature_types_}; + this->page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } } + +// Instantiation +template void +EllpackPageSourceImpl>::Fetch(); +template void +EllpackPageSourceImpl>::Fetch(); } // namespace xgboost::data diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index f9aa128c7..7f50899b9 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -19,46 +19,127 @@ #include "xgboost/span.h" // for Span namespace xgboost::data { -class EllpackPageSource : public PageSourceIncMixIn { +// We need to decouple the storage and the view of the storage so that we can implement +// concurrent read. + +// Dummy type to hide CUDA calls from the host compiler. +struct EllpackHostCache; +// Pimpl to hide CUDA calls from the host compiler. +class EllpackHostCacheStreamImpl; + +// A view onto the actual cache implemented by `EllpackHostCache`. +class EllpackHostCacheStream { + std::unique_ptr p_impl_; + + public: + explicit EllpackHostCacheStream(std::shared_ptr cache); + ~EllpackHostCacheStream(); + + [[nodiscard]] bst_idx_t Write(void const* ptr, bst_idx_t n_bytes); + template + [[nodiscard]] std::enable_if_t, bst_idx_t> Write(T const& v) { + return this->Write(&v, sizeof(T)); + } + + [[nodiscard]] bool Read(void* ptr, bst_idx_t n_bytes); + + template + [[nodiscard]] auto Read(T* ptr) -> std::enable_if_t, bool> { + return this->Read(ptr, sizeof(T)); + } + + [[nodiscard]] bst_idx_t Tell() const; + void Seek(bst_idx_t offset_bytes); + // Limit the size of read. offset_bytes is the maximum offset that this stream can read + // to. An error is raised if the limited is exceeded. + void Bound(bst_idx_t offset_bytes); +}; + +template +class EllpackFormatPolicy { + std::shared_ptr cuts_{nullptr}; + DeviceOrd device_; + + public: + using FormatT = EllpackPageRawFormat; + + public: + [[nodiscard]] auto CreatePageFormat() const { + CHECK_EQ(cuts_->cut_values_.Device(), device_); + std::unique_ptr fmt{new EllpackPageRawFormat{cuts_, device_}}; + return fmt; + } + + void SetCuts(std::shared_ptr cuts, DeviceOrd device) { + std::swap(cuts_, cuts); + device_ = device; + CHECK(this->device_.IsCUDA()); + } + [[nodiscard]] auto GetCuts() { + CHECK(cuts_); + return cuts_; + } + [[nodiscard]] auto Device() const { return device_; } +}; + +template typename F> +class EllpackFormatStreamPolicy : public F { + std::shared_ptr p_cache_; + + public: + using WriterT = EllpackHostCacheStream; + using ReaderT = EllpackHostCacheStream; + + public: + EllpackFormatStreamPolicy(); + [[nodiscard]] std::unique_ptr CreateWriter(StringView name, std::uint32_t iter); + + [[nodiscard]] std::unique_ptr CreateReader(StringView name, bst_idx_t offset, + bst_idx_t length) const; +}; + +template +class EllpackPageSourceImpl : public PageSourceIncMixIn { + using Super = PageSourceIncMixIn; bool is_dense_; bst_idx_t row_stride_; BatchParam param_; common::Span feature_types_; - std::shared_ptr cuts_; - DeviceOrd device_; - - protected: - [[nodiscard]] SparsePageFormat* CreatePageFormat() const override { - cuts_->SetDevice(this->device_); - return new EllpackPageRawFormat{cuts_}; - } public: - EllpackPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features, - size_t n_batches, std::shared_ptr cache, BatchParam param, - std::shared_ptr cuts, bool is_dense, - bst_idx_t row_stride, common::Span feature_types, - std::shared_ptr source, DeviceOrd device) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false), + EllpackPageSourceImpl(float missing, std::int32_t nthreads, bst_feature_t n_features, + std::size_t n_batches, std::shared_ptr cache, BatchParam param, + std::shared_ptr cuts, bool is_dense, + bst_idx_t row_stride, common::Span feature_types, + std::shared_ptr source, DeviceOrd device) + : Super{missing, nthreads, n_features, n_batches, cache, false}, is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)}, - feature_types_{feature_types}, - cuts_{std::move(cuts)}, - device_{device} { + feature_types_{feature_types} { this->source_ = source; + cuts->SetDevice(device); + this->SetCuts(std::move(cuts), device); this->Fetch(); } void Fetch() final; }; +// Cache to host +using EllpackPageHostSource = + EllpackPageSourceImpl>; + +// Cache to disk +using EllpackPageSource = + EllpackPageSourceImpl>; + #if !defined(XGBOOST_USE_CUDA) -inline void EllpackPageSource::Fetch() { +template +inline void EllpackPageSourceImpl::Fetch() { // silent the warning about unused variables. (void)(row_stride_); (void)(is_dense_); - (void)(device_); common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index c525d51d1..fad6a3215 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -17,20 +17,35 @@ #include "xgboost/data.h" // for BatchParam, FeatureType #include "xgboost/span.h" // for Span -namespace xgboost { -namespace data { -class GradientIndexPageSource : public PageSourceIncMixIn { +namespace xgboost::data { +/** + * @brief Policy for creating ghist index format. The storage is default (disk). + */ +template +class GHistIndexFormatPolicy { + protected: common::HistogramCuts cuts_; + + public: + using FormatT = SparsePageFormat; + + public: + [[nodiscard]] auto CreatePageFormat() const { + std::unique_ptr fmt{new GHistIndexRawFormat{cuts_}}; + return fmt; + } + + void SetCuts(common::HistogramCuts cuts) { std::swap(cuts_, cuts); } +}; + +class GradientIndexPageSource + : public PageSourceIncMixIn< + GHistIndexMatrix, DefaultFormatStreamPolicy> { bool is_dense_; std::int32_t max_bin_per_feat_; common::Span feature_types_; double sparse_thresh_; - protected: - [[nodiscard]] SparsePageFormat* CreatePageFormat() const override { - return new GHistIndexRawFormat{cuts_}; - } - public: GradientIndexPageSource(float missing, std::int32_t nthreads, bst_feature_t n_features, size_t n_batches, std::shared_ptr cache, BatchParam param, @@ -39,17 +54,16 @@ class GradientIndexPageSource : public PageSourceIncMixIn { std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, std::isnan(param.sparse_thresh)), - cuts_{std::move(cuts)}, is_dense_{is_dense}, max_bin_per_feat_{param.max_bin}, feature_types_{feature_types}, sparse_thresh_{param.sparse_thresh} { this->source_ = source; + this->SetCuts(std::move(cuts)); this->Fetch(); } void Fetch() final; }; -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index ae4927a30..2e51ebb51 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -38,13 +38,17 @@ std::size_t NFeaturesDevice(DMatrixProxy *) // NOLINT #endif } // namespace detail - SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle, - DataIterResetCallback *reset, - XGDMatrixCallbackNext *next, float missing, - int32_t nthreads, std::string cache_prefix) - : proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing}, - cache_prefix_{std::move(cache_prefix)} { + DataIterResetCallback *reset, XGDMatrixCallbackNext *next, + float missing, int32_t nthreads, std::string cache_prefix, + bool on_host) + : proxy_{proxy_handle}, + iter_{iter_handle}, + reset_{reset}, + next_{next}, + missing_{missing}, + cache_prefix_{std::move(cache_prefix)}, + on_host_{on_host} { Context ctx; ctx.nthread = nthreads; @@ -103,8 +107,26 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p fmat_ctx_ = ctx; } +SparsePageDMatrix::~SparsePageDMatrix() { + // Clear out all resources before deleting the cache file. + sparse_page_source_.reset(); + std::visit([](auto &&ptr) { ptr.reset(); }, ellpack_page_source_); + column_source_.reset(); + sorted_column_source_.reset(); + ghist_index_source_.reset(); + + for (auto const &kv : cache_info_) { + CHECK(kv.second); + auto n = kv.second->ShardName(); + if (kv.second->OnHost()) { + continue; + } + TryDeleteCacheFile(n); + } +} + void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) { - auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".row.page", false, cache_prefix_, &cache_info_); // Don't use proxy DMatrix once this is already initialized, this allows users to // release the iterator and data. if (cache_info_.at(id)->written) { @@ -132,8 +154,9 @@ BatchSet SparsePageDMatrix::GetRowBatches() { } BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { - auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); + error::NoOnHost(on_host_); this->InitializeSparsePage(ctx); if (!column_source_) { column_source_ = @@ -146,8 +169,9 @@ BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { } BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) { - auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); + error::NoOnHost(on_host_); this->InitializeSparsePage(ctx); if (!sorted_column_source_) { sorted_column_source_ = std::make_shared( @@ -165,11 +189,12 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct CHECK_GE(param.max_bin, 2); } detail::CheckEmpty(batch_param_, param); - auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); + error::NoOnHost(on_host_); + auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_); if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { this->InitializeSparsePage(ctx); cache_info_.erase(id); - MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); + MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_); LOG(INFO) << "Generating new Gradient Index."; // Use sorted sketch for approx. auto sorted_sketch = param.regen; @@ -193,7 +218,7 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct #if !defined(XGBOOST_USE_CUDA) BatchSet SparsePageDMatrix::GetEllpackBatches(Context const *, const BatchParam &) { common::AssertGPUSupport(); - return BatchSet{BatchIterator{this->ellpack_page_source_}}; + return BatchSet{BatchIterator{nullptr}}; } #endif // !defined(XGBOOST_USE_CUDA) } // namespace xgboost::data diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 14a99370a..cfb35cdd5 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -1,7 +1,9 @@ /** * Copyright 2021-2024, XGBoost contributors */ -#include // for shared_ptr +#include // for shared_ptr +#include // for move +#include // for visit #include "../common/hist_util.cuh" #include "../common/hist_util.h" // for HistogramCuts @@ -19,13 +21,15 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, CHECK_GE(param.max_bin, 2); } detail::CheckEmpty(batch_param_, param); - auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); - size_t row_stride = 0; + auto id = MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_); + + bst_idx_t row_stride = 0; if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { this->InitializeSparsePage(ctx); // reinitialize the cache cache_info_.erase(id); - MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); + MakeCache(this, ".ellpack.page", on_host_, cache_prefix_, &cache_info_); + LOG(INFO) << "Generating new a Ellpack page."; std::shared_ptr cuts; if (!param.hess.empty()) { cuts = std::make_shared( @@ -41,17 +45,28 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, CHECK_NE(row_stride, 0); batch_param_ = param; - auto ft = this->info_.feature_types.ConstDeviceSpan(); - ellpack_page_source_.reset(); // make sure resource is released before making new ones. - ellpack_page_source_ = std::make_shared( - this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id), - param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, - ctx->Device()); + auto ft = this->Info().feature_types.ConstDeviceSpan(); + if (on_host_ && std::get_if(&ellpack_page_source_) == nullptr) { + ellpack_page_source_.emplace(nullptr); + } + std::visit( + [&](auto&& ptr) { + ptr.reset(); // make sure resource is released before making new ones. + using SourceT = typename std::remove_reference_t::element_type; + ptr = std::make_shared(this->missing_, ctx->Threads(), this->Info().num_col_, + this->n_batches_, cache_info_.at(id), param, + std::move(cuts), this->IsDense(), row_stride, ft, + this->sparse_page_source_, ctx->Device()); + }, + ellpack_page_source_); } else { CHECK(sparse_page_source_); - ellpack_page_source_->Reset(); + std::visit([&](auto&& ptr) { ptr->Reset(); }, this->ellpack_page_source_); } - return BatchSet{BatchIterator{this->ellpack_page_source_}}; + auto batch_set = + std::visit([this](auto&& ptr) { return BatchSet{BatchIterator{ptr}}; }, + this->ellpack_page_source_); + return batch_set; } } // namespace xgboost::data diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index fd31bc661..89c011f66 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -7,16 +7,20 @@ #ifndef XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ #define XGBOOST_DATA_SPARSE_PAGE_DMATRIX_H_ -#include -#include -#include -#include +#include // for uint32_t, int32_t +#include // for map +#include // for shared_ptr +#include // for stringstream +#include // for string +#include // for variant, visit -#include "ellpack_page_source.h" -#include "gradient_index_page_source.h" -#include "sparse_page_source.h" -#include "xgboost/data.h" +#include "ellpack_page_source.h" // for EllpackPageSource, EllpackPageHostSource +#include "gradient_index_page_source.h" // for GradientIndexPageSource +#include "sparse_page_source.h" // for SparsePageSource, Cache +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, MetaInfo #include "xgboost/logging.h" +#include "xgboost/span.h" // for Span namespace xgboost::data { /** @@ -70,6 +74,7 @@ class SparsePageDMatrix : public DMatrix { float missing_; Context fmat_ctx_; std::string cache_prefix_; + bool on_host_{false}; std::uint32_t n_batches_{0}; // sparse page is the source to other page types, we make a special member function. void InitializeSparsePage(Context const *ctx); @@ -79,29 +84,16 @@ class SparsePageDMatrix : public DMatrix { public: explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, int32_t nthreads, - std::string cache_prefix); + std::string cache_prefix, bool on_host = false); - ~SparsePageDMatrix() override { - // Clear out all resources before deleting the cache file. - sparse_page_source_.reset(); - ellpack_page_source_.reset(); - column_source_.reset(); - sorted_column_source_.reset(); - ghist_index_source_.reset(); - - for (auto const &kv : cache_info_) { - CHECK(kv.second); - auto n = kv.second->ShardName(); - TryDeleteCacheFile(n); - } - } + ~SparsePageDMatrix() override; [[nodiscard]] MetaInfo &Info() override; [[nodiscard]] const MetaInfo &Info() const override; [[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; } // The only DMatrix implementation that returns false. [[nodiscard]] bool SingleColBlock() const override { return false; } - DMatrix *Slice(common::Span) override { + DMatrix *Slice(common::Span) override { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; return nullptr; } @@ -111,7 +103,7 @@ class SparsePageDMatrix : public DMatrix { } [[nodiscard]] bool EllpackExists() const override { - return static_cast(ellpack_page_source_); + return std::visit([](auto &&ptr) { return static_cast(ptr); }, ellpack_page_source_); } [[nodiscard]] bool GHistIndexExists() const override { return static_cast(ghist_index_source_); @@ -138,7 +130,9 @@ class SparsePageDMatrix : public DMatrix { private: // source data pointers. std::shared_ptr sparse_page_source_; - std::shared_ptr ellpack_page_source_; + using EllpackDiskPtr = std::shared_ptr; + using EllpackHostPtr = std::shared_ptr; + std::variant ellpack_page_source_; std::shared_ptr column_source_; std::shared_ptr sorted_column_source_; std::shared_ptr ghist_index_source_; @@ -153,15 +147,16 @@ class SparsePageDMatrix : public DMatrix { /** * @brief Make cache if it doesn't exist yet. */ -inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, std::string prefix, +inline std::string MakeCache(SparsePageDMatrix *ptr, std::string format, bool on_host, + std::string prefix, std::map> *out) { auto &cache_info = *out; auto name = MakeId(prefix, ptr); auto id = name + format; auto it = cache_info.find(id); if (it == cache_info.cend()) { - cache_info[id].reset(new Cache{false, name, format}); - LOG(INFO) << "Make cache:" << cache_info[id]->ShardName() << std::endl; + cache_info[id].reset(new Cache{false, name, format, on_host}); + LOG(INFO) << "Make cache:" << cache_info[id]->ShardName(); } return id; } diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc new file mode 100644 index 000000000..363c46f2d --- /dev/null +++ b/src/data/sparse_page_source.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2021-2024, XGBoost Contributors + */ +#include "sparse_page_source.h" + +#include // for exists +#include // for string +#include // for remove +#include // for partial_sum + +namespace xgboost::data { +void Cache::Commit() { + if (!written) { + std::partial_sum(offset.begin(), offset.end(), offset.begin()); + written = true; + } +} + +void TryDeleteCacheFile(const std::string& file) { + // Don't throw, this is called in a destructor. + auto exists = std::filesystem::exists(file); + if (!exists) { + LOG(WARNING) << "External memory cache file " << file << " is missing."; + } + if (std::remove(file.c_str()) != 0) { + LOG(WARNING) << "Couldn't remove external memory cache file " << file + << "; you may want to remove it manually"; + } +} +} // namespace xgboost::data diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 427325a74..89aa86ace 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -8,11 +8,9 @@ #include // for min #include // for atomic #include // for uint64_t -#include // for remove #include // for future #include // for unique_ptr #include // for mutex -#include // for partial_sum #include // for string #include // for pair, move #include // for vector @@ -27,18 +25,12 @@ #include "proxy_dmatrix.h" // for DMatrixProxy #include "sparse_page_writer.h" // for SparsePageFormat #include "xgboost/base.h" // for bst_feature_t -#include "xgboost/data.h" // for SparsePage, CSCPage +#include "xgboost/data.h" // for SparsePage, CSCPage, SortedCSCPage #include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore #include "xgboost/logging.h" // for CHECK_EQ namespace xgboost::data { -inline void TryDeleteCacheFile(const std::string& file) { - if (std::remove(file.c_str()) != 0) { - // Don't throw, this is called in a destructor. - LOG(WARNING) << "Couldn't remove external memory cache file " << file - << "; you may want to remove it manually"; - } -} +void TryDeleteCacheFile(const std::string& file); /** * @brief Information about the cache including path and page offsets. @@ -46,13 +38,14 @@ inline void TryDeleteCacheFile(const std::string& file) { struct Cache { // whether the write to the cache is complete bool written; + bool on_host; std::string name; std::string format; // offset into binary cache file. std::vector offset; - Cache(bool w, std::string n, std::string fmt) - : written{w}, name{std::move(n)}, format{std::move(fmt)} { + Cache(bool w, std::string n, std::string fmt, bool on_host) + : written{w}, on_host{on_host}, name{std::move(n)}, format{std::move(fmt)} { offset.push_back(0); } @@ -64,6 +57,7 @@ struct Cache { [[nodiscard]] std::string ShardName() const { return ShardName(this->name, this->format); } + [[nodiscard]] bool OnHost() const { return on_host; } /** * @brief Record a page with size of n_bytes. */ @@ -83,12 +77,7 @@ struct Cache { /** * @brief Call this once the write for the cache is complete. */ - void Commit() { - if (!written) { - std::partial_sum(offset.begin(), offset.end(), offset.begin()); - written = true; - } - } + void Commit(); }; // Prevents multi-threaded call to `GetBatches`. @@ -146,10 +135,59 @@ class ExceHandler { }; /** - * @brief Base class for all page sources. Handles fetching, writing, and iteration. + * @brief Default implementation of the stream creater. + */ +template typename F> +class DefaultFormatStreamPolicy : public F { + public: + using WriterT = common::AlignedFileWriteStream; + using ReaderT = common::AlignedResourceReadStream; + + public: + std::unique_ptr CreateWriter(StringView name, std::uint32_t iter) { + std::unique_ptr fo; + if (iter == 0) { + fo = std::make_unique(name, "wb"); + } else { + fo = std::make_unique(name, "ab"); + } + return fo; + } + + std::unique_ptr CreateReader(StringView name, std::uint64_t offset, + std::uint64_t length) const { + return std::make_unique(std::string{name}, offset, length); + } +}; + +/** + * @brief Default implementatioin of the format creator. */ template -class SparsePageSourceImpl : public BatchIteratorImpl { +class DefaultFormatPolicy { + public: + using FormatT = SparsePageFormat; + + public: + auto CreatePageFormat() const { + std::unique_ptr fmt{::xgboost::data::CreatePageFormat("raw")}; + return fmt; + } +}; + +/** + * @brief Base class for all page sources. Handles fetching, writing, and iteration. + * + * The interface to external storage is divided into two types. The first one is the + * format, representing how to read and write the binary. The second part is where to + * store the binary cache. These policies are implemented in the `FormatStreamPolicy` + * policy class. The format policy controls how to create the format (the first part), and + * the stream policy decides where the stream should read from and write to (the second + * part). This way we can compose the polices and page types with ease. + */ +template > +class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPolicy { protected: // Prevents calling this iterator from multiple places(or threads). std::mutex single_threaded_; @@ -165,7 +203,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { // Index to the current page. std::uint32_t count_{0}; // Total number of batches. - std::uint32_t n_batches_{0}; + bst_idx_t n_batches_{0}; std::shared_ptr cache_info_; @@ -179,10 +217,6 @@ class SparsePageSourceImpl : public BatchIteratorImpl { ExceHandler exce_; common::Monitor monitor_; - [[nodiscard]] virtual SparsePageFormat* CreatePageFormat() const { - return ::xgboost::data::CreatePageFormat("raw"); - } - [[nodiscard]] bool ReadCache() { CHECK(!at_end_); if (!cache_info_->written) { @@ -196,8 +230,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::int32_t kPrefetches = 3; std::int32_t n_prefetches = std::min(nthreads_, kPrefetches); n_prefetches = std::max(n_prefetches, 1); - std::int32_t n_prefetch_batches = - std::min(static_cast(n_prefetches), n_batches_); + std::int32_t n_prefetch_batches = std::min(static_cast(n_prefetches), n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; CHECK_LE(n_prefetch_batches, kPrefetches); std::size_t fetch_it = count_; @@ -216,10 +249,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl { *GlobalConfigThreadLocalStore::Get() = config; auto page = std::make_shared(); this->exce_.Run([&] { - std::unique_ptr> fmt{this->CreatePageFormat()}; + std::unique_ptr fmt{this->CreatePageFormat()}; auto name = self->cache_info_->ShardName(); auto [offset, length] = self->cache_info_->View(fetch_it); - auto fi = std::make_unique(name, offset, length); + std::unique_ptr fi{ + this->CreateReader(name, offset, length)}; CHECK(fmt->Read(page.get(), fi.get())); }); return page; @@ -243,16 +277,11 @@ class SparsePageSourceImpl : public BatchIteratorImpl { CHECK(!cache_info_->written); common::Timer timer; timer.Start(); - std::unique_ptr> fmt{this->CreatePageFormat()}; + auto fmt{this->CreatePageFormat()}; auto name = cache_info_->ShardName(); - std::unique_ptr fo; - if (this->Iter() == 0) { - fo = std::make_unique(StringView{name}, "wb"); - } else { - fo = std::make_unique(StringView{name}, "ab"); - } - + std::unique_ptr fo{ + this->CreateWriter(StringView{name}, this->Iter())}; auto bytes = fmt->Write(*page_, fo.get()); timer.Stop(); @@ -265,9 +294,9 @@ class SparsePageSourceImpl : public BatchIteratorImpl { virtual void Fetch() = 0; public: - SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches, std::shared_ptr cache) - : workers_{nthreads}, + : workers_{std::max(2, std::min(nthreads, 16))}, // Don't use too many threads. missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, @@ -403,18 +432,19 @@ class SparsePageSource : public SparsePageSourceImpl { }; // A mixin for advancing the iterator. -template -class PageSourceIncMixIn : public SparsePageSourceImpl { +template > +class PageSourceIncMixIn : public SparsePageSourceImpl { protected: std::shared_ptr source_; - using Super = SparsePageSourceImpl; + using Super = SparsePageSourceImpl; // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page // so we avoid fetching it. bool sync_{true}; public: PageSourceIncMixIn(float missing, std::int32_t nthreads, bst_feature_t n_features, - std::uint32_t n_batches, std::shared_ptr cache, bool sync) + bst_idx_t n_batches, std::shared_ptr cache, bool sync) : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} [[nodiscard]] PageSourceIncMixIn& operator++() final { diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 7aefebeb6..d2031ca21 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -234,7 +234,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(Context const* ctx, // Compact the ELLPACK pages into the single sample page. thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); for (auto& batch : batch_iterator) { - page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_)); + page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); } return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; @@ -252,7 +252,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, auto cuctx = ctx->CUDACtx(); size_t n_rows = dmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( - gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); + ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); auto page = (*dmat->GetBatches(ctx, batch_param_).begin()).Impl(); @@ -279,21 +279,18 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c auto cuctx = ctx->CUDACtx(); bst_idx_t n_rows = dmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( - gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); - + ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); // Perform Poisson sampling in place. thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), thrust::counting_iterator(0), dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); - // Count the sampled rows. - size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); - + size_t sample_rows = + thrust::count_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), IsNonZero()); // Compact gradient pairs. gpair_.resize(sample_rows); thrust::copy_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero()); - // Index the sample rows. thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero()); @@ -301,18 +298,16 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c sample_row_index_.begin()); thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), sample_row_index_.begin(), ClearEmptyRows()); - auto batch_iterator = dmat->GetBatches(ctx, batch_param_); auto first_page = (*batch_iterator.begin()).Impl(); // Create a new ELLPACK page with empty rows. page_.reset(); // Release the device memory first before reallocating - page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), first_page->is_dense, + page_.reset(new EllpackPageImpl(ctx->Device(), first_page->CutsShared(), dmat->IsDense(), first_page->row_stride, sample_rows)); - // Compact the ELLPACK pages into the single sample page. - thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); + thrust::fill(cuctx->CTP(), dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0); for (auto& batch : batch_iterator) { - page_->Compact(ctx->Device(), batch.Impl(), dh::ToSpan(sample_row_index_)); + page_->Compact(ctx, batch.Impl(), dh::ToSpan(sample_row_index_)); } return {sample_rows, page_.get(), dh::ToSpan(gpair_)}; @@ -363,21 +358,24 @@ GradientBasedSample GradientBasedSampler::Sample(Context const* ctx, return sample; } -size_t GradientBasedSampler::CalculateThresholdIndex(common::Span gpair, +size_t GradientBasedSampler::CalculateThresholdIndex(Context const* ctx, + common::Span gpair, common::Span threshold, common::Span grad_sum, size_t sample_rows) { - thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits::max()); - thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold), - CombineGradientPair()); - thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1); - thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, + auto cuctx = ctx->CUDACtx(); + thrust::fill(cuctx->CTP(), dh::tend(threshold) - 1, dh::tend(threshold), + std::numeric_limits::max()); + thrust::transform(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold), + CombineGradientPair{}); + thrust::sort(cuctx->TP(), dh::tbegin(threshold), dh::tend(threshold) - 1); + thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum)); - thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum), + thrust::transform(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum), thrust::counting_iterator(0), dh::tbegin(grad_sum), SampleRateDelta(threshold, gpair.size(), sample_rows)); thrust::device_ptr min = - thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum)); + thrust::min_element(cuctx->CTP(), dh::tbegin(grad_sum), dh::tend(grad_sum)); return thrust::distance(dh::tbegin(grad_sum), min) + 1; } }; // namespace tree diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index f89bf242e..5a57e2ae8 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -129,9 +129,8 @@ class GradientBasedSampler { GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat); /*! \brief Calculate the threshold used to normalize sampling probabilities. */ - static size_t CalculateThresholdIndex(common::Span gpair, - common::Span threshold, - common::Span grad_sum, + static size_t CalculateThresholdIndex(Context const* ctx, common::Span gpair, + common::Span threshold, common::Span grad_sum, size_t sample_rows); private: diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 924a458d7..9d9687dda 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -200,7 +200,7 @@ TEST(EllpackPage, Compact) { auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); // Create an empty result page. - EllpackPageImpl result(FstCU(), page->CutsShared(), page->is_dense, page->row_stride, + EllpackPageImpl result(ctx.Device(), page->CutsShared(), page->is_dense, page->row_stride, kCompactedRows); // Compact batch pages into the result page. @@ -210,7 +210,7 @@ TEST(EllpackPage, Compact) { thrust::device_vector row_indexes_d = row_indexes_h; common::Span row_indexes_span(row_indexes_d.data().get(), kRows); for (auto& batch : dmat->GetBatches(&ctx, param)) { - result.Compact(FstCU(), batch.Impl(), row_indexes_span); + result.Compact(&ctx, batch.Impl(), row_indexes_span); } size_t current_row = 0; diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index c50c3bee2..d5ff721f8 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -4,15 +4,19 @@ #include #include -#include "../../../src/common/io.h" // for PrivateMmapConstStream, AlignedResourceReadStream... -#include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/ellpack_page.cuh" // for EllpackPage #include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat -#include "../../../src/tree/param.h" // TrainParam +#include "../../../src/data/ellpack_page_source.h" // for EllpackFormatStreamPolicy +#include "../../../src/tree/param.h" // for TrainParam #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" namespace xgboost::data { -TEST(EllpackPageRawFormat, IO) { +namespace { +template +void TestEllpackPageRawFormat() { + FormatStreamPolicy policy; + Context ctx{MakeCUDACtx(0)}; auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; @@ -21,24 +25,26 @@ TEST(EllpackPageRawFormat, IO) { std::string path = tmpdir.path + "/ellpack.page"; std::shared_ptr cuts; - for (auto const& page : m->GetBatches(&ctx, param)) { + for (auto const &page : m->GetBatches(&ctx, param)) { cuts = page.Impl()->CutsShared(); } - cuts->SetDevice(ctx.Device()); - auto format = std::make_unique(cuts); + ASSERT_EQ(cuts->cut_values_.Device(), ctx.Device()); + ASSERT_TRUE(cuts->cut_values_.DeviceCanRead()); + policy.SetCuts(cuts, ctx.Device()); + + std::unique_ptr format{policy.CreatePageFormat()}; std::size_t n_bytes{0}; { - auto fo = std::make_unique(StringView{path}, "wb"); + auto fo = policy.CreateWriter(StringView{path}, 0); for (auto const &ellpack : m->GetBatches(&ctx, param)) { n_bytes += format->Write(ellpack, fo.get()); } } EllpackPage page; - std::unique_ptr fi{ - std::make_unique(path.c_str(), 0, n_bytes)}; + auto fi = policy.CreateReader(StringView{path}, static_cast(0), n_bytes); ASSERT_TRUE(format->Read(&page, fi.get())); for (auto const &ellpack : m->GetBatches(&ctx, param)) { @@ -52,4 +58,13 @@ TEST(EllpackPageRawFormat, IO) { ASSERT_EQ(loaded->gidx_buffer.HostVector(), orig->gidx_buffer.HostVector()); } } +} // anonymous namespace + +TEST(EllpackPageRawFormat, DiskIO) { + TestEllpackPageRawFormat>(); +} + +TEST(EllpackPageRawFormat, HostIO) { + TestEllpackPageRawFormat>(); +} } // namespace xgboost::data diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 5783caa37..7200b96a9 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include // for DMatrix @@ -29,14 +29,10 @@ TEST(SparsePageDMatrix, EllpackPage) { EXPECT_EQ(n, dmat->Info().num_row_); auto path = - data::MakeId(tmp_file + ".cache", - dynamic_cast(dmat)) + - ".row.page"; + data::MakeId(tmp_file + ".cache", dynamic_cast(dmat)) + ".row.page"; EXPECT_TRUE(FileExists(path)); - path = - data::MakeId(tmp_file + ".cache", - dynamic_cast(dmat)) + - ".ellpack.page"; + path = data::MakeId(tmp_file + ".cache", dynamic_cast(dmat)) + + ".ellpack.page"; EXPECT_TRUE(FileExists(path)); delete dmat; @@ -82,8 +78,8 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) { std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries, filename); // Loop over the batches and count the records - int64_t batch_count = 0; - int64_t row_count = 0; + std::int64_t batch_count = 0; + bst_idx_t row_count = 0; for (const auto& batch : dmat->GetBatches(&ctx, param)) { EXPECT_LT(batch.Size(), dmat->Info().num_row_); batch_count++; @@ -138,50 +134,85 @@ TEST(SparsePageDMatrix, RetainEllpackPage) { } } -TEST(SparsePageDMatrix, EllpackPageContent) { - auto ctx = MakeCUDACtx(0); - constexpr size_t kRows = 6; - constexpr size_t kCols = 2; - constexpr size_t kPageSize = 1; +namespace { +// Test comparing external DMatrix with in-core DMatrix +class TestEllpackPageExt : public ::testing::TestWithParam> { + protected: + void Run(bool on_host, bool is_dense) { + float sparsity = is_dense ? 0.0 : 0.2; - // Create an in-memory DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + auto ctx = MakeCUDACtx(0); + constexpr bst_idx_t kRows = 64; + constexpr size_t kCols = 2; - // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + // Create an in-memory DMatrix. + auto p_fmat = RandomDataGenerator{kRows, kCols, sparsity}.GenerateDMatrix(true); - auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()}; - auto impl = (*dmat->GetBatches(&ctx, param).begin()).Impl(); - EXPECT_EQ(impl->base_rowid, 0); - EXPECT_EQ(impl->n_rows, kRows); - EXPECT_FALSE(impl->is_dense); - EXPECT_EQ(impl->row_stride, 2); - EXPECT_EQ(impl->Cuts().TotalBins(), 4); + // Create a DMatrix with multiple batches. + dmlc::TemporaryDirectory tmpdir; + auto prefix = tmpdir.path + "/cache"; - std::unique_ptr impl_ext; - size_t offset = 0; - for (auto& batch : dmat_ext->GetBatches(&ctx, param)) { - if (!impl_ext) { - impl_ext = std::make_unique( - batch.Impl()->gidx_buffer.Device(), batch.Impl()->CutsShared(), batch.Impl()->is_dense, - batch.Impl()->row_stride, kRows); + auto p_ext_fmat = RandomDataGenerator{kRows, kCols, sparsity} + .Batches(4) + .OnHost(on_host) + .GenerateSparsePageDMatrix(prefix, true); + + auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()}; + auto impl = (*p_fmat->GetBatches(&ctx, param).begin()).Impl(); + ASSERT_EQ(impl->base_rowid, 0); + ASSERT_EQ(impl->n_rows, kRows); + ASSERT_EQ(impl->is_dense, is_dense); + ASSERT_EQ(impl->row_stride, 2); + ASSERT_EQ(impl->Cuts().TotalBins(), 4); + + std::unique_ptr impl_ext; + size_t offset = 0; + for (auto& batch : p_ext_fmat->GetBatches(&ctx, param)) { + if (!impl_ext) { + impl_ext = std::make_unique( + batch.Impl()->gidx_buffer.Device(), batch.Impl()->CutsShared(), batch.Impl()->is_dense, + batch.Impl()->row_stride, kRows); + } + auto n_elems = impl_ext->Copy(ctx.Device(), batch.Impl(), offset); + offset += n_elems; } - auto n_elems = impl_ext->Copy(ctx.Device(), batch.Impl(), offset); - offset += n_elems; - } - EXPECT_EQ(impl_ext->base_rowid, 0); - EXPECT_EQ(impl_ext->n_rows, kRows); - EXPECT_FALSE(impl_ext->is_dense); - EXPECT_EQ(impl_ext->row_stride, 2); - EXPECT_EQ(impl_ext->Cuts().TotalBins(), 4); + ASSERT_EQ(impl_ext->base_rowid, 0); + ASSERT_EQ(impl_ext->n_rows, kRows); + ASSERT_EQ(impl_ext->is_dense, is_dense); + ASSERT_EQ(impl_ext->row_stride, 2); + ASSERT_EQ(impl_ext->Cuts().TotalBins(), 4); - std::vector buffer(impl->gidx_buffer.HostVector()); - std::vector buffer_ext(impl_ext->gidx_buffer.HostVector()); - EXPECT_EQ(buffer, buffer_ext); + std::vector buffer(impl->gidx_buffer.HostVector()); + std::vector buffer_ext(impl_ext->gidx_buffer.HostVector()); + ASSERT_EQ(buffer, buffer_ext); + } +}; +} // anonymous namespace + +TEST_P(TestEllpackPageExt, Data) { + auto [on_host, is_dense] = this->GetParam(); + this->Run(on_host, is_dense); } +INSTANTIATE_TEST_SUITE_P(EllpackPageExt, TestEllpackPageExt, ::testing::ValuesIn([]() { + std::vector> values; + for (auto on_host : {true, false}) { + for (auto is_dense : {true, false}) { + values.emplace_back(on_host, is_dense); + } + } + return values; + }()), + [](::testing::TestParamInfo const& info) { + auto on_host = std::get<0>(info.param); + auto is_dense = std::get<1>(info.param); + std::stringstream ss; + ss << (on_host ? "host" : "ext"); + ss << "_"; + ss << (is_dense ? "dense" : "sparse"); + return ss.str(); + }); + struct ReadRowFunction { EllpackDeviceAccessor matrix; int row; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index fc5ec3034..9b988f960 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -437,9 +437,9 @@ void RandomDataGenerator::GenerateCSR( #endif // defined(XGBOOST_USE_CUDA) } - std::unique_ptr dmat{ - DMatrix::Create(static_cast(iter.get()), iter->Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), Context{}.Threads(), prefix)}; + std::unique_ptr dmat{DMatrix::Create( + static_cast(iter.get()), iter->Proxy(), Reset, Next, + std::numeric_limits::quiet_NaN(), Context{}.Threads(), prefix, on_host_)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(dmat.get())) + ".row.page"; @@ -520,9 +520,9 @@ std::unique_ptr CreateSparsePageDMatrix(bst_idx_t n_samples, bst_featur CHECK_GE(n_samples, n_batches); NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches); - std::unique_ptr dmat{ - DMatrix::Create(static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), omp_get_max_threads(), prefix)}; + std::unique_ptr dmat{DMatrix::Create( + static_cast(&iter), iter.Proxy(), Reset, Next, + std::numeric_limits::quiet_NaN(), omp_get_max_threads(), prefix, false)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(dmat.get())) + ".row.page"; @@ -549,7 +549,7 @@ std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, std::unique_ptr dmat{ DMatrix::Create(static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), 0, prefix)}; + std::numeric_limits::quiet_NaN(), 0, prefix, false)}; auto row_page_path = data::MakeId(prefix, dynamic_cast(dmat.get())) + @@ -568,9 +568,9 @@ std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, return dmat; } -std::unique_ptr CreateSparsePageDMatrixWithRC( - size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, - const dmlc::TemporaryDirectory& tempdir) { +std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, + size_t page_size, bool deterministic, + const dmlc::TemporaryDirectory& tempdir) { if (!n_rows || !n_cols) { return nullptr; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index cb8852e1b..2211b2d00 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -241,6 +241,7 @@ class RandomDataGenerator { bst_bin_t bins_{0}; std::vector ft_; bst_cat_t max_cat_{32}; + bool on_host_{false}; Json ArrayInterfaceImpl(HostDeviceVector* storage, size_t rows, size_t cols) const; @@ -266,6 +267,10 @@ class RandomDataGenerator { n_batches_ = n_batches; return *this; } + RandomDataGenerator& OnHost(bool on_host) { + on_host_ = on_host; + return *this; + } RandomDataGenerator& Seed(uint64_t s) { seed_ = s; lcg_.Seed(seed_); diff --git a/tests/cpp/test_helpers.cc b/tests/cpp/test_helpers.cc index f582ba564..529f94e24 100644 --- a/tests/cpp/test_helpers.cc +++ b/tests/cpp/test_helpers.cc @@ -67,4 +67,30 @@ TEST(RandomDataGenerator, GenerateArrayInterfaceBatch) { CHECK_EQ(get(j_array["shape"][0]), kRows); CHECK_EQ(get(j_array["shape"][1]), kCols); } + +TEST(RandomDataGenerator, SparseDMatrix) { + bst_idx_t constexpr kCols{100}, kBatches{13}; + bst_idx_t n_samples{kBatches * 128}; + dmlc::TemporaryDirectory tmpdir; + auto prefix = tmpdir.path + "/cache"; + auto p_ext_fmat = + RandomDataGenerator{n_samples, kCols, 0.0}.Batches(kBatches).GenerateSparsePageDMatrix(prefix, + true); + + auto p_fmat = RandomDataGenerator{n_samples, kCols, 0.0}.GenerateDMatrix(true); + + SparsePage concat; + std::int32_t n_batches{0}; + for (auto const& page : p_ext_fmat->GetBatches()) { + concat.Push(page); + ++n_batches; + } + ASSERT_EQ(n_batches, kBatches); + ASSERT_EQ(concat.Size(), n_samples); + + for (auto const& page : p_fmat->GetBatches()) { + ASSERT_EQ(page.data.ConstHostVector(), concat.data.ConstHostVector()); + ASSERT_EQ(page.offset.ConstHostVector(), concat.offset.ConstHostVector()); + } +} } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 4325b6308..3a432fe67 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -21,20 +21,38 @@ def test_gpu_single_batch() -> None: strategies.integers(0, 8), strategies.booleans(), strategies.booleans(), + strategies.booleans(), ) -@settings(deadline=None, max_examples=10, print_blob=True) +@settings(deadline=None, max_examples=16, print_blob=True) def test_gpu_data_iterator( n_samples_per_batch: int, n_features: int, n_batches: int, subsample: bool, use_cupy: bool, + on_host: bool, ) -> None: run_data_iterator( - n_samples_per_batch, n_features, n_batches, "gpu_hist", subsample, use_cupy + n_samples_per_batch, + n_features, + n_batches, + "hist", + subsample=subsample, + device="cuda", + use_cupy=use_cupy, + on_host=on_host, ) def test_cpu_data_iterator() -> None: """Make sure CPU algorithm can handle GPU inputs""" - run_data_iterator(1024, 2, 3, "approx", False, True) + run_data_iterator( + 1024, + 2, + 3, + "approx", + device="cuda", + subsample=False, + use_cupy=True, + on_host=False, + ) diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index e665bcb10..1cc34f346 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -73,7 +73,9 @@ def run_data_iterator( n_batches: int, tree_method: str, subsample: bool, + device: str, use_cupy: bool, + on_host: bool, ) -> None: n_rounds = 2 # The test is more difficult to pass if the subsample rate is smaller as the root_sum @@ -83,7 +85,8 @@ def run_data_iterator( it = IteratorForTest( *make_batches(n_samples_per_batch, n_features, n_batches, use_cupy), - cache="cache" + cache="cache", + on_host=on_host, ) if n_batches == 0: with pytest.raises(ValueError, match="1 batch"): @@ -98,10 +101,11 @@ def run_data_iterator( "tree_method": tree_method, "max_depth": 2, "subsample": subsample_rate, + "device": device, "seed": 0, } - if tree_method == "gpu_hist": + if device.find("cuda") != -1: parameters["sampling_method"] = "gradient_based" results_from_it: Dict[str, Dict[str, List[float]]] = {} @@ -167,10 +171,24 @@ def test_data_iterator( subsample: bool, ) -> None: run_data_iterator( - n_samples_per_batch, n_features, n_batches, "approx", subsample, False + n_samples_per_batch, + n_features, + n_batches, + "approx", + subsample, + "cpu", + False, + False, ) run_data_iterator( - n_samples_per_batch, n_features, n_batches, "hist", subsample, False + n_samples_per_batch, + n_features, + n_batches, + "hist", + subsample, + "cpu", + False, + False, ) @@ -241,7 +259,7 @@ def test_cat_check() -> None: batches.append((X, y)) X, y = list(zip(*batches)) - it = tm.IteratorForTest(X, y, None, cache=None) + it = tm.IteratorForTest(X, y, None, cache=None, on_host=False) Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True) with pytest.raises(ValueError, match="categorical features"): @@ -254,7 +272,7 @@ def test_cat_check() -> None: with tempfile.TemporaryDirectory() as tmpdir: cache_path = os.path.join(tmpdir, "cache") - it = tm.IteratorForTest(X, y, None, cache=cache_path) + it = tm.IteratorForTest(X, y, None, cache=cache_path, on_host=False) Xy = xgb.DMatrix(it, enable_categorical=True) with pytest.raises(ValueError, match="categorical features"): xgb.train({"booster": "gblinear"}, Xy)