diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index e1bcbe99a..4c4d8d156 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -10,8 +10,13 @@ instead of Quantile DMatrix. The feature is not ready for production use yet. See :doc:`the tutorial ` for more details. + .. versionchanged:: 3.0.0 + + Added :py:class:`~xgboost.ExtMemQuantileDMatrix`. + """ +import argparse import os import tempfile from typing import Callable, List, Tuple @@ -43,30 +48,40 @@ def make_batches( class Iterator(xgboost.DataIter): """A custom iterator for loading files in batches.""" - def __init__(self, file_paths: List[Tuple[str, str]]) -> None: + def __init__(self, device: str, file_paths: List[Tuple[str, str]]) -> None: + self.device = device + self._file_paths = file_paths self._it = 0 - # XGBoost will generate some cache files under current directory with the prefix - # "cache" + # XGBoost will generate some cache files under the current directory with the + # prefix "cache" super().__init__(cache_prefix=os.path.join(".", "cache")) def load_file(self) -> Tuple[np.ndarray, np.ndarray]: + """Load a single batch of data.""" X_path, y_path = self._file_paths[self._it] - X = np.load(X_path) - y = np.load(y_path) + # When the `ExtMemQuantileDMatrix` is used, the device must match. This + # constraint will be relaxed in the future. + if self.device == "cpu": + X = np.load(X_path) + y = np.load(y_path) + else: + X = cp.load(X_path) + y = cp.load(y_path) + assert X.shape[0] == y.shape[0] return X, y def next(self, input_data: Callable) -> int: - """Advance the iterator by 1 step and pass the data to XGBoost. This function is - called by XGBoost during the construction of ``DMatrix`` + """Advance the iterator by 1 step and pass the data to XGBoost. This function + is called by XGBoost during the construction of ``DMatrix`` """ if self._it == len(self._file_paths): # return 0 to let XGBoost know this is the end of iteration return 0 - # input_data is a function passed in by XGBoost who has the similar signature to + # input_data is a function passed in by XGBoost and has the similar signature to # the ``DMatrix`` constructor. X, y = self.load_file() input_data(data=X, label=y) @@ -78,27 +93,74 @@ class Iterator(xgboost.DataIter): self._it = 0 -def main(tmpdir: str) -> xgboost.Booster: - # generate some random data for demo - files = make_batches(1024, 17, 31, tmpdir) - it = Iterator(files) +def hist_train(it: Iterator) -> None: + """The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for + faster initialization and lower memory usage. + + .. versionadded:: 3.0.0 + + """ # For non-data arguments, specify it here once instead of passing them by the `next` # method. - missing = np.nan - Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) - - # ``approx`` is also supported, but less efficient due to sketching. GPU behaves - # differently than CPU tree methods as it uses a hybrid approach. See tutorial in - # doc for details. + Xy = xgboost.ExtMemQuantileDMatrix(it, missing=np.nan, enable_categorical=False) booster = xgboost.train( - {"tree_method": "hist", "max_depth": 4}, + {"tree_method": "hist", "max_depth": 4, "device": it.device}, Xy, evals=[(Xy, "Train")], num_boost_round=10, ) - return booster + booster.predict(Xy) + + +def approx_train(it: Iterator) -> None: + """The approx tree method uses the basic `DMatrix`.""" + + # For non-data arguments, specify it here once instead of passing them by the `next` + # method. + Xy = xgboost.DMatrix(it, missing=np.nan, enable_categorical=False) + # ``approx`` is also supported, but less efficient due to sketching. It's + # recommended to use `hist` instead. + booster = xgboost.train( + {"tree_method": "approx", "max_depth": 4, "device": it.device}, + Xy, + evals=[(Xy, "Train")], + num_boost_round=10, + ) + booster.predict(Xy) + + +def main(tmpdir: str, args: argparse.Namespace) -> None: + """Entry point for training.""" + + # generate some random data for demo + files = make_batches( + n_samples_per_batch=1024, n_features=17, n_batches=31, tmpdir=tmpdir + ) + it = Iterator(args.device, files) + + hist_train(it) + approx_train(it) if __name__ == "__main__": - with tempfile.TemporaryDirectory() as tmpdir: - main(tmpdir) + parser = argparse.ArgumentParser() + parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu") + args = parser.parse_args() + if args.device == "cuda": + import cupy as cp + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + # It's important to use RMM for GPU-based external memory to improve performance. + # If XGBoost is not built with RMM support, a warning will be raised. + mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) + rmm.mr.set_current_device_resource(mr) + # Set the allocator for cupy as well. + cp.cuda.set_allocator(rmm_cupy_allocator) + # Make sure XGBoost is using RMM for all allocations. + with xgboost.config_context(use_rmm=True): + with tempfile.TemporaryDirectory() as tmpdir: + main(tmpdir, args) + else: + with tempfile.TemporaryDirectory() as tmpdir: + main(tmpdir, args) diff --git a/doc/jvm/xgboost_spark_migration.rst b/doc/jvm/xgboost_spark_migration.rst index cf291f83f..5d75457ec 100644 --- a/doc/jvm/xgboost_spark_migration.rst +++ b/doc/jvm/xgboost_spark_migration.rst @@ -55,9 +55,9 @@ When submitting the XGBoost application to the Spark cluster, you only need to s --jars xgboost-spark_2.12-3.0.0.jar \ ... \ -************** +*************** XGBoost Ranking -************** +*************** Learning to rank using XGBoostRegressor has been replaced by a dedicated `XGBoostRanker`, which is specifically designed to support ranking algorithms. diff --git a/doc/parameter.rst b/doc/parameter.rst index 49d42f838..5f1298808 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -230,15 +230,35 @@ Parameters for Tree Booster - ``one_output_per_tree``: One model for each target. - ``multi_output_tree``: Use multi-target trees. + +Parameters for Non-Exact Tree Methods +===================================== + * ``max_cached_hist_node``, [default = 65536] - Maximum number of cached nodes for histogram. + Maximum number of cached nodes for histogram. This can be used with the ``hist`` and the + ``approx`` tree methods. .. versionadded:: 2.0.0 - For most of the cases this parameter should not be set except for growing deep trees. After 3.0, this parameter affects GPU algorithms as well. + +* ``extmem_concat_pages``, [default = ``false``] + + This parameter is only used for the ``hist`` tree method with ``device=cuda`` and + ``subsample != 1.0``. Before 3.0, pages were always concatenated. + + .. versionadded:: 3.0.0 + + Whether the GPU-based ``hist`` tree method should concatenate the training data into a + single batch instead of fetching data on-demand when external memory is used. For GPU + devices that don't support address translation services, external memory training is + expensive. This parameter can be used in combination with subsampling to reduce overall + memory usage without significant overhead. See :doc:`/tutorials/external_memory` for + more information. + .. _cat-param: Parameters for Categorical Feature diff --git a/doc/python/python_api.rst b/doc/python/python_api.rst index 86da4fda0..11de9385b 100644 --- a/doc/python/python_api.rst +++ b/doc/python/python_api.rst @@ -26,6 +26,12 @@ Core Data Structure .. autoclass:: xgboost.QuantileDMatrix :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: xgboost.ExtMemQuantileDMatrix + :members: + :inherited-members: :show-inheritance: .. autoclass:: xgboost.Booster diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index 99dea7aae..652b60685 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -4,15 +4,13 @@ Using XGBoost External Memory Version When working with large datasets, training XGBoost models can be challenging as the entire dataset needs to be loaded into memory. This can be costly and sometimes -infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks -for running XGBoost algorithms. External memory can be used for both training and -prediction, but training is the primary use case and it will be our focus in this -tutorial. For prediction and evaluation, users can iterate through the data themselves -while training requires the full dataset to be loaded into the memory. - -During training, there are two different modes for external memory support available in -XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the -GPU-based training algorithm. We will introduce them in the following sections. +infeasible. Starting from 1.5, users can define a custom iterator to load data in chunks +for running XGBoost algorithms. External memory can be used for training and prediction, +but training is the primary use case and it will be our focus in this tutorial. For +prediction and evaluation, users can iterate through the data themselves, whereas training +requires the entire dataset to be loaded into the memory. Significant progress was made in +the 3.0 release for the GPU implementation. We will introduce the difference between CPU +and GPU in the following sections. .. note:: @@ -20,27 +18,33 @@ GPU-based training algorithm. We will introduce them in the following sections. .. note:: - The feature is still experimental as of 2.0. The performance is not well optimized. + The feature is considered experimental but ready for public testing in 3.0. Vector-leaf + is not yet supported. -The external memory support has gone through multiple iterations and is still under heavy -development. Like the :py:class:`~xgboost.QuantileDMatrix` with -:py:class:`~xgboost.DataIter`, XGBoost loads data batch-by-batch using a custom iterator -supplied by the user. However, unlike the :py:class:`~xgboost.QuantileDMatrix`, external -memory will not concatenate the batches unless GPU is used (it uses a hybrid approach, -more details follow). Instead, it will cache all batches on the external memory and fetch -them on-demand. Go to the end of the document to see a comparison between -:py:class:`~xgboost.QuantileDMatrix` and external memory. +The external memory support has undergone multiple development iterations. Like the +:py:class:`~xgboost.QuantileDMatrix` with :py:class:`~xgboost.DataIter`, XGBoost loads +data batch-by-batch using a custom iterator supplied by the user. However, unlike the +:py:class:`~xgboost.QuantileDMatrix`, external memory does not concatenate the batches +(unless specified by the ``extmem_concat_pages``) . Instead, it caches all batches in the +external memory and fetch them on-demand. Go to the end of the document to see a +comparison between :py:class:`~xgboost.QuantileDMatrix` and the external memory version of +:py:class:`~xgboost.ExtMemQuantileDMatrix`. + +**Contents** + +.. contents:: + :backlinks: none + :local: ************* Data Iterator ************* -Starting from XGBoost 1.5, users can define their own data loader using Python or C -interface. There are some examples in the ``demo`` directory for quick start. This is a -generalized version of text input external memory, where users no longer need to prepare a -text file that XGBoost recognizes. To enable the feature, users need to define a data -iterator with 2 class methods: ``next`` and ``reset``, then pass it into the -:py:class:`~xgboost.DMatrix` constructor. +Starting with XGBoost 1.5, users can define their own data loader using Python or C +interface. Some examples are in the ``demo`` directory for a quick start. To enable +external memory training, users need to define a data iterator with 2 class methods: +``next`` and ``reset``, then pass it into the :py:class:`~xgboost.DMatrix` or the +:py:class:`~xgboost.ExtMemQuantileDMatrix` constructor. .. code-block:: python @@ -53,20 +57,20 @@ iterator with 2 class methods: ``next`` and ``reset``, then pass it into the def __init__(self, svm_file_paths: List[str]): self._file_paths = svm_file_paths self._it = 0 - # XGBoost will generate some cache files under current directory with the prefix + # XGBoost will generate some cache files under the current directory with the prefix # "cache" super().__init__(cache_prefix=os.path.join(".", "cache")) def next(self, input_data: Callable): - """Advance the iterator by 1 step and pass the data to XGBoost. This function is + """Advance the iterator by 1 step and pass the data to XGBoost. This function is called by XGBoost during the construction of ``DMatrix`` """ if self._it == len(self._file_paths): - # return 0 to let XGBoost know this is the end of iteration + # return 0 to let XGBoost know this is the end of the iteration return 0 - # input_data is a function passed in by XGBoost who has the exact same signature of + # input_data is a function passed in by XGBoost and has the exact same signature of # ``DMatrix`` X, y = load_svmlight_file(self._file_paths[self._it]) input_data(data=X, label=y) @@ -79,59 +83,106 @@ iterator with 2 class methods: ``next`` and ``reset``, then pass it into the self._it = 0 it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"]) - Xy = xgboost.DMatrix(it) - # The ``approx`` also work, but with low performance. GPU implementation is different from CPU. - # as noted in following sections. + Xy = xgboost.ExtMemQuantileDMatrix(it) booster = xgboost.train({"tree_method": "hist"}, Xy) + # The ``approx`` tree method also works, but with lower performance and cannot be used + # with the quantile DMatrix. + + Xy = xgboost.DMatrix(it) + booster = xgboost.train({"tree_method": "approx"}, Xy) The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`. For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the common interface for using external memory with XGBoost, you can pass the resulting -:py:class:`DMatrix` object for training, prediction, and evaluation. +:py:class:`~xgboost.DMatrix` object for training, prediction, and evaluation. + +The :py:class:`~xgboost.ExtMemQuantileDMatrix` is an external memory version of the +:py:class:`~xgboost.QuantileDMatrix`. These two classes are specifically designed for the +``hist`` tree method for reduced memory usage and data loading overhead. See respective +references for more info. It is important to set the batch size based on the memory available. A good starting point -is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not* -recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt -performance in gradient boosting. - -*********** -CPU Version -*********** - -In the previous section, we demonstrated how to train a tree-based model using the -``hist`` tree method on a CPU. This method involves iterating through data batches stored -in a cache during tree construction. For optimal performance, we recommend using the -``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree -nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy -requires XGBoost to iterate over the data set for each tree node, resulting in slower -performance. - -If external memory is used, the performance of CPU training is limited by IO -(input/output) speed. This means that the disk IO speed primarily determines the training -speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of -storage can be too slow for practical usage. In addition, your system may perform caching -to reduce the overhead of file reading. +for CPU is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not* +recommended to set small batch sizes like 32 samples per batch, as this can severely hurt +performance in gradient boosting. See below sections for information about the GPU version +and other best practices. ********************************** GPU Version (GPU Hist tree method) ********************************** -External memory is supported by GPU algorithms (i.e. when ``device`` is set to -``cuda``). However, the algorithm used for GPU is different from the one used for -CPU. When training on a CPU, the tree method iterates through all batches from external -memory for each step of the tree construction algorithm. On the other hand, the GPU -algorithm uses a hybrid approach. It iterates through the data during the beginning of -each iteration and concatenates all batches into one in GPU memory for performance -reasons. To reduce overall memory usage, users can utilize subsampling. The GPU hist tree -method supports `gradient-based sampling`, enabling users to set a low sampling rate -without compromising accuracy. +External memory is supported by GPU algorithms (i.e., when ``device`` is set to +``cuda``). Starting with 3.0, the default GPU implementation is similar to what the CPU +version does. It also supports the use of :py:class:`~xgboost.ExtMemQuantileDMatrix` when +the ``hist`` tree method is employed. For a GPU device, the main memory is the device +memory, whereas the external memory can be either a disk or the CPU memory. XGBoost stages +the cache on CPU memory by default. Users can change the backing storage to disk by +specifying the ``on_host`` parameter in the :py:class:`~xgboost.DataIter`. However, using +the disk is not recommended. It's likely to make the GPU slower than the CPU. The option is +here for experimental purposes only. + +Inputs to the :py:class:`~xgboost.ExtMemQuantileDMatrix` (through the iterator) must be on +the GPU. This is a current limitation we aim to address in the future. + +.. code-block:: python + + import cupy as cp + import rmm + from rmm.allocators.cupy import rmm_cupy_allocator + + # It's important to use RMM for GPU-based external memory to improve performance. + # If XGBoost is not built with RMM support, a warning will be raised. + mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) + rmm.mr.set_current_device_resource(mr) + # Set the allocator for cupy as well. + cp.cuda.set_allocator(rmm_cupy_allocator) + # Make sure XGBoost is using RMM for all allocations. + with xgboost.config_context(use_rmm=True): + # Construct the iterators for ExtMemQuantileDMatrix + # ... + # Build the ExtMemQuantileDMatrix and start training + Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins) + Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train) + booster = xgboost.train( + { + "tree_method": "hist", + "max_depth": 6, + "max_bin": n_bins, + "device": device, + }, + Xy_train, + num_boost_round=n_rounds, + evals=[(Xy_train, "Train"), (Xy_valid, "Valid")] + ) + +It's crucial to use `RAPIDS Memory Manager (RMM) `__ for +all memory allocation when training with external memory. XGBoost relies on the memory +pool to reduce the overhead for data fetching. The size of each batch should be slightly +smaller than a quarter of the available GPU memory. In addition, the open source `NVIDIA +Linux driver +`__ +is required for ``Heterogeneous memory management (HMM)`` support. + +In addition to the batch-based data fetching, the GPU version supports concatenating +batches into a single blob for the training data to improve performance. For GPUs +connected via PCIe instead of nvlink, the performance overhead with batch-based training +is significant, particularly for non-dense data. Overall, it can be at least five times +slower than in-core training. Concatenating pages can be used to get the performance +closer to in-core training. This option should be used in combination with subsampling to +reduce the memory usage. During concatenation, subsampling removes a portion of samples, +reducing the training dataset size. The GPU hist tree method supports `gradient-based +sampling`, enabling users to set a low sampling rate without compromising accuracy. Before +3.0, concatenation with subsampling was the only option for GPU-based external +memory. After 3.0, XGBoost uses the regular batch fetching as the default while the page +concatenation can be enabled by: .. code-block:: python param = { - ... + "device": "cuda", + "extmem_concat_pages": true, 'subsample': 0.2, 'sampling_method': 'gradient_based', } @@ -139,10 +190,70 @@ without compromising accuracy. For more information about the sampling algorithm and its use in external memory training, see `this paper `_. -.. warning:: +========== +NVLink-C2C +========== - When GPU is running out of memory during iteration on external memory, user might - receive a segfault instead of an OOM exception. +The newer NVIDIA platforms like `Grace-Hopper +`__ use `NVLink-C2C +`__, which facilitates a fast +interconnect between the CPU and the GPU. With the host memory serving as the data cache, +XGBoost can retrieve data with significantly lower overhead. When the input data is dense, +there's minimal to no performance loss for training, except for the initial construction +of the :py:class:`~xgboost.ExtMemQuantileDMatrix`. The initial construction iterates +through the input data twice, as a result, the most significantly overhead compared to +in-core training is one additional data read when the data is dense. + +To run experiments on these platforms, the open source `NVIDIA Linux driver +`__ +with version ``>=565.47`` is required. + +************** +Best Practices +************** + +In previous sections, we demonstrated how to train a tree-based model with data residing +on an external memory and made some recommendations for batch size. Here are some other +configurations we find useful. The external memory feature involves iterating through data +batches stored in a cache during tree construction. For optimal performance, we recommend +using the ``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer +of tree nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy +requires XGBoost to iterate over the data set for each tree node, resulting in +significantly slower performance. + +In addition, this ``hist`` tree method should be preferred over the ``approx`` tree method +as the former doesn't recreate the histogram bins for every iteration. Creating the +histogram bins requires loading the raw input data, which is prohibitively expensive. The +:py:class:`~xgboost.ExtMemQuantileDMatrix` designed for the ``hist`` tree method can speed +up the initial data construction and the evaluation significantly for external memory. + +Since the external memory implementation focuses on training where XGBoost needs to access +the entire dataset, only the ``X`` is divided into batches while everything else is +concatenated. As a result, it's recommended for users to define their own management code +to iterate through the data for inference, especially for SHAP value computation. The size +of SHAP results can be larger than ``X``, making external memory in XGBoost less +effective. Some frameworks like ``dask`` can help with the data chunking and iterate +through the data for inference with memory spilling. + +When external memory is used, the performance of CPU training is limited by disk IO +(input/output) speed. This means that the disk IO speed primarily determines the training +speed. Similarly, PCIe bandwidth limits the GPU performance, assuming the CPU memory is +used as a cache and address translation services (ATS) is unavailable. We recommend using +regular :py:class:`~xgboost.QuantileDMatrix` over +:py:class:`~xgboost.ExtMemQuantileDMatrix` for constructing the validation dataset when +feasible. Running inference is much less computation-intensive than training and, hence, +much faster. For GPU, the time it takes to read the data from host to device completely +determines the time it takes to run inference, even if a C2C link is available. + +.. code-block:: python + + # Try to use `QuantileDMatrix` for the validation if it can be fit into the GPU memory. + Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins) + Xy_valid = xgboost.QuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train) + +During CPU benchmarking, we used an NVMe connected to a PCIe-4 slot. Other types of +storage can be too slow for practical usage. However, your system will likely perform some +caching to reduce the overhead of the file read. See the following sections for remarks. .. _ext_remarks: @@ -157,43 +268,43 @@ and internal runtime structures are concatenated. This means that memory reducti effective when dealing with wide datasets where ``X`` is significantly larger in size compared to other data like ``y``, while it has little impact on slim datasets. -As one might expect, fetching data on-demand puts significant pressure on the storage -device. Today's computing device can process way more data than a storage can read in a -single unit of time. The ratio is at order of magnitudes. An GPU is capable of processing -hundred of Gigabytes of floating-point data in a split second. On the other hand, a -four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data transfer -rate. As a result, the training is likely to be severely bounded by your storage +As one might expect, fetching data on demand puts significant pressure on the storage +device. Today's computing devices can process way more data than storage devices can read +in a single unit of time. The ratio is in the order of magnitudes. A GPU is capable of +processing hundreds of Gigabytes of floating-point data in a split second. On the other +hand, a four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data +transfer rate. As a result, the training is likely to be severely bounded by your storage device. Before adopting the external memory solution, some back-of-envelop calculations -might help you see whether it's viable. For instance, if your NVMe drive can transfer 4GB -(a fairly practical number) of data per second and you have a 100GB of data in compressed -XGBoost cache (which corresponds to a dense float32 numpy array with the size of 200GB, -give or take). A tree with depth 8 needs at least 16 iterations through the data when the -parameter is right. You need about 14 minutes to train a single tree without accounting +might help you determine its viability. For instance, if your NVMe drive can transfer 4GB +(a reasonably practical number) of data per second, and you have a 100GB of data in a +compressed XGBoost cache (corresponding to a dense float32 numpy array with 200GB, give or +take). A tree with depth 8 needs at least 16 iterations through the data when the +parameter is optimal. You need about 14 minutes to train a single tree without accounting for some other overheads and assume the computation overlaps with the IO. If your dataset -happens to have TB-level size, then you might need thousands of trees to get a generalized -model. These calculations can help you get an estimate on the expected training time. +happens to have a TB-level size, you might need thousands of trees to get a generalized +model. These calculations can help you get an estimate of the expected training time. -However, sometimes we can ameliorate this limitation. One should also consider that the OS -(mostly talking about the Linux kernel) can usually cache the data on host memory. It only -evicts pages when new data comes in and there's no room left. In practice, at least some -portion of the data can persist on the host memory throughout the entire training +However, sometimes, we can ameliorate this limitation. One should also consider that the +OS (mainly talking about the Linux kernel) can usually cache the data on host memory. It +only evicts pages when new data comes in and there's no room left. In practice, at least +some portion of the data can persist in the host memory throughout the entire training session. We are aware of this cache when optimizing the external memory fetcher. The compressed cache is usually smaller than the raw input data, especially when the input is dense without any missing value. If the host memory can fit a significant portion of this -compressed cache, then the performance should be decent after initialization. Our -development so far focus on two fronts of optimization for external memory: +compressed cache, the performance should be decent after initialization. Our development +so far focuses on following fronts of optimization for external memory: - Avoid iterating through the data whenever appropriate. - If the OS can cache the data, the performance should be close to in-core training. +- For GPU, the actual computation should overlap with memory copy as much as possible. -Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not -tested against system errors like disconnected network devices (`SIGBUS`). In the face of -a bus error, you will see a hard crash and need to clean up the cache files. If the -training session might take a long time and you are using solutions like NVMe-oF, we +Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It has not +been tested against system errors like disconnected network devices (`SIGBUS`). In the +face of a bus error, you will see a hard crash and need to clean up the cache files. If +the training session might take a long time and you use solutions like NVMe-oF, we recommend checkpointing your model periodically. Also, it's worth noting that most tests have been conducted on Linux distributions. - Another important point to keep in mind is that creating the initial cache for XGBoost may take some time. The interface to external memory is through custom iterators, which we can not assume to be thread-safe. Therefore, initialization is performed sequentially. Using @@ -206,13 +317,30 @@ Compared to the QuantileDMatrix Passing an iterator to the :py:class:`~xgboost.QuantileDMatrix` enables direct construction of :py:class:`~xgboost.QuantileDMatrix` with data chunks. On the other hand, -if it's passed to :py:class:`~xgboost.DMatrix`, it instead enables the external memory -feature. The :py:class:`~xgboost.QuantileDMatrix` concatenates the data on memory after +if it's passed to the :py:class:`~xgboost.DMatrix` or the +:py:class:`~xgboost.ExtMemQuantileDMatrix`, it instead enables the external memory +feature. The :py:class:`~xgboost.QuantileDMatrix` concatenates the data in memory after compression and doesn't fetch data during training. On the other hand, the external memory -:py:class:`~xgboost.DMatrix` fetches data batches from external memory on-demand. Use the -:py:class:`~xgboost.QuantileDMatrix` (with iterator if necessary) when you can fit most of -your data in memory. The training would be an order of magnitude faster than using -external memory. +:py:class:`~xgboost.DMatrix` (:py:class:`~xgboost.ExtMemQuantileDMatrix`) fetches data +batches from external memory on demand. Use the :py:class:`~xgboost.QuantileDMatrix` (with +iterator if necessary) when you can fit most of your data in memory. For many platforms, +the training speed can be an order of magnitude faster than external memory. + +************* +Brief History +************* + +For a long time, external memory support has been an experimental feature and has +undergone multiple development iterations. Here's a brief summary of major changes: + +- Gradient-based sampling was introduced to the GPU hist in 1.1. +- The iterator interface was introduced in 1.5, along with a major rewrite for the + internal framework. +- 2.0 introduced the use of ``mmap``, along with optimization in XBGoost to enable + zero-copy data fetching. +- 3.0 reworked the GPU implementation to support caching data on the host and disk, + introduced the :py:class:`~xgboost.ExtMemQuantileDMatrix` class, added quantile-based + objectives support. **************** Text File Inputs @@ -220,11 +348,11 @@ Text File Inputs .. warning:: - This is the original form of external memory support before 1.5, users are encouraged - to use custom data iterator instead. + This is the original form of external memory support before 1.5 and is now deprecated, + users are encouraged to use a custom data iterator instead. -There is no big difference between using external memory version of text input and the -in-memory version. The only difference is the filename format. +There is no significant difference between using the external memory version of text input +and the in-memory version of text input. The only difference is the filename format. The external memory version takes in the following `URI `_ format: @@ -233,7 +361,7 @@ The external memory version takes in the following `URI filename?format=libsvm#cacheprefix -The ``filename`` is the normal path to LIBSVM format file you want to load in, and +The ``filename`` is the typical path to LIBSVM format file you want to load in, and ``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed data in binary form. @@ -253,7 +381,7 @@ format, the external memory support can be enabled by: dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache') XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to a new file named -``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For +``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For more notes about text input formats, see :doc:`/tutorials/input_format`. -For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``. +For the CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 39ab5846b..97242889a 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -504,8 +504,8 @@ def _prediction_output( class DataIter(ABC): # pylint: disable=too-many-instance-attributes """The interface for user defined data iterator. The iterator facilitates distributed training, :py:class:`QuantileDMatrix`, and external memory support using - :py:class:`DMatrix`. Most of time, users don't need to interact with this class - directly. + :py:class:`DMatrix` or :py:class:`ExtMemQuantileDMatrix`. Most of time, users don't + need to interact with this class directly. .. note:: @@ -525,15 +525,16 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes keep the cache. on_host : - Whether the data should be cached on host memory instead of harddrive when using - GPU with external memory. If set to true, then the "external memory" would - simply be CPU (host) memory. + Whether the data should be cached on the host memory instead of the file system + when using GPU with external memory. When set to true (the default), the + "external memory" is the CPU (host) memory. See + :doc:`/tutorials/external_memory` for more info. .. versionadded:: 3.0.0 .. warning:: - This is still working in progress, not ready for test yet. + This is an experimental parameter. """ @@ -541,7 +542,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes self, cache_prefix: Optional[str] = None, release_data: bool = True, - on_host: bool = False, + on_host: bool = True, ) -> None: self.cache_prefix = cache_prefix self.on_host = on_host @@ -1681,9 +1682,12 @@ class QuantileDMatrix(DMatrix): class ExtMemQuantileDMatrix(DMatrix): """The external memory version of the :py:class:`QuantileDMatrix`. + See :doc:`/tutorials/external_memory` for explanation and usage examples, and + :py:class:`QuantileDMatrix` for parameter document. + .. warning:: - This is still working in progress, not ready for test yet. + This is an experimental feature. .. versionadded:: 3.0.0 @@ -1699,6 +1703,13 @@ class ExtMemQuantileDMatrix(DMatrix): ref: Optional[DMatrix] = None, enable_categorical: bool = False, ) -> None: + """ + Parameters + ---------- + data : + A user-defined :py:class:`DataIter` for loading data. + + """ self.max_bin = max_bin self.missing = missing if missing is not None else np.nan self.nthread = nthread if nthread is not None else -1 diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 02fc6f55c..c2ee4a058 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -106,9 +106,10 @@ inline auto NoCategorical(std::string name) { return name + " doesn't support categorical features."; } -inline void NoOnHost(bool on_host) { - if (on_host) { - LOG(FATAL) << "Caching on host memory is only available for GPU."; +inline void NoPageConcat(bool concat_pages) { + if (concat_pages) { + LOG(FATAL) << "`extmem_concat_pages` must be false when there's no sampling or when it's " + "running on the CPU."; } } } // namespace xgboost::error diff --git a/src/data/batch_utils.cuh b/src/data/batch_utils.cuh new file mode 100644 index 000000000..9f05e73c9 --- /dev/null +++ b/src/data/batch_utils.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#pragma once + +#include "xgboost/data.h" // for BatchParam + +namespace xgboost::data::cuda_impl { +// Use two batch for prefecting. There's always one batch being worked on, while the other +// batch being transferred. +constexpr auto DftPrefetchBatches() { return 2; } + +// Empty parameter to prevent regen, only used to control external memory prefetching. +// +// Both the approx and hist initializes the DMatrix before creating the actual +// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty +// parameter to avoid any regen. +inline BatchParam StaticBatch(bool prefetch_copy) { + BatchParam p; + p.prefetch_copy = prefetch_copy; + p.n_prefetch_batches = DftPrefetchBatches(); + return p; +} +} // namespace xgboost::data::cuda_impl diff --git a/src/data/data.cc b/src/data/data.cc index b71820a96..b1b25f707 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -920,7 +920,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s data::fileiter::Next, std::numeric_limits::quiet_NaN(), 1, - cache_file}; + cache_file, + false}; } return dmat; diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 3c121b13c..8d28b71d4 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -10,7 +10,7 @@ #include // for move #include // for vector -#include "../common/cuda_rt_utils.h" // for SupportsPageableMem +#include "../common/cuda_rt_utils.h" // for SupportsPageableMem, SupportsAts #include "../common/hist_util.h" // for HistogramCuts #include "ellpack_page.h" // for EllpackPage #include "ellpack_page_raw_format.h" // for EllpackPageRawFormat @@ -67,7 +67,20 @@ class EllpackFormatPolicy { using FormatT = EllpackPageRawFormat; public: - EllpackFormatPolicy() = default; + EllpackFormatPolicy() { + StringView msg{" The overhead of iterating through external memory might be significant."}; + if (!has_hmm_) { + LOG(WARNING) << "CUDA heterogeneous memory management is not available." << msg; + } else if (!common::SupportsAts()) { + LOG(WARNING) << "CUDA address translation service is not available." << msg; + } +#if !defined(XGBOOST_USE_RMM) + LOG(WARNING) << "XGBoost is not built with RMM support." << msg; +#endif + if (!GlobalConfigThreadLocalStore::Get()->use_rmm) { + LOG(WARNING) << "`use_rmm` is set to false." << msg; + } + } // For testing with the HMM flag. explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {} @@ -135,6 +148,9 @@ class EllpackMmapStreamPolicy : public F { bst_idx_t length) const; }; +/** + * @brief Ellpack source with sparse pages as the underlying source. + */ template class EllpackPageSourceImpl : public PageSourceIncMixIn { using Super = PageSourceIncMixIn; @@ -171,6 +187,9 @@ using EllpackPageHostSource = using EllpackPageSource = EllpackPageSourceImpl>; +/** + * @brief Ellpack source directly interfaces with user-defined iterators. + */ template class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin { using Super = ExtQantileSourceMixin; @@ -201,6 +220,7 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixinSetDevice(ctx->Device()); this->SetCuts(std::move(cuts), ctx->Device()); this->Fetch(); } diff --git a/src/data/extmem_quantile_dmatrix.cc b/src/data/extmem_quantile_dmatrix.cc index 0bdab8f02..e3659f205 100644 --- a/src/data/extmem_quantile_dmatrix.cc +++ b/src/data/extmem_quantile_dmatrix.cc @@ -13,6 +13,7 @@ #include "proxy_dmatrix.h" // for DataIterProxy, HostAdapterDispatch #include "quantile_dmatrix.h" // for GetDataShape, MakeSketches #include "simple_batch_iterator.h" // for SimpleBatchIteratorImpl +#include "sparse_page_source.h" // for MakeCachePrefix #if !defined(XGBOOST_USE_CUDA) #include "../common/common.h" // for AssertGPUSupport @@ -26,6 +27,7 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix std::int32_t n_threads, std::string cache, bst_bin_t max_bin, bool on_host) : cache_prefix_{std::move(cache)}, on_host_{on_host} { + cache_prefix_ = MakeCachePrefix(cache_prefix_); auto iter = std::make_shared>( iter_handle, reset, next); iter->Reset(); diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 7cabfbd14..352810541 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -13,9 +13,9 @@ #include // for move #include // for visit -#include "../collective/communicator-inl.h" -#include "batch_utils.h" // for RegenGHist -#include "gradient_index.h" +#include "batch_utils.h" // for RegenGHist +#include "gradient_index.h" // for GHistIndexMatrix +#include "sparse_page_source.h" // for MakeCachePrefix namespace xgboost::data { MetaInfo &SparsePageDMatrix::Info() { return info_; } @@ -34,12 +34,9 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p cache_prefix_{std::move(cache_prefix)}, on_host_{on_host} { Context ctx; - ctx.nthread = nthreads; + ctx.Init(Args{{"nthread", std::to_string(nthreads)}}); + cache_prefix_ = MakeCachePrefix(cache_prefix_); - cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_; - if (collective::IsDistributed()) { - cache_prefix_ += ("-r" + std::to_string(collective::GetRank())); - } DMatrixProxy *proxy = MakeProxy(proxy_); auto iter = DataIterProxy{ iter_, reset_, next_}; @@ -107,7 +104,6 @@ BatchSet SparsePageDMatrix::GetRowBatches() { BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); - error::NoOnHost(on_host_); this->InitializeSparsePage(ctx); if (!column_source_) { column_source_ = @@ -122,7 +118,6 @@ BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) { auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); - error::NoOnHost(on_host_); this->InitializeSparsePage(ctx); if (!sorted_column_source_) { sorted_column_source_ = std::make_shared( @@ -140,7 +135,6 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct CHECK_GE(param.max_bin, 2); } detail::CheckEmpty(batch_param_, param); - error::NoOnHost(on_host_); auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_); if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { this->InitializeSparsePage(ctx); diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index f40c16f72..9f2eed918 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -70,10 +70,10 @@ class SparsePageDMatrix : public DMatrix { DataIterResetCallback *reset_; XGDMatrixCallbackNext *next_; - float missing_; + float const missing_; Context fmat_ctx_; std::string cache_prefix_; - bool on_host_{false}; + bool const on_host_; std::uint32_t n_batches_{0}; // sparse page is the source to other page types, we make a special member function. void InitializeSparsePage(Context const *ctx); @@ -83,7 +83,7 @@ class SparsePageDMatrix : public DMatrix { public: explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, int32_t nthreads, - std::string cache_prefix, bool on_host = false); + std::string cache_prefix, bool on_host); ~SparsePageDMatrix() override; diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc index 1edf27c46..13a468d9b 100644 --- a/src/data/sparse_page_raw_format.cc +++ b/src/data/sparse_page_raw_format.cc @@ -54,22 +54,18 @@ class SparsePageRawFormat : public SparsePageFormat { private: }; -XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) -.describe("Raw binary data format.") -.set_body([]() { - return new SparsePageRawFormat(); - }); +#define SparsePageFmt SparsePageFormat +DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SparsePageFmt, raw) + .describe("Raw binary data format.") + .set_body([]() { return new SparsePageRawFormat(); }); -XGBOOST_REGISTER_CSC_PAGE_FORMAT(raw) -.describe("Raw binary data format.") -.set_body([]() { - return new SparsePageRawFormat(); - }); - -XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw) -.describe("Raw binary data format.") -.set_body([]() { - return new SparsePageRawFormat(); - }); +#define CSCPageFmt SparsePageFormat +DMLC_REGISTRY_REGISTER(SparsePageFormatReg, CSCPageFmt, raw) + .describe("Raw binary data format.") + .set_body([]() { return new SparsePageRawFormat(); }); +#define SortedCSCPageFmt SparsePageFormat +DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SortedCSCPageFmt, raw) + .describe("Raw binary data format.") + .set_body([]() { return new SparsePageRawFormat(); }); } // namespace xgboost::data diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index 724260512..dd4050a71 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -8,6 +8,8 @@ #include // for partial_sum #include // for string +#include "../collective/communicator-inl.h" // for IsDistributed, GetRank + namespace xgboost::data { void Cache::Commit() { if (!this->written) { @@ -28,6 +30,14 @@ void TryDeleteCacheFile(const std::string& file) { } } +std::string MakeCachePrefix(std::string cache_prefix) { + cache_prefix = cache_prefix.empty() ? "DMatrix" : cache_prefix; + if (collective::IsDistributed()) { + cache_prefix += ("-r" + std::to_string(collective::GetRank())); + } + return cache_prefix; +} + #if !defined(XGBOOST_USE_CUDA) void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; } #endif diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 471a84d60..cefd13ad7 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -33,6 +33,8 @@ namespace xgboost::data { void TryDeleteCacheFile(const std::string& file); +std::string MakeCachePrefix(std::string cache_prefix); + /** * @brief Information about the cache including path and page offsets. */ diff --git a/src/data/sparse_page_writer.h b/src/data/sparse_page_writer.h index 989c03d33..526126d29 100644 --- a/src/data/sparse_page_writer.h +++ b/src/data/sparse_page_writer.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file sparse_page_writer.h * \author Tianqi Chen */ @@ -11,7 +11,6 @@ #include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream #include "dmlc/registry.h" // for Registry, FunctionRegEntryBase -#include "xgboost/data.h" // for SparsePage,CSCPage,SortedCSCPage,EllpackPage ... namespace xgboost::data { template @@ -54,47 +53,13 @@ inline SparsePageFormat* CreatePageFormat(const std::string& name) { return (e->body)(); } -/*! - * \brief Registry entry for sparse page format. +/** + * @brief Registry entry for sparse page format. */ template struct SparsePageFormatReg : public dmlc::FunctionRegEntryBase, std::function* ()>> { }; - -/*! - * \brief Macro to register sparse page format. - * - * \code - * // example of registering a objective - * XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) - * .describe("Raw binary data format.") - * .set_body([]() { - * return new RawFormat(); - * }); - * \endcode - */ -#define SparsePageFmt SparsePageFormat -#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SparsePageFmt, Name) - -#define CSCPageFmt SparsePageFormat -#define XGBOOST_REGISTER_CSC_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, CSCPageFmt, Name) - -#define SortedCSCPageFmt SparsePageFormat -#define XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SortedCSCPageFmt, Name) - -#define EllpackPageFmt SparsePageFormat -#define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, EllpackPageFmt, Name) - -#define GHistIndexPageFmt SparsePageFormat -#define XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(Name) \ - DMLC_REGISTRY_REGISTER(SparsePageFormatReg, \ - GHistIndexPageFmt, Name) - } // namespace xgboost::data #endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 325f67eda..115d30e7a 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -14,9 +14,10 @@ #include "../common/categorical.h" #include "../common/common.h" #include "../common/cuda_context.cuh" // for CUDAContext -#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs +#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs, SetDevice #include "../common/device_helpers.cuh" -#include "../common/error_msg.h" // for InplacePredictProxy +#include "../common/error_msg.h" // for InplacePredictProxy +#include "../data/batch_utils.cuh" // for StaticBatch #include "../data/device_adapter.cuh" #include "../data/ellpack_page.cuh" #include "../data/proxy_dmatrix.h" @@ -31,6 +32,8 @@ namespace xgboost::predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); +using data::cuda_impl::StaticBatch; + struct TreeView { RegTree::CategoricalSplitMatrix cats; common::Span d_tree; @@ -475,15 +478,14 @@ struct PathInfo { }; // Transform model into path element form for GPUTreeShap -void ExtractPaths( - dh::device_vector> *paths, - DeviceModel *model, dh::device_vector *path_categories, - DeviceOrd device) { - dh::safe_cuda(cudaSetDevice(device.ordinal)); +void ExtractPaths(Context const* ctx, + dh::device_vector>* paths, + DeviceModel* model, dh::device_vector* path_categories, + DeviceOrd device) { + common::SetDevice(device.ordinal); auto& device_model = *model; dh::caching_device_vector info(device_model.nodes.Size()); - dh::XGBCachingDeviceAllocator alloc; auto d_nodes = device_model.nodes.ConstDeviceSpan(); auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan(); auto nodes_transform = dh::MakeTransformIterator( @@ -502,17 +504,15 @@ void ExtractPaths( } return PathInfo{static_cast(idx), path_length, tree_idx}; }); - auto end = thrust::copy_if( - thrust::cuda::par(alloc), nodes_transform, - nodes_transform + d_nodes.size(), info.begin(), - [=] __device__(const PathInfo& e) { return e.leaf_position != -1; }); + auto end = thrust::copy_if(ctx->CUDACtx()->CTP(), nodes_transform, + nodes_transform + d_nodes.size(), info.begin(), + [=] __device__(const PathInfo& e) { return e.leaf_position != -1; }); info.resize(end - info.begin()); auto length_iterator = dh::MakeTransformIterator( info.begin(), [=] __device__(const PathInfo& info) { return info.length; }); dh::caching_device_vector path_segments(info.size() + 1); - thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator, - length_iterator + info.size() + 1, + thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1, path_segments.begin()); paths->resize(path_segments.back()); @@ -528,19 +528,17 @@ void ExtractPaths( auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan(); size_t max_cat = 0; - if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types), + if (thrust::any_of(ctx->CUDACtx()->CTP(), dh::tbegin(d_split_types), dh::tend(d_split_types), common::IsCatOp{})) { dh::PinnedMemory pinned; auto h_max_cat = pinned.GetSpan(1); auto max_elem_it = dh::MakeTransformIterator( dh::tbegin(d_cat_node_segments), [] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; }); - size_t max_cat_it = - thrust::max_element(thrust::device, max_elem_it, - max_elem_it + d_cat_node_segments.size()) - - max_elem_it; - dh::safe_cuda(cudaMemcpy(h_max_cat.data(), - d_cat_node_segments.data() + max_cat_it, + size_t max_cat_it = thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it, + max_elem_it + d_cat_node_segments.size()) - + max_elem_it; + dh::safe_cuda(cudaMemcpy(h_max_cat.data(), d_cat_node_segments.data() + max_cat_it, h_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); max_cat = h_max_cat[0].size; CHECK_GE(max_cat, 1); @@ -550,7 +548,7 @@ void ExtractPaths( auto d_model_categories = device_model.categories.DeviceSpan(); common::Span d_path_categories = dh::ToSpan(*path_categories); - dh::LaunchN(info.size(), [=] __device__(size_t idx) { + dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) { auto path_info = d_info[idx]; size_t tree_offset = d_tree_segments[path_info.tree_idx]; TreeView tree{0, path_info.tree_idx, d_nodes, @@ -864,7 +862,7 @@ class GPUPredictor : public xgboost::Predictor { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); auto const kernel = [&](auto predict_fn) { - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( predict_fn, data, model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), @@ -888,7 +886,7 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; bool use_shared = false; - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}( + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, 0, ctx_->CUDACtx()->Stream()}( PredictKernel, batch, model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), @@ -924,7 +922,7 @@ class GPUPredictor : public xgboost::Predictor { } } else { bst_idx_t batch_offset = 0; - for (auto const& page : dmat->GetBatches(ctx_, BatchParam{})) { + for (auto const& page : dmat->GetBatches(ctx_, StaticBatch(true))) { dmat->Info().feature_types.SetDevice(ctx_->Device()); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_, feature_types), d_model, @@ -989,7 +987,7 @@ class GPUPredictor : public xgboost::Predictor { bool use_shared = shared_memory_bytes != 0; - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( PredictKernel, m->Value(), d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), d_model.split_types.ConstDeviceSpan(), @@ -1055,7 +1053,7 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; d_model.Init(model, 0, tree_end, ctx_->Device()); dh::device_vector categories; - ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device()); + ExtractPaths(ctx_, &device_paths, &d_model, &categories, ctx_->Device()); if (p_fmat->PageExists()) { for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(ctx_->Device()); @@ -1067,7 +1065,7 @@ class GPUPredictor : public xgboost::Predictor { X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); } } else { - for (auto& batch : p_fmat->GetBatches(ctx_, {})) { + for (auto& batch : p_fmat->GetBatches(ctx_, StaticBatch(true))) { EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_)}; auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(), std::numeric_limits::quiet_NaN()}; @@ -1083,7 +1081,7 @@ class GPUPredictor : public xgboost::Predictor { auto base_score = model.learner_model_param->BaseScore(ctx_); dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, - [=] __device__(size_t idx) { + ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) { phis[(idx + 1) * contributions_columns - 1] += margin.empty() ? base_score(0) : margin[idx]; }); @@ -1125,7 +1123,7 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; d_model.Init(model, 0, tree_end, ctx_->Device()); dh::device_vector categories; - ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device()); + ExtractPaths(ctx_, &device_paths, &d_model, &categories, ctx_->Device()); if (p_fmat->PageExists()) { for (auto const& batch : p_fmat->GetBatches()) { batch.data.SetDevice(ctx_->Device()); @@ -1137,7 +1135,7 @@ class GPUPredictor : public xgboost::Predictor { X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis)); } } else { - for (auto const& batch : p_fmat->GetBatches(ctx_, {})) { + for (auto const& batch : p_fmat->GetBatches(ctx_, StaticBatch(true))) { auto impl = batch.Impl(); auto acc = impl->GetDeviceAccessor(ctx_, p_fmat->Info().feature_types.ConstDeviceSpan()); auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size; @@ -1155,7 +1153,7 @@ class GPUPredictor : public xgboost::Predictor { auto base_score = model.learner_model_param->BaseScore(ctx_); size_t n_features = model.learner_model_param->num_feature; dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, - [=] __device__(size_t idx) { + ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) { size_t group = idx % ngroup; size_t row_idx = idx / ngroup; phis[gpu_treeshap::IndexPhiInteractions(row_idx, ngroup, group, n_features, @@ -1199,7 +1197,7 @@ class GPUPredictor : public xgboost::Predictor { bst_feature_t num_features = info.num_col_; auto launch = [&](auto fn, std::uint32_t grid, auto data, bst_idx_t batch_offset) { - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes}( + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( fn, data, d_model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(), @@ -1223,7 +1221,7 @@ class GPUPredictor : public xgboost::Predictor { } } else { bst_idx_t batch_offset = 0; - for (auto const& batch : p_fmat->GetBatches(ctx_, BatchParam{})) { + for (auto const& batch : p_fmat->GetBatches(ctx_, StaticBatch(true))) { EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_)}; auto grid = static_cast(common::DivRoundUp(batch.Size(), kBlockThreads)); launch(PredictLeafKernel, grid, data, batch_offset); diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 077cc2c72..46a52a8ea 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -148,19 +148,8 @@ class PoissonSampling : public thrust::binary_function gpair, - DMatrix* dmat) { - return {dmat, gpair}; -} - -ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param) - : batch_param_{std::move(batch_param)} {} - -GradientBasedSample ExternalMemoryNoSampling::Sample(Context const*, - common::Span gpair, - DMatrix* p_fmat) { + DMatrix* p_fmat) { return {p_fmat, gpair}; } @@ -246,9 +235,10 @@ GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batc grad_sum_(n_rows, 0.0f) {} GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, - common::Span gpair, DMatrix* dmat) { + common::Span gpair, + DMatrix* p_fmat) { auto cuctx = ctx->CUDACtx(); - size_t n_rows = dmat->Info().num_row_; + size_t n_rows = p_fmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); @@ -257,7 +247,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx, thrust::counting_iterator(0), dh::tbegin(gpair), PoissonSampling(dh::ToSpan(threshold_), threshold_index, RandomWeight(common::GlobalRandom()()))); - return {dmat, gpair}; + return {p_fmat, gpair}; } ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows, @@ -323,46 +313,46 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows, const BatchParam& batch_param, float subsample, - int sampling_method, bool is_external_memory) { + int sampling_method, bool concat_pages) { // The ctx is kept here for future development of stream-based operations. - monitor_.Init("gradient_based_sampler"); + monitor_.Init(__func__); bool is_sampling = subsample < 1.0; - if (is_sampling) { - switch (sampling_method) { - case TrainParam::kUniform: - if (is_external_memory) { - strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample)); - } else { - strategy_.reset(new UniformSampling(batch_param, subsample)); - } - break; - case TrainParam::kGradientBased: - if (is_external_memory) { - strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample)); - } else { - strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample)); - } - break; - default: - LOG(FATAL) << "unknown sampling method"; + if (!is_sampling) { + strategy_.reset(new NoSampling{}); + error::NoPageConcat(concat_pages); + return; + } + + switch (sampling_method) { + case TrainParam::kUniform: { + if (concat_pages) { + strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample)); + } else { + strategy_.reset(new UniformSampling(batch_param, subsample)); + } + break; } - } else { - if (is_external_memory) { - strategy_.reset(new ExternalMemoryNoSampling(batch_param)); - } else { - strategy_.reset(new NoSampling(batch_param)); + case TrainParam::kGradientBased: { + if (concat_pages) { + strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample)); + } else { + strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample)); + } + break; } + default: + LOG(FATAL) << "unknown sampling method"; } } // Sample a DMatrix based on the given gradient pairs. GradientBasedSample GradientBasedSampler::Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) { - monitor_.Start("Sample"); + monitor_.Start(__func__); GradientBasedSample sample = strategy_->Sample(ctx, gpair, dmat); - monitor_.Stop("Sample"); + monitor_.Stop(__func__); return sample; } diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index ea3d10cd0..12c094866 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -24,31 +24,29 @@ class SamplingStrategy { virtual GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) = 0; virtual ~SamplingStrategy() = default; + /** + * @brief Whether pages are concatenated after sampling. + */ + [[nodiscard]] virtual bool ConcatPages() const { return false; } }; -/*! \brief No sampling in in-memory mode. */ +class ExtMemSamplingStrategy : public SamplingStrategy { + public: + [[nodiscard]] bool ConcatPages() const final { return true; } +}; + +/** + * @brief No-op. + */ class NoSampling : public SamplingStrategy { public: - explicit NoSampling(BatchParam batch_param); GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat) override; - - private: - BatchParam batch_param_; }; -/*! \brief No sampling in external memory mode. */ -class ExternalMemoryNoSampling : public SamplingStrategy { - public: - explicit ExternalMemoryNoSampling(BatchParam batch_param); - GradientBasedSample Sample(Context const* ctx, common::Span gpair, - DMatrix* dmat) override; - - private: - BatchParam batch_param_; -}; - -/*! \brief Uniform sampling in in-memory mode. */ +/** + * @brief Uniform sampling in in-memory mode. + */ class UniformSampling : public SamplingStrategy { public: UniformSampling(BatchParam batch_param, float subsample); @@ -61,7 +59,7 @@ class UniformSampling : public SamplingStrategy { }; /*! \brief No sampling in external memory mode. */ -class ExternalMemoryUniformSampling : public SamplingStrategy { +class ExternalMemoryUniformSampling : public ExtMemSamplingStrategy { public: ExternalMemoryUniformSampling(size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, @@ -91,7 +89,7 @@ class GradientBasedSampling : public SamplingStrategy { }; /*! \brief Gradient-based sampling in external memory mode.. */ -class ExternalMemoryGradientBasedSampling : public SamplingStrategy { +class ExternalMemoryGradientBasedSampling : public ExtMemSamplingStrategy { public: ExternalMemoryGradientBasedSampling(size_t n_rows, BatchParam batch_param, float subsample); GradientBasedSample Sample(Context const* ctx, common::Span gpair, @@ -120,7 +118,7 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy { class GradientBasedSampler { public: GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param, - float subsample, int sampling_method, bool is_external_memory); + float subsample, int sampling_method, bool concat_pages); /*! \brief Sample from a DMatrix based on the given gradient pairs. */ GradientBasedSample Sample(Context const* ctx, common::Span gpair, DMatrix* dmat); @@ -130,6 +128,8 @@ class GradientBasedSampler { common::Span threshold, common::Span grad_sum, size_t sample_rows); + [[nodiscard]] bool ConcatPages() const { return this->strategy_->ConcatPages(); } + private: common::Monitor monitor_; std::unique_ptr strategy_; diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h index e981e886a..e06eff027 100644 --- a/src/tree/hist/param.h +++ b/src/tree/hist/param.h @@ -23,6 +23,7 @@ struct HistMakerTrainParam : public XGBoostParameter { constexpr static std::size_t CudaDefaultNodes() { return static_cast(1) << 12; } bool debug_synchronize{false}; + bool extmem_concat_pages{false}; void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const; @@ -42,6 +43,7 @@ struct HistMakerTrainParam : public XGBoostParameter { .set_default(NotSet()) .set_lower_bound(1) .describe("Maximum number of nodes in histogram cache."); + DMLC_DECLARE_FIELD(extmem_concat_pages).set_default(false); } }; } // namespace xgboost::tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index fe5637f4a..51c8a5b21 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -278,7 +278,7 @@ class GlobalApproxUpdater : public TreeUpdater { *sampled = linalg::Empty(ctx_, gpair->Size(), 1); auto in = gpair->HostView().Values(); std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data()); - + error::NoPageConcat(this->hist_param_.extmem_concat_pages); SampleGradient(ctx_, param, sampled->HostView()); } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 0fdc30822..dfbf42743 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -5,10 +5,11 @@ #include // for numeric_limits #include // for ostream -#include "gpu_hist/quantiser.cuh" // for GradientQuantiser -#include "param.h" // for TrainParam -#include "xgboost/base.h" // for bst_bin_t -#include "xgboost/task.h" // for ObjInfo +#include "../data/batch_utils.cuh" // for DftPrefetchBatches, StaticBatch +#include "gpu_hist/quantiser.cuh" // for GradientQuantiser +#include "param.h" // for TrainParam +#include "xgboost/base.h" // for bst_bin_t +#include "xgboost/task.h" // for ObjInfo namespace xgboost::tree { struct GPUTrainingParam { @@ -119,26 +120,19 @@ struct DeviceSplitCandidate { }; namespace cuda_impl { -constexpr auto DftPrefetchBatches() { return 2; } - inline BatchParam HistBatch(TrainParam const& param) { auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()}; p.prefetch_copy = true; - p.n_prefetch_batches = DftPrefetchBatches(); + p.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches(); return p; } inline BatchParam ApproxBatch(TrainParam const& p, common::Span hess, ObjInfo const& task) { - return BatchParam{p.max_bin, hess, !task.const_hess}; -} - -// Empty parameter to prevent regen, only used to control external memory prefetching. -inline BatchParam StaticBatch(bool prefetch_copy) { - BatchParam p; - p.prefetch_copy = prefetch_copy; - p.n_prefetch_batches = DftPrefetchBatches(); - return p; + auto batch = BatchParam{p.max_bin, hess, !task.const_hess}; + batch.prefetch_copy = true; + batch.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches(); + return batch; } } // namespace cuda_impl diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 390422ce1..a30f624fd 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -21,6 +21,7 @@ #include "../common/hist_util.h" // for HistogramCuts #include "../common/random.h" // for ColumnSampler, GlobalRandom #include "../common/timer.h" +#include "../data/batch_utils.cuh" // for StaticBatch #include "../data/ellpack_page.cuh" #include "../data/ellpack_page.h" #include "constraints.cuh" @@ -50,11 +51,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); using cuda_impl::ApproxBatch; using cuda_impl::HistBatch; - -// Both the approx and hist initializes the DMatrix before creating the actual -// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty -// parameter to avoid any regen. -using cuda_impl::StaticBatch; +using data::cuda_impl::StaticBatch; // Extra data for each node that is passed to the update position function struct NodeSplitData { @@ -102,11 +99,11 @@ struct GPUHistMakerDevice { std::vector> partitioners_; DeviceHistogramBuilder histogram_; - std::vector batch_ptr_; + std::vector const batch_ptr_; // node idx for each sample dh::device_vector positions_; HistMakerTrainParam const* hist_param_; - std::shared_ptr cuts_{nullptr}; + std::shared_ptr const cuts_; auto CreatePartitionNodes(RegTree const* p_tree, std::vector const& candidates) { std::vector nidx(candidates.size()); @@ -135,35 +132,35 @@ struct GPUHistMakerDevice { dh::device_vector monotone_constraints; - TrainParam param; + TrainParam const param; std::unique_ptr quantiser; dh::PinnedMemory pinned; dh::PinnedMemory pinned2; - common::Monitor monitor; FeatureInteractionConstraintDevice interaction_constraints; std::unique_ptr sampler; std::unique_ptr feature_groups; + common::Monitor monitor; GPUHistMakerDevice(Context const* ctx, TrainParam _param, HistMakerTrainParam const* hist_param, std::shared_ptr column_sampler, BatchParam batch_param, MetaInfo const& info, std::vector batch_ptr, std::shared_ptr cuts) : evaluator_{_param, static_cast(info.num_col_), ctx->Device()}, - ctx_(ctx), - param(std::move(_param)), - column_sampler_(std::move(column_sampler)), - interaction_constraints(param, static_cast(info.num_col_)), + ctx_{ctx}, + column_sampler_{std::move(column_sampler)}, batch_ptr_{std::move(batch_ptr)}, hist_param_{hist_param}, - cuts_{std::move(cuts)} { - this->sampler = - std::make_unique(ctx, info.num_row_, batch_param, param.subsample, - param.sampling_method, batch_ptr_.size() > 2); + cuts_{std::move(cuts)}, + param{std::move(_param)}, + interaction_constraints(param, static_cast(info.num_col_)), + sampler{std::make_unique( + ctx, info.num_row_, batch_param, param.subsample, param.sampling_method, + batch_ptr_.size() > 2 && this->hist_param_->extmem_concat_pages)} { if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; @@ -185,33 +182,31 @@ struct GPUHistMakerDevice { } // Reset values for each update iteration - [[nodiscard]] DMatrix* Reset(HostDeviceVector* dh_gpair, DMatrix* p_fmat) { + [[nodiscard]] DMatrix* Reset(HostDeviceVector const* dh_gpair, DMatrix* p_fmat) { this->monitor.Start(__func__); common::SetDevice(ctx_->Ordinal()); auto const& info = p_fmat->Info(); - // backup the gradient - dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair, ctx_->CUDACtx()->Stream()); - this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(), - param.colsample_bynode, param.colsample_bylevel, - param.colsample_bytree); - this->interaction_constraints.Reset(ctx_); - this->evaluator_.Reset(this->ctx_, *cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(), - p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit()); - // Sampling + /** + * Sampling + */ + dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair, ctx_->CUDACtx()->Stream()); auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat); this->gpair = sample.gpair; - p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners + p_fmat = sample.p_fmat; p_fmat->Info().feature_types.SetDevice(ctx_->Device()); - std::size_t n_batches = p_fmat->NumBatches(); - bool is_concat = (n_batches + 1) != this->batch_ptr_.size(); - std::vector batch_ptr{batch_ptr_}; + + /** + * Initialize the partitioners + */ + bool is_concat = sampler->ConcatPages(); + std::size_t n_batches = is_concat ? 1 : p_fmat->NumBatches(); + std::vector batch_ptr{this->batch_ptr_}; if (is_concat) { // Concatenate the batch ptrs as well. batch_ptr = {static_cast(0), p_fmat->Info().num_row_}; } - // Initialize partitions if (!partitioners_.empty()) { CHECK_EQ(partitioners_.size(), n_batches); } @@ -230,8 +225,20 @@ struct GPUHistMakerDevice { CHECK_EQ(partitioners_.front()->Size(), p_fmat->Info().num_row_); } - // Other initializations - quantiser = std::make_unique(ctx_, this->gpair, p_fmat->Info()); + /** + * Initialize the evaluator + */ + this->column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.HostVector(), + param.colsample_bynode, param.colsample_bylevel, + param.colsample_bytree); + this->interaction_constraints.Reset(ctx_); + this->evaluator_.Reset(this->ctx_, *cuts_, info.feature_types.ConstDeviceSpan(), info.num_col_, + this->param, info.IsColumnSplit()); + + /** + * Other initializations + */ + this->quantiser = std::make_unique(ctx_, this->gpair, p_fmat->Info()); this->InitFeatureGroupsOnce(info); @@ -327,8 +334,8 @@ struct GPUHistMakerDevice { auto d_ridx = partitioners_.at(k)->GetRows(nidx); this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc, - feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, - d_node_hist, *quantiser); + feature_groups->DeviceAccessor(ctx_->Device()), this->gpair, + d_ridx, d_node_hist, *quantiser); monitor.Stop(__func__); } @@ -678,11 +685,11 @@ struct GPUHistMakerDevice { constexpr bst_node_t kRootNIdx = RegTree::kRoot; auto quantiser = *this->quantiser; auto gpair_it = dh::MakeTransformIterator( - dh::tbegin(gpair), + dh::tbegin(this->gpair), [=] __device__(auto const& gpair) { return quantiser.ToFixedPoint(gpair); }); GradientPairInt64 root_sum_quantised = - dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{}, - thrust::plus{}); + dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + this->gpair.size(), + GradientPairInt64{}, thrust::plus{}); using ReduceT = typename decltype(root_sum_quantised)::ValueT; auto rc = collective::GlobalSum( ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast(&root_sum_quantised), 2)); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 724ecf87b..bafe52591 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -539,6 +539,7 @@ class QuantileHistMaker : public TreeUpdater { // Copy gradient into buffer for sampling. This converts C-order to F-order. std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out)); } + error::NoPageConcat(this->hist_param_.extmem_concat_pages); SampleGradient(ctx_, *param, h_sample_out); auto *h_out_position = &out_position[tree_it - trees.begin()]; if ((*tree_it)->IsMultiTarget()) { diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 8729eba82..0117cc8f2 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -496,7 +496,7 @@ auto MakeExtMemForTest(bst_idx_t n_samples, bst_feature_t n_features, Json dconf NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches}; auto Xy = std::make_shared( - &iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits::quiet_NaN(), 0, ""); + &iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits::quiet_NaN(), 0, "", false); MakeLabelForTest(Xy, p_fmat); return std::pair{p_fmat, Xy}; } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index f6991cfd5..a7c1bb3af 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -37,7 +37,8 @@ void TestSparseDMatrixLoadFile(Context const* ctx) { data::fileiter::Next, std::numeric_limits::quiet_NaN(), n_threads, - tmpdir.path + "cache"}; + tmpdir.path + "cache", + false}; ASSERT_EQ(AllThreadsForTest(), m.Ctx()->Threads()); ASSERT_EQ(m.Info().num_col_, 5); ASSERT_EQ(m.Info().num_row_, 64); @@ -364,9 +365,9 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) { CreateBigTestData(filename, 1 << 16); data::FileIterator iter(filename + "?format=libsvm", 0, 1); - std::unique_ptr sparse{ - new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, - std::numeric_limits::quiet_NaN(), threads, filename}}; + std::unique_ptr sparse{new data::SparsePageDMatrix{ + &iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next, + std::numeric_limits::quiet_NaN(), threads, filename, false}}; CHECK(sparse->Ctx()->Threads() == threads || sparse->Ctx()->Threads() == AllThreadsForTest()); DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 2c3bcdd88..45b3f7967 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -81,10 +81,11 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; - GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); - auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get()); - auto p_fmat = sample.p_fmat; - ASSERT_EQ(p_fmat, dmat.get()); + ASSERT_THAT( + [&] { + GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true); + }, + GMockThrow("extmem_concat_pages")); } TEST(GradientBasedSampler, UniformSampling) { @@ -120,4 +121,4 @@ TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) { constexpr bool kFixedSizeSampling = false; VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } -}; // namespace xgboost::tree +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ebd92510d..77ff69a97 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -23,7 +23,7 @@ namespace xgboost::tree { namespace { void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix* dmat, RegTree* tree, HostDeviceVector* preds, float subsample, - const std::string& sampling_method, bst_bin_t max_bin) { + const std::string& sampling_method, bst_bin_t max_bin, bool concat_pages) { Args args{ {"max_depth", "2"}, {"max_bin", std::to_string(max_bin)}, @@ -38,13 +38,17 @@ void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix ObjInfo task{ObjInfo::kRegression}; std::unique_ptr hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)}; - hist_maker->Configure(Args{}); + if (subsample < 1.0) { + hist_maker->Configure(Args{{"extmem_concat_pages", std::to_string(concat_pages)}}); + } else { + hist_maker->Configure(Args{}); + } std::vector> position(1); hist_maker->Update(¶m, gpair, dmat, common::Span>{position}, {tree}); auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1); - if (subsample < 1.0 && !dmat->SingleColBlock()) { + if (subsample < 1.0 && !dmat->SingleColBlock() && concat_pages) { ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache)); } else { ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache)); @@ -69,12 +73,12 @@ TEST(GpuHist, UniformSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, ctx.Device()); - UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, false); // Build another tree using sampling. RegTree tree_sampling; HostDeviceVector preds_sampling(kRows, 0.0, ctx.Device()); UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "uniform", - kRows); + kRows, false); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); @@ -100,13 +104,13 @@ TEST(GpuHist, GradientBasedSampling) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, ctx.Device()); - UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, false); // Build another tree using sampling. RegTree tree_sampling; HostDeviceVector preds_sampling(kRows, 0.0, ctx.Device()); UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, - "gradient_based", kRows); + "gradient_based", kRows, false); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); @@ -137,11 +141,11 @@ TEST(GpuHist, ExternalMemory) { // Build a tree using the in-memory DMatrix. RegTree tree; HostDeviceVector preds(kRows, 0.0, ctx.Device()); - UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, true); // Build another tree using multiple ELLPACK pages. RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, ctx.Device()); - UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows); + UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows, true); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); @@ -181,14 +185,14 @@ TEST(GpuHist, ExternalMemoryWithSampling) { RegTree tree; HostDeviceVector preds(kRows, 0.0, ctx.Device()); - UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows, true); // Build another tree using multiple ELLPACK pages. common::GlobalRandom() = rng; RegTree tree_ext; HostDeviceVector preds_ext(kRows, 0.0, ctx.Device()); UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, kSubsample, kSamplingMethod, - kRows); + kRows, true); Json jtree{Object{}}; Json jtree_ext{Object{}}; @@ -228,6 +232,42 @@ TEST(GpuHist, MaxDepth) { ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error); } +TEST(GpuHist, PageConcatConfig) { + auto ctx = MakeCUDACtx(0); + bst_idx_t n_samples = 64, n_features = 32; + auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.Batches(2).GenerateSparsePageDMatrix( + "temp", true); + + auto learner = std::unique_ptr(Learner::Create({p_fmat})); + learner->SetParam("device", ctx.DeviceName()); + learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("subsample", "0.8"); + learner->Configure(); + + learner->UpdateOneIter(0, p_fmat); + learner->SetParam("extmem_concat_pages", "false"); + learner->Configure(); + // GPU Hist rebuilds the updater after configuration. Training continues + learner->UpdateOneIter(1, p_fmat); + + learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("subsample", "1.0"); + ASSERT_THAT([&] { learner->UpdateOneIter(2, p_fmat); }, GMockThrow("extmem_concat_pages")); + + // Throws error on CPU. + { + auto learner = std::unique_ptr(Learner::Create({p_fmat})); + learner->SetParam("extmem_concat_pages", "true"); + ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages")); + } + { + auto learner = std::unique_ptr(Learner::Create({p_fmat})); + learner->SetParam("extmem_concat_pages", "true"); + learner->SetParam("tree_method", "approx"); + ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages")); + } +} + namespace { RegTree GetHistTree(Context const* ctx, DMatrix* dmat) { ObjInfo task{ObjInfo::kRegression}; diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py index 76811675b..7198941cd 100644 --- a/tests/python-gpu/test_gpu_data_iterator.py +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -3,6 +3,8 @@ import sys import pytest from hypothesis import given, settings, strategies +import xgboost as xgb +from xgboost import testing as tm from xgboost.testing import no_cupy from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem @@ -72,6 +74,22 @@ def test_extmem_qdm( check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host) +def test_concat_pages() -> None: + it = tm.IteratorForTest(*tm.make_batches(64, 16, 4, use_cupy=True), cache=None) + Xy = xgb.ExtMemQuantileDMatrix(it) + with pytest.raises(ValueError, match="can not be used with concatenated pages"): + booster = xgb.train( + { + "device": "cuda", + "subsample": 0.5, + "sampling_method": "gradient_based", + "extmem_concat_pages": True, + "objective": "reg:absoluteerror", + }, + Xy, + ) + + @given( strategies.integers(1, 64), strategies.integers(1, 8), diff --git a/tests/python-gpu/test_gpu_demos.py b/tests/python-gpu/test_gpu_demos.py index d3b6089a3..61315cbe1 100644 --- a/tests/python-gpu/test_gpu_demos.py +++ b/tests/python-gpu/test_gpu_demos.py @@ -6,24 +6,32 @@ import pytest from xgboost import testing as tm -sys.path.append("tests/python") -import test_demos as td # noqa +DEMO_DIR = tm.demo_dir(__file__) +PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, "guide-python") @pytest.mark.skipif(**tm.no_cupy()) def test_data_iterator(): - script = os.path.join(td.PYTHON_DEMO_DIR, "quantile_data_iterator.py") + script = os.path.join(PYTHON_DEMO_DIR, "quantile_data_iterator.py") cmd = ["python", script] subprocess.check_call(cmd) def test_update_process_demo(): - script = os.path.join(td.PYTHON_DEMO_DIR, "update_process.py") + script = os.path.join(PYTHON_DEMO_DIR, "update_process.py") cmd = ["python", script] subprocess.check_call(cmd) def test_categorical_demo(): - script = os.path.join(td.PYTHON_DEMO_DIR, "categorical.py") + script = os.path.join(PYTHON_DEMO_DIR, "categorical.py") + cmd = ["python", script] + subprocess.check_call(cmd) + + +@pytest.mark.skipif(**tm.no_rmm()) +@pytest.mark.skipif(**tm.no_cupy()) +def test_external_memory_demo(): + script = os.path.join(PYTHON_DEMO_DIR, "external_memory.py") cmd = ["python", script] subprocess.check_call(cmd)