[EM] Make page concatenation optional. (#10826)
This PR introduces a new parameter `extmem_concat_pages` to make the page concatenation optional for GPU hist. In addition, the document is updated for the new GPU-based external memory.
This commit is contained in:
parent
215da76263
commit
e228c1a121
@ -10,8 +10,13 @@ instead of Quantile DMatrix. The feature is not ready for production use yet.
|
||||
|
||||
See :doc:`the tutorial </tutorials/external_memory>` for more details.
|
||||
|
||||
.. versionchanged:: 3.0.0
|
||||
|
||||
Added :py:class:`~xgboost.ExtMemQuantileDMatrix`.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable, List, Tuple
|
||||
@ -43,30 +48,40 @@ def make_batches(
|
||||
class Iterator(xgboost.DataIter):
|
||||
"""A custom iterator for loading files in batches."""
|
||||
|
||||
def __init__(self, file_paths: List[Tuple[str, str]]) -> None:
|
||||
def __init__(self, device: str, file_paths: List[Tuple[str, str]]) -> None:
|
||||
self.device = device
|
||||
|
||||
self._file_paths = file_paths
|
||||
self._it = 0
|
||||
# XGBoost will generate some cache files under current directory with the prefix
|
||||
# "cache"
|
||||
# XGBoost will generate some cache files under the current directory with the
|
||||
# prefix "cache"
|
||||
super().__init__(cache_prefix=os.path.join(".", "cache"))
|
||||
|
||||
def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Load a single batch of data."""
|
||||
X_path, y_path = self._file_paths[self._it]
|
||||
# When the `ExtMemQuantileDMatrix` is used, the device must match. This
|
||||
# constraint will be relaxed in the future.
|
||||
if self.device == "cpu":
|
||||
X = np.load(X_path)
|
||||
y = np.load(y_path)
|
||||
else:
|
||||
X = cp.load(X_path)
|
||||
y = cp.load(y_path)
|
||||
|
||||
assert X.shape[0] == y.shape[0]
|
||||
return X, y
|
||||
|
||||
def next(self, input_data: Callable) -> int:
|
||||
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is
|
||||
called by XGBoost during the construction of ``DMatrix``
|
||||
"""Advance the iterator by 1 step and pass the data to XGBoost. This function
|
||||
is called by XGBoost during the construction of ``DMatrix``
|
||||
|
||||
"""
|
||||
if self._it == len(self._file_paths):
|
||||
# return 0 to let XGBoost know this is the end of iteration
|
||||
return 0
|
||||
|
||||
# input_data is a function passed in by XGBoost who has the similar signature to
|
||||
# input_data is a function passed in by XGBoost and has the similar signature to
|
||||
# the ``DMatrix`` constructor.
|
||||
X, y = self.load_file()
|
||||
input_data(data=X, label=y)
|
||||
@ -78,27 +93,74 @@ class Iterator(xgboost.DataIter):
|
||||
self._it = 0
|
||||
|
||||
|
||||
def main(tmpdir: str) -> xgboost.Booster:
|
||||
# generate some random data for demo
|
||||
files = make_batches(1024, 17, 31, tmpdir)
|
||||
it = Iterator(files)
|
||||
def hist_train(it: Iterator) -> None:
|
||||
"""The hist tree method can use a special data structure `ExtMemQuantileDMatrix` for
|
||||
faster initialization and lower memory usage.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
"""
|
||||
# For non-data arguments, specify it here once instead of passing them by the `next`
|
||||
# method.
|
||||
missing = np.nan
|
||||
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
|
||||
|
||||
# ``approx`` is also supported, but less efficient due to sketching. GPU behaves
|
||||
# differently than CPU tree methods as it uses a hybrid approach. See tutorial in
|
||||
# doc for details.
|
||||
Xy = xgboost.ExtMemQuantileDMatrix(it, missing=np.nan, enable_categorical=False)
|
||||
booster = xgboost.train(
|
||||
{"tree_method": "hist", "max_depth": 4},
|
||||
{"tree_method": "hist", "max_depth": 4, "device": it.device},
|
||||
Xy,
|
||||
evals=[(Xy, "Train")],
|
||||
num_boost_round=10,
|
||||
)
|
||||
return booster
|
||||
booster.predict(Xy)
|
||||
|
||||
|
||||
def approx_train(it: Iterator) -> None:
|
||||
"""The approx tree method uses the basic `DMatrix`."""
|
||||
|
||||
# For non-data arguments, specify it here once instead of passing them by the `next`
|
||||
# method.
|
||||
Xy = xgboost.DMatrix(it, missing=np.nan, enable_categorical=False)
|
||||
# ``approx`` is also supported, but less efficient due to sketching. It's
|
||||
# recommended to use `hist` instead.
|
||||
booster = xgboost.train(
|
||||
{"tree_method": "approx", "max_depth": 4, "device": it.device},
|
||||
Xy,
|
||||
evals=[(Xy, "Train")],
|
||||
num_boost_round=10,
|
||||
)
|
||||
booster.predict(Xy)
|
||||
|
||||
|
||||
def main(tmpdir: str, args: argparse.Namespace) -> None:
|
||||
"""Entry point for training."""
|
||||
|
||||
# generate some random data for demo
|
||||
files = make_batches(
|
||||
n_samples_per_batch=1024, n_features=17, n_batches=31, tmpdir=tmpdir
|
||||
)
|
||||
it = Iterator(args.device, files)
|
||||
|
||||
hist_train(it)
|
||||
approx_train(it)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
|
||||
args = parser.parse_args()
|
||||
if args.device == "cuda":
|
||||
import cupy as cp
|
||||
import rmm
|
||||
from rmm.allocators.cupy import rmm_cupy_allocator
|
||||
|
||||
# It's important to use RMM for GPU-based external memory to improve performance.
|
||||
# If XGBoost is not built with RMM support, a warning will be raised.
|
||||
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
|
||||
rmm.mr.set_current_device_resource(mr)
|
||||
# Set the allocator for cupy as well.
|
||||
cp.cuda.set_allocator(rmm_cupy_allocator)
|
||||
# Make sure XGBoost is using RMM for all allocations.
|
||||
with xgboost.config_context(use_rmm=True):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
main(tmpdir)
|
||||
main(tmpdir, args)
|
||||
else:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
main(tmpdir, args)
|
||||
|
||||
@ -55,9 +55,9 @@ When submitting the XGBoost application to the Spark cluster, you only need to s
|
||||
--jars xgboost-spark_2.12-3.0.0.jar \
|
||||
... \
|
||||
|
||||
**************
|
||||
***************
|
||||
XGBoost Ranking
|
||||
**************
|
||||
***************
|
||||
|
||||
Learning to rank using XGBoostRegressor has been replaced by a dedicated `XGBoostRanker`, which is specifically designed
|
||||
to support ranking algorithms.
|
||||
|
||||
@ -230,15 +230,35 @@ Parameters for Tree Booster
|
||||
- ``one_output_per_tree``: One model for each target.
|
||||
- ``multi_output_tree``: Use multi-target trees.
|
||||
|
||||
|
||||
Parameters for Non-Exact Tree Methods
|
||||
=====================================
|
||||
|
||||
* ``max_cached_hist_node``, [default = 65536]
|
||||
|
||||
Maximum number of cached nodes for histogram.
|
||||
Maximum number of cached nodes for histogram. This can be used with the ``hist`` and the
|
||||
``approx`` tree methods.
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
- For most of the cases this parameter should not be set except for growing deep
|
||||
trees. After 3.0, this parameter affects GPU algorithms as well.
|
||||
|
||||
|
||||
* ``extmem_concat_pages``, [default = ``false``]
|
||||
|
||||
This parameter is only used for the ``hist`` tree method with ``device=cuda`` and
|
||||
``subsample != 1.0``. Before 3.0, pages were always concatenated.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
Whether the GPU-based ``hist`` tree method should concatenate the training data into a
|
||||
single batch instead of fetching data on-demand when external memory is used. For GPU
|
||||
devices that don't support address translation services, external memory training is
|
||||
expensive. This parameter can be used in combination with subsampling to reduce overall
|
||||
memory usage without significant overhead. See :doc:`/tutorials/external_memory` for
|
||||
more information.
|
||||
|
||||
.. _cat-param:
|
||||
|
||||
Parameters for Categorical Feature
|
||||
|
||||
@ -26,6 +26,12 @@ Core Data Structure
|
||||
|
||||
.. autoclass:: xgboost.QuantileDMatrix
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.ExtMemQuantileDMatrix
|
||||
:members:
|
||||
:inherited-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: xgboost.Booster
|
||||
|
||||
@ -4,15 +4,13 @@ Using XGBoost External Memory Version
|
||||
|
||||
When working with large datasets, training XGBoost models can be challenging as the entire
|
||||
dataset needs to be loaded into memory. This can be costly and sometimes
|
||||
infeasible. Staring from 1.5, users can define a custom iterator to load data in chunks
|
||||
for running XGBoost algorithms. External memory can be used for both training and
|
||||
prediction, but training is the primary use case and it will be our focus in this
|
||||
tutorial. For prediction and evaluation, users can iterate through the data themselves
|
||||
while training requires the full dataset to be loaded into the memory.
|
||||
|
||||
During training, there are two different modes for external memory support available in
|
||||
XGBoost, one for CPU-based algorithms like ``hist`` and ``approx``, another one for the
|
||||
GPU-based training algorithm. We will introduce them in the following sections.
|
||||
infeasible. Starting from 1.5, users can define a custom iterator to load data in chunks
|
||||
for running XGBoost algorithms. External memory can be used for training and prediction,
|
||||
but training is the primary use case and it will be our focus in this tutorial. For
|
||||
prediction and evaluation, users can iterate through the data themselves, whereas training
|
||||
requires the entire dataset to be loaded into the memory. Significant progress was made in
|
||||
the 3.0 release for the GPU implementation. We will introduce the difference between CPU
|
||||
and GPU in the following sections.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -20,27 +18,33 @@ GPU-based training algorithm. We will introduce them in the following sections.
|
||||
|
||||
.. note::
|
||||
|
||||
The feature is still experimental as of 2.0. The performance is not well optimized.
|
||||
The feature is considered experimental but ready for public testing in 3.0. Vector-leaf
|
||||
is not yet supported.
|
||||
|
||||
The external memory support has gone through multiple iterations and is still under heavy
|
||||
development. Like the :py:class:`~xgboost.QuantileDMatrix` with
|
||||
:py:class:`~xgboost.DataIter`, XGBoost loads data batch-by-batch using a custom iterator
|
||||
supplied by the user. However, unlike the :py:class:`~xgboost.QuantileDMatrix`, external
|
||||
memory will not concatenate the batches unless GPU is used (it uses a hybrid approach,
|
||||
more details follow). Instead, it will cache all batches on the external memory and fetch
|
||||
them on-demand. Go to the end of the document to see a comparison between
|
||||
:py:class:`~xgboost.QuantileDMatrix` and external memory.
|
||||
The external memory support has undergone multiple development iterations. Like the
|
||||
:py:class:`~xgboost.QuantileDMatrix` with :py:class:`~xgboost.DataIter`, XGBoost loads
|
||||
data batch-by-batch using a custom iterator supplied by the user. However, unlike the
|
||||
:py:class:`~xgboost.QuantileDMatrix`, external memory does not concatenate the batches
|
||||
(unless specified by the ``extmem_concat_pages``) . Instead, it caches all batches in the
|
||||
external memory and fetch them on-demand. Go to the end of the document to see a
|
||||
comparison between :py:class:`~xgboost.QuantileDMatrix` and the external memory version of
|
||||
:py:class:`~xgboost.ExtMemQuantileDMatrix`.
|
||||
|
||||
**Contents**
|
||||
|
||||
.. contents::
|
||||
:backlinks: none
|
||||
:local:
|
||||
|
||||
*************
|
||||
Data Iterator
|
||||
*************
|
||||
|
||||
Starting from XGBoost 1.5, users can define their own data loader using Python or C
|
||||
interface. There are some examples in the ``demo`` directory for quick start. This is a
|
||||
generalized version of text input external memory, where users no longer need to prepare a
|
||||
text file that XGBoost recognizes. To enable the feature, users need to define a data
|
||||
iterator with 2 class methods: ``next`` and ``reset``, then pass it into the
|
||||
:py:class:`~xgboost.DMatrix` constructor.
|
||||
Starting with XGBoost 1.5, users can define their own data loader using Python or C
|
||||
interface. Some examples are in the ``demo`` directory for a quick start. To enable
|
||||
external memory training, users need to define a data iterator with 2 class methods:
|
||||
``next`` and ``reset``, then pass it into the :py:class:`~xgboost.DMatrix` or the
|
||||
:py:class:`~xgboost.ExtMemQuantileDMatrix` constructor.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -53,7 +57,7 @@ iterator with 2 class methods: ``next`` and ``reset``, then pass it into the
|
||||
def __init__(self, svm_file_paths: List[str]):
|
||||
self._file_paths = svm_file_paths
|
||||
self._it = 0
|
||||
# XGBoost will generate some cache files under current directory with the prefix
|
||||
# XGBoost will generate some cache files under the current directory with the prefix
|
||||
# "cache"
|
||||
super().__init__(cache_prefix=os.path.join(".", "cache"))
|
||||
|
||||
@ -63,10 +67,10 @@ iterator with 2 class methods: ``next`` and ``reset``, then pass it into the
|
||||
|
||||
"""
|
||||
if self._it == len(self._file_paths):
|
||||
# return 0 to let XGBoost know this is the end of iteration
|
||||
# return 0 to let XGBoost know this is the end of the iteration
|
||||
return 0
|
||||
|
||||
# input_data is a function passed in by XGBoost who has the exact same signature of
|
||||
# input_data is a function passed in by XGBoost and has the exact same signature of
|
||||
# ``DMatrix``
|
||||
X, y = load_svmlight_file(self._file_paths[self._it])
|
||||
input_data(data=X, label=y)
|
||||
@ -79,59 +83,106 @@ iterator with 2 class methods: ``next`` and ``reset``, then pass it into the
|
||||
self._it = 0
|
||||
|
||||
it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"])
|
||||
Xy = xgboost.DMatrix(it)
|
||||
|
||||
# The ``approx`` also work, but with low performance. GPU implementation is different from CPU.
|
||||
# as noted in following sections.
|
||||
Xy = xgboost.ExtMemQuantileDMatrix(it)
|
||||
booster = xgboost.train({"tree_method": "hist"}, Xy)
|
||||
|
||||
# The ``approx`` tree method also works, but with lower performance and cannot be used
|
||||
# with the quantile DMatrix.
|
||||
|
||||
Xy = xgboost.DMatrix(it)
|
||||
booster = xgboost.train({"tree_method": "approx"}, Xy)
|
||||
|
||||
The above snippet is a simplified version of :ref:`sphx_glr_python_examples_external_memory.py`.
|
||||
For an example in C, please see ``demo/c-api/external-memory/``. The iterator is the
|
||||
common interface for using external memory with XGBoost, you can pass the resulting
|
||||
:py:class:`DMatrix` object for training, prediction, and evaluation.
|
||||
:py:class:`~xgboost.DMatrix` object for training, prediction, and evaluation.
|
||||
|
||||
The :py:class:`~xgboost.ExtMemQuantileDMatrix` is an external memory version of the
|
||||
:py:class:`~xgboost.QuantileDMatrix`. These two classes are specifically designed for the
|
||||
``hist`` tree method for reduced memory usage and data loading overhead. See respective
|
||||
references for more info.
|
||||
|
||||
It is important to set the batch size based on the memory available. A good starting point
|
||||
is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not*
|
||||
recommended to set small batch sizes like 32 samples per batch, as this can seriously hurt
|
||||
performance in gradient boosting.
|
||||
|
||||
***********
|
||||
CPU Version
|
||||
***********
|
||||
|
||||
In the previous section, we demonstrated how to train a tree-based model using the
|
||||
``hist`` tree method on a CPU. This method involves iterating through data batches stored
|
||||
in a cache during tree construction. For optimal performance, we recommend using the
|
||||
``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer of tree
|
||||
nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy
|
||||
requires XGBoost to iterate over the data set for each tree node, resulting in slower
|
||||
performance.
|
||||
|
||||
If external memory is used, the performance of CPU training is limited by IO
|
||||
(input/output) speed. This means that the disk IO speed primarily determines the training
|
||||
speed. During benchmarking, we used an NVMe connected to a PCIe-4 slot, other types of
|
||||
storage can be too slow for practical usage. In addition, your system may perform caching
|
||||
to reduce the overhead of file reading.
|
||||
for CPU is to set the batch size to 10GB per batch if you have 64GB of memory. It is *not*
|
||||
recommended to set small batch sizes like 32 samples per batch, as this can severely hurt
|
||||
performance in gradient boosting. See below sections for information about the GPU version
|
||||
and other best practices.
|
||||
|
||||
**********************************
|
||||
GPU Version (GPU Hist tree method)
|
||||
**********************************
|
||||
|
||||
External memory is supported by GPU algorithms (i.e. when ``device`` is set to
|
||||
``cuda``). However, the algorithm used for GPU is different from the one used for
|
||||
CPU. When training on a CPU, the tree method iterates through all batches from external
|
||||
memory for each step of the tree construction algorithm. On the other hand, the GPU
|
||||
algorithm uses a hybrid approach. It iterates through the data during the beginning of
|
||||
each iteration and concatenates all batches into one in GPU memory for performance
|
||||
reasons. To reduce overall memory usage, users can utilize subsampling. The GPU hist tree
|
||||
method supports `gradient-based sampling`, enabling users to set a low sampling rate
|
||||
without compromising accuracy.
|
||||
External memory is supported by GPU algorithms (i.e., when ``device`` is set to
|
||||
``cuda``). Starting with 3.0, the default GPU implementation is similar to what the CPU
|
||||
version does. It also supports the use of :py:class:`~xgboost.ExtMemQuantileDMatrix` when
|
||||
the ``hist`` tree method is employed. For a GPU device, the main memory is the device
|
||||
memory, whereas the external memory can be either a disk or the CPU memory. XGBoost stages
|
||||
the cache on CPU memory by default. Users can change the backing storage to disk by
|
||||
specifying the ``on_host`` parameter in the :py:class:`~xgboost.DataIter`. However, using
|
||||
the disk is not recommended. It's likely to make the GPU slower than the CPU. The option is
|
||||
here for experimental purposes only.
|
||||
|
||||
Inputs to the :py:class:`~xgboost.ExtMemQuantileDMatrix` (through the iterator) must be on
|
||||
the GPU. This is a current limitation we aim to address in the future.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import cupy as cp
|
||||
import rmm
|
||||
from rmm.allocators.cupy import rmm_cupy_allocator
|
||||
|
||||
# It's important to use RMM for GPU-based external memory to improve performance.
|
||||
# If XGBoost is not built with RMM support, a warning will be raised.
|
||||
mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource())
|
||||
rmm.mr.set_current_device_resource(mr)
|
||||
# Set the allocator for cupy as well.
|
||||
cp.cuda.set_allocator(rmm_cupy_allocator)
|
||||
# Make sure XGBoost is using RMM for all allocations.
|
||||
with xgboost.config_context(use_rmm=True):
|
||||
# Construct the iterators for ExtMemQuantileDMatrix
|
||||
# ...
|
||||
# Build the ExtMemQuantileDMatrix and start training
|
||||
Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins)
|
||||
Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train)
|
||||
booster = xgboost.train(
|
||||
{
|
||||
"tree_method": "hist",
|
||||
"max_depth": 6,
|
||||
"max_bin": n_bins,
|
||||
"device": device,
|
||||
},
|
||||
Xy_train,
|
||||
num_boost_round=n_rounds,
|
||||
evals=[(Xy_train, "Train"), (Xy_valid, "Valid")]
|
||||
)
|
||||
|
||||
It's crucial to use `RAPIDS Memory Manager (RMM) <https://github.com/rapidsai/rmm>`__ for
|
||||
all memory allocation when training with external memory. XGBoost relies on the memory
|
||||
pool to reduce the overhead for data fetching. The size of each batch should be slightly
|
||||
smaller than a quarter of the available GPU memory. In addition, the open source `NVIDIA
|
||||
Linux driver
|
||||
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
|
||||
is required for ``Heterogeneous memory management (HMM)`` support.
|
||||
|
||||
In addition to the batch-based data fetching, the GPU version supports concatenating
|
||||
batches into a single blob for the training data to improve performance. For GPUs
|
||||
connected via PCIe instead of nvlink, the performance overhead with batch-based training
|
||||
is significant, particularly for non-dense data. Overall, it can be at least five times
|
||||
slower than in-core training. Concatenating pages can be used to get the performance
|
||||
closer to in-core training. This option should be used in combination with subsampling to
|
||||
reduce the memory usage. During concatenation, subsampling removes a portion of samples,
|
||||
reducing the training dataset size. The GPU hist tree method supports `gradient-based
|
||||
sampling`, enabling users to set a low sampling rate without compromising accuracy. Before
|
||||
3.0, concatenation with subsampling was the only option for GPU-based external
|
||||
memory. After 3.0, XGBoost uses the regular batch fetching as the default while the page
|
||||
concatenation can be enabled by:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
param = {
|
||||
...
|
||||
"device": "cuda",
|
||||
"extmem_concat_pages": true,
|
||||
'subsample': 0.2,
|
||||
'sampling_method': 'gradient_based',
|
||||
}
|
||||
@ -139,10 +190,70 @@ without compromising accuracy.
|
||||
For more information about the sampling algorithm and its use in external memory training,
|
||||
see `this paper <https://arxiv.org/abs/2005.09148>`_.
|
||||
|
||||
.. warning::
|
||||
==========
|
||||
NVLink-C2C
|
||||
==========
|
||||
|
||||
When GPU is running out of memory during iteration on external memory, user might
|
||||
receive a segfault instead of an OOM exception.
|
||||
The newer NVIDIA platforms like `Grace-Hopper
|
||||
<https://www.nvidia.com/en-us/data-center/grace-hopper-superchip/>`__ use `NVLink-C2C
|
||||
<https://www.nvidia.com/en-us/data-center/nvlink-c2c/>`__, which facilitates a fast
|
||||
interconnect between the CPU and the GPU. With the host memory serving as the data cache,
|
||||
XGBoost can retrieve data with significantly lower overhead. When the input data is dense,
|
||||
there's minimal to no performance loss for training, except for the initial construction
|
||||
of the :py:class:`~xgboost.ExtMemQuantileDMatrix`. The initial construction iterates
|
||||
through the input data twice, as a result, the most significantly overhead compared to
|
||||
in-core training is one additional data read when the data is dense.
|
||||
|
||||
To run experiments on these platforms, the open source `NVIDIA Linux driver
|
||||
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
|
||||
with version ``>=565.47`` is required.
|
||||
|
||||
**************
|
||||
Best Practices
|
||||
**************
|
||||
|
||||
In previous sections, we demonstrated how to train a tree-based model with data residing
|
||||
on an external memory and made some recommendations for batch size. Here are some other
|
||||
configurations we find useful. The external memory feature involves iterating through data
|
||||
batches stored in a cache during tree construction. For optimal performance, we recommend
|
||||
using the ``grow_policy=depthwise`` setting, which allows XGBoost to build an entire layer
|
||||
of tree nodes with only a few batch iterations. Conversely, using the ``lossguide`` policy
|
||||
requires XGBoost to iterate over the data set for each tree node, resulting in
|
||||
significantly slower performance.
|
||||
|
||||
In addition, this ``hist`` tree method should be preferred over the ``approx`` tree method
|
||||
as the former doesn't recreate the histogram bins for every iteration. Creating the
|
||||
histogram bins requires loading the raw input data, which is prohibitively expensive. The
|
||||
:py:class:`~xgboost.ExtMemQuantileDMatrix` designed for the ``hist`` tree method can speed
|
||||
up the initial data construction and the evaluation significantly for external memory.
|
||||
|
||||
Since the external memory implementation focuses on training where XGBoost needs to access
|
||||
the entire dataset, only the ``X`` is divided into batches while everything else is
|
||||
concatenated. As a result, it's recommended for users to define their own management code
|
||||
to iterate through the data for inference, especially for SHAP value computation. The size
|
||||
of SHAP results can be larger than ``X``, making external memory in XGBoost less
|
||||
effective. Some frameworks like ``dask`` can help with the data chunking and iterate
|
||||
through the data for inference with memory spilling.
|
||||
|
||||
When external memory is used, the performance of CPU training is limited by disk IO
|
||||
(input/output) speed. This means that the disk IO speed primarily determines the training
|
||||
speed. Similarly, PCIe bandwidth limits the GPU performance, assuming the CPU memory is
|
||||
used as a cache and address translation services (ATS) is unavailable. We recommend using
|
||||
regular :py:class:`~xgboost.QuantileDMatrix` over
|
||||
:py:class:`~xgboost.ExtMemQuantileDMatrix` for constructing the validation dataset when
|
||||
feasible. Running inference is much less computation-intensive than training and, hence,
|
||||
much faster. For GPU, the time it takes to read the data from host to device completely
|
||||
determines the time it takes to run inference, even if a C2C link is available.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Try to use `QuantileDMatrix` for the validation if it can be fit into the GPU memory.
|
||||
Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins)
|
||||
Xy_valid = xgboost.QuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train)
|
||||
|
||||
During CPU benchmarking, we used an NVMe connected to a PCIe-4 slot. Other types of
|
||||
storage can be too slow for practical usage. However, your system will likely perform some
|
||||
caching to reduce the overhead of the file read. See the following sections for remarks.
|
||||
|
||||
.. _ext_remarks:
|
||||
|
||||
@ -157,43 +268,43 @@ and internal runtime structures are concatenated. This means that memory reducti
|
||||
effective when dealing with wide datasets where ``X`` is significantly larger in size
|
||||
compared to other data like ``y``, while it has little impact on slim datasets.
|
||||
|
||||
As one might expect, fetching data on-demand puts significant pressure on the storage
|
||||
device. Today's computing device can process way more data than a storage can read in a
|
||||
single unit of time. The ratio is at order of magnitudes. An GPU is capable of processing
|
||||
hundred of Gigabytes of floating-point data in a split second. On the other hand, a
|
||||
four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data transfer
|
||||
rate. As a result, the training is likely to be severely bounded by your storage
|
||||
As one might expect, fetching data on demand puts significant pressure on the storage
|
||||
device. Today's computing devices can process way more data than storage devices can read
|
||||
in a single unit of time. The ratio is in the order of magnitudes. A GPU is capable of
|
||||
processing hundreds of Gigabytes of floating-point data in a split second. On the other
|
||||
hand, a four-lane NVMe storage connected to a PCIe-4 slot usually has about 6GB/s of data
|
||||
transfer rate. As a result, the training is likely to be severely bounded by your storage
|
||||
device. Before adopting the external memory solution, some back-of-envelop calculations
|
||||
might help you see whether it's viable. For instance, if your NVMe drive can transfer 4GB
|
||||
(a fairly practical number) of data per second and you have a 100GB of data in compressed
|
||||
XGBoost cache (which corresponds to a dense float32 numpy array with the size of 200GB,
|
||||
give or take). A tree with depth 8 needs at least 16 iterations through the data when the
|
||||
parameter is right. You need about 14 minutes to train a single tree without accounting
|
||||
might help you determine its viability. For instance, if your NVMe drive can transfer 4GB
|
||||
(a reasonably practical number) of data per second, and you have a 100GB of data in a
|
||||
compressed XGBoost cache (corresponding to a dense float32 numpy array with 200GB, give or
|
||||
take). A tree with depth 8 needs at least 16 iterations through the data when the
|
||||
parameter is optimal. You need about 14 minutes to train a single tree without accounting
|
||||
for some other overheads and assume the computation overlaps with the IO. If your dataset
|
||||
happens to have TB-level size, then you might need thousands of trees to get a generalized
|
||||
model. These calculations can help you get an estimate on the expected training time.
|
||||
happens to have a TB-level size, you might need thousands of trees to get a generalized
|
||||
model. These calculations can help you get an estimate of the expected training time.
|
||||
|
||||
However, sometimes we can ameliorate this limitation. One should also consider that the OS
|
||||
(mostly talking about the Linux kernel) can usually cache the data on host memory. It only
|
||||
evicts pages when new data comes in and there's no room left. In practice, at least some
|
||||
portion of the data can persist on the host memory throughout the entire training
|
||||
However, sometimes, we can ameliorate this limitation. One should also consider that the
|
||||
OS (mainly talking about the Linux kernel) can usually cache the data on host memory. It
|
||||
only evicts pages when new data comes in and there's no room left. In practice, at least
|
||||
some portion of the data can persist in the host memory throughout the entire training
|
||||
session. We are aware of this cache when optimizing the external memory fetcher. The
|
||||
compressed cache is usually smaller than the raw input data, especially when the input is
|
||||
dense without any missing value. If the host memory can fit a significant portion of this
|
||||
compressed cache, then the performance should be decent after initialization. Our
|
||||
development so far focus on two fronts of optimization for external memory:
|
||||
compressed cache, the performance should be decent after initialization. Our development
|
||||
so far focuses on following fronts of optimization for external memory:
|
||||
|
||||
- Avoid iterating through the data whenever appropriate.
|
||||
- If the OS can cache the data, the performance should be close to in-core training.
|
||||
- For GPU, the actual computation should overlap with memory copy as much as possible.
|
||||
|
||||
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It is not
|
||||
tested against system errors like disconnected network devices (`SIGBUS`). In the face of
|
||||
a bus error, you will see a hard crash and need to clean up the cache files. If the
|
||||
training session might take a long time and you are using solutions like NVMe-oF, we
|
||||
Starting with XGBoost 2.0, the implementation of external memory uses ``mmap``. It has not
|
||||
been tested against system errors like disconnected network devices (`SIGBUS`). In the
|
||||
face of a bus error, you will see a hard crash and need to clean up the cache files. If
|
||||
the training session might take a long time and you use solutions like NVMe-oF, we
|
||||
recommend checkpointing your model periodically. Also, it's worth noting that most tests
|
||||
have been conducted on Linux distributions.
|
||||
|
||||
|
||||
Another important point to keep in mind is that creating the initial cache for XGBoost may
|
||||
take some time. The interface to external memory is through custom iterators, which we can
|
||||
not assume to be thread-safe. Therefore, initialization is performed sequentially. Using
|
||||
@ -206,13 +317,30 @@ Compared to the QuantileDMatrix
|
||||
|
||||
Passing an iterator to the :py:class:`~xgboost.QuantileDMatrix` enables direct
|
||||
construction of :py:class:`~xgboost.QuantileDMatrix` with data chunks. On the other hand,
|
||||
if it's passed to :py:class:`~xgboost.DMatrix`, it instead enables the external memory
|
||||
feature. The :py:class:`~xgboost.QuantileDMatrix` concatenates the data on memory after
|
||||
if it's passed to the :py:class:`~xgboost.DMatrix` or the
|
||||
:py:class:`~xgboost.ExtMemQuantileDMatrix`, it instead enables the external memory
|
||||
feature. The :py:class:`~xgboost.QuantileDMatrix` concatenates the data in memory after
|
||||
compression and doesn't fetch data during training. On the other hand, the external memory
|
||||
:py:class:`~xgboost.DMatrix` fetches data batches from external memory on-demand. Use the
|
||||
:py:class:`~xgboost.QuantileDMatrix` (with iterator if necessary) when you can fit most of
|
||||
your data in memory. The training would be an order of magnitude faster than using
|
||||
external memory.
|
||||
:py:class:`~xgboost.DMatrix` (:py:class:`~xgboost.ExtMemQuantileDMatrix`) fetches data
|
||||
batches from external memory on demand. Use the :py:class:`~xgboost.QuantileDMatrix` (with
|
||||
iterator if necessary) when you can fit most of your data in memory. For many platforms,
|
||||
the training speed can be an order of magnitude faster than external memory.
|
||||
|
||||
*************
|
||||
Brief History
|
||||
*************
|
||||
|
||||
For a long time, external memory support has been an experimental feature and has
|
||||
undergone multiple development iterations. Here's a brief summary of major changes:
|
||||
|
||||
- Gradient-based sampling was introduced to the GPU hist in 1.1.
|
||||
- The iterator interface was introduced in 1.5, along with a major rewrite for the
|
||||
internal framework.
|
||||
- 2.0 introduced the use of ``mmap``, along with optimization in XBGoost to enable
|
||||
zero-copy data fetching.
|
||||
- 3.0 reworked the GPU implementation to support caching data on the host and disk,
|
||||
introduced the :py:class:`~xgboost.ExtMemQuantileDMatrix` class, added quantile-based
|
||||
objectives support.
|
||||
|
||||
****************
|
||||
Text File Inputs
|
||||
@ -220,11 +348,11 @@ Text File Inputs
|
||||
|
||||
.. warning::
|
||||
|
||||
This is the original form of external memory support before 1.5, users are encouraged
|
||||
to use custom data iterator instead.
|
||||
This is the original form of external memory support before 1.5 and is now deprecated,
|
||||
users are encouraged to use a custom data iterator instead.
|
||||
|
||||
There is no big difference between using external memory version of text input and the
|
||||
in-memory version. The only difference is the filename format.
|
||||
There is no significant difference between using the external memory version of text input
|
||||
and the in-memory version of text input. The only difference is the filename format.
|
||||
|
||||
The external memory version takes in the following `URI
|
||||
<https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format:
|
||||
@ -233,7 +361,7 @@ The external memory version takes in the following `URI
|
||||
|
||||
filename?format=libsvm#cacheprefix
|
||||
|
||||
The ``filename`` is the normal path to LIBSVM format file you want to load in, and
|
||||
The ``filename`` is the typical path to LIBSVM format file you want to load in, and
|
||||
``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed
|
||||
data in binary form.
|
||||
|
||||
@ -256,4 +384,4 @@ XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to
|
||||
``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For
|
||||
more notes about text input formats, see :doc:`/tutorials/input_format`.
|
||||
|
||||
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
|
||||
For the CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.
|
||||
|
||||
@ -504,8 +504,8 @@ def _prediction_output(
|
||||
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
"""The interface for user defined data iterator. The iterator facilitates
|
||||
distributed training, :py:class:`QuantileDMatrix`, and external memory support using
|
||||
:py:class:`DMatrix`. Most of time, users don't need to interact with this class
|
||||
directly.
|
||||
:py:class:`DMatrix` or :py:class:`ExtMemQuantileDMatrix`. Most of time, users don't
|
||||
need to interact with this class directly.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -525,15 +525,16 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
keep the cache.
|
||||
|
||||
on_host :
|
||||
Whether the data should be cached on host memory instead of harddrive when using
|
||||
GPU with external memory. If set to true, then the "external memory" would
|
||||
simply be CPU (host) memory.
|
||||
Whether the data should be cached on the host memory instead of the file system
|
||||
when using GPU with external memory. When set to true (the default), the
|
||||
"external memory" is the CPU (host) memory. See
|
||||
:doc:`/tutorials/external_memory` for more info.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
.. warning::
|
||||
|
||||
This is still working in progress, not ready for test yet.
|
||||
This is an experimental parameter.
|
||||
|
||||
"""
|
||||
|
||||
@ -541,7 +542,7 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
self,
|
||||
cache_prefix: Optional[str] = None,
|
||||
release_data: bool = True,
|
||||
on_host: bool = False,
|
||||
on_host: bool = True,
|
||||
) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
self.on_host = on_host
|
||||
@ -1681,9 +1682,12 @@ class QuantileDMatrix(DMatrix):
|
||||
class ExtMemQuantileDMatrix(DMatrix):
|
||||
"""The external memory version of the :py:class:`QuantileDMatrix`.
|
||||
|
||||
See :doc:`/tutorials/external_memory` for explanation and usage examples, and
|
||||
:py:class:`QuantileDMatrix` for parameter document.
|
||||
|
||||
.. warning::
|
||||
|
||||
This is still working in progress, not ready for test yet.
|
||||
This is an experimental feature.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
|
||||
@ -1699,6 +1703,13 @@ class ExtMemQuantileDMatrix(DMatrix):
|
||||
ref: Optional[DMatrix] = None,
|
||||
enable_categorical: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
A user-defined :py:class:`DataIter` for loading data.
|
||||
|
||||
"""
|
||||
self.max_bin = max_bin
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
self.nthread = nthread if nthread is not None else -1
|
||||
|
||||
@ -106,9 +106,10 @@ inline auto NoCategorical(std::string name) {
|
||||
return name + " doesn't support categorical features.";
|
||||
}
|
||||
|
||||
inline void NoOnHost(bool on_host) {
|
||||
if (on_host) {
|
||||
LOG(FATAL) << "Caching on host memory is only available for GPU.";
|
||||
inline void NoPageConcat(bool concat_pages) {
|
||||
if (concat_pages) {
|
||||
LOG(FATAL) << "`extmem_concat_pages` must be false when there's no sampling or when it's "
|
||||
"running on the CPU.";
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
|
||||
24
src/data/batch_utils.cuh
Normal file
24
src/data/batch_utils.cuh
Normal file
@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Copyright 2024, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "xgboost/data.h" // for BatchParam
|
||||
|
||||
namespace xgboost::data::cuda_impl {
|
||||
// Use two batch for prefecting. There's always one batch being worked on, while the other
|
||||
// batch being transferred.
|
||||
constexpr auto DftPrefetchBatches() { return 2; }
|
||||
|
||||
// Empty parameter to prevent regen, only used to control external memory prefetching.
|
||||
//
|
||||
// Both the approx and hist initializes the DMatrix before creating the actual
|
||||
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
|
||||
// parameter to avoid any regen.
|
||||
inline BatchParam StaticBatch(bool prefetch_copy) {
|
||||
BatchParam p;
|
||||
p.prefetch_copy = prefetch_copy;
|
||||
p.n_prefetch_batches = DftPrefetchBatches();
|
||||
return p;
|
||||
}
|
||||
} // namespace xgboost::data::cuda_impl
|
||||
@ -920,7 +920,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
1,
|
||||
cache_file};
|
||||
cache_file,
|
||||
false};
|
||||
}
|
||||
|
||||
return dmat;
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/cuda_rt_utils.h" // for SupportsPageableMem
|
||||
#include "../common/cuda_rt_utils.h" // for SupportsPageableMem, SupportsAts
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "ellpack_page_raw_format.h" // for EllpackPageRawFormat
|
||||
@ -67,7 +67,20 @@ class EllpackFormatPolicy {
|
||||
using FormatT = EllpackPageRawFormat;
|
||||
|
||||
public:
|
||||
EllpackFormatPolicy() = default;
|
||||
EllpackFormatPolicy() {
|
||||
StringView msg{" The overhead of iterating through external memory might be significant."};
|
||||
if (!has_hmm_) {
|
||||
LOG(WARNING) << "CUDA heterogeneous memory management is not available." << msg;
|
||||
} else if (!common::SupportsAts()) {
|
||||
LOG(WARNING) << "CUDA address translation service is not available." << msg;
|
||||
}
|
||||
#if !defined(XGBOOST_USE_RMM)
|
||||
LOG(WARNING) << "XGBoost is not built with RMM support." << msg;
|
||||
#endif
|
||||
if (!GlobalConfigThreadLocalStore::Get()->use_rmm) {
|
||||
LOG(WARNING) << "`use_rmm` is set to false." << msg;
|
||||
}
|
||||
}
|
||||
// For testing with the HMM flag.
|
||||
explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {}
|
||||
|
||||
@ -135,6 +148,9 @@ class EllpackMmapStreamPolicy : public F<S> {
|
||||
bst_idx_t length) const;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Ellpack source with sparse pages as the underlying source.
|
||||
*/
|
||||
template <typename F>
|
||||
class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
|
||||
using Super = PageSourceIncMixIn<EllpackPage, F>;
|
||||
@ -171,6 +187,9 @@ using EllpackPageHostSource =
|
||||
using EllpackPageSource =
|
||||
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
|
||||
|
||||
/**
|
||||
* @brief Ellpack source directly interfaces with user-defined iterators.
|
||||
*/
|
||||
template <typename FormatCreatePolicy>
|
||||
class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy> {
|
||||
using Super = ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy>;
|
||||
@ -201,6 +220,7 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, Forma
|
||||
info_{info},
|
||||
ext_info_{std::move(ext_info)},
|
||||
base_rows_{std::move(base_rows)} {
|
||||
cuts->SetDevice(ctx->Device());
|
||||
this->SetCuts(std::move(cuts), ctx->Device());
|
||||
this->Fetch();
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "proxy_dmatrix.h" // for DataIterProxy, HostAdapterDispatch
|
||||
#include "quantile_dmatrix.h" // for GetDataShape, MakeSketches
|
||||
#include "simple_batch_iterator.h" // for SimpleBatchIteratorImpl
|
||||
#include "sparse_page_source.h" // for MakeCachePrefix
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
@ -26,6 +27,7 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix
|
||||
std::int32_t n_threads, std::string cache,
|
||||
bst_bin_t max_bin, bool on_host)
|
||||
: cache_prefix_{std::move(cache)}, on_host_{on_host} {
|
||||
cache_prefix_ = MakeCachePrefix(cache_prefix_);
|
||||
auto iter = std::make_shared<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>(
|
||||
iter_handle, reset, next);
|
||||
iter->Reset();
|
||||
|
||||
@ -13,9 +13,9 @@
|
||||
#include <utility> // for move
|
||||
#include <variant> // for visit
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "gradient_index.h"
|
||||
#include "gradient_index.h" // for GHistIndexMatrix
|
||||
#include "sparse_page_source.h" // for MakeCachePrefix
|
||||
|
||||
namespace xgboost::data {
|
||||
MetaInfo &SparsePageDMatrix::Info() { return info_; }
|
||||
@ -34,12 +34,9 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
|
||||
cache_prefix_{std::move(cache_prefix)},
|
||||
on_host_{on_host} {
|
||||
Context ctx;
|
||||
ctx.nthread = nthreads;
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthreads)}});
|
||||
cache_prefix_ = MakeCachePrefix(cache_prefix_);
|
||||
|
||||
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
|
||||
if (collective::IsDistributed()) {
|
||||
cache_prefix_ += ("-r" + std::to_string(collective::GetRank()));
|
||||
}
|
||||
DMatrixProxy *proxy = MakeProxy(proxy_);
|
||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||
iter_, reset_, next_};
|
||||
@ -107,7 +104,6 @@ BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||
auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_);
|
||||
CHECK_NE(this->Info().num_col_, 0);
|
||||
error::NoOnHost(on_host_);
|
||||
this->InitializeSparsePage(ctx);
|
||||
if (!column_source_) {
|
||||
column_source_ =
|
||||
@ -122,7 +118,6 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
|
||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) {
|
||||
auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_);
|
||||
CHECK_NE(this->Info().num_col_, 0);
|
||||
error::NoOnHost(on_host_);
|
||||
this->InitializeSparsePage(ctx);
|
||||
if (!sorted_column_source_) {
|
||||
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
|
||||
@ -140,7 +135,6 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
}
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
error::NoOnHost(on_host_);
|
||||
auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_);
|
||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||
this->InitializeSparsePage(ctx);
|
||||
|
||||
@ -70,10 +70,10 @@ class SparsePageDMatrix : public DMatrix {
|
||||
DataIterResetCallback *reset_;
|
||||
XGDMatrixCallbackNext *next_;
|
||||
|
||||
float missing_;
|
||||
float const missing_;
|
||||
Context fmat_ctx_;
|
||||
std::string cache_prefix_;
|
||||
bool on_host_{false};
|
||||
bool const on_host_;
|
||||
std::uint32_t n_batches_{0};
|
||||
// sparse page is the source to other page types, we make a special member function.
|
||||
void InitializeSparsePage(Context const *ctx);
|
||||
@ -83,7 +83,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
public:
|
||||
explicit SparsePageDMatrix(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int32_t nthreads,
|
||||
std::string cache_prefix, bool on_host = false);
|
||||
std::string cache_prefix, bool on_host);
|
||||
|
||||
~SparsePageDMatrix() override;
|
||||
|
||||
|
||||
@ -54,22 +54,18 @@ class SparsePageRawFormat : public SparsePageFormat<T> {
|
||||
private:
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
|
||||
#define SparsePageFmt SparsePageFormat<SparsePage>
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SparsePage>, SparsePageFmt, raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat<SparsePage>();
|
||||
});
|
||||
.set_body([]() { return new SparsePageRawFormat<SparsePage>(); });
|
||||
|
||||
XGBOOST_REGISTER_CSC_PAGE_FORMAT(raw)
|
||||
#define CSCPageFmt SparsePageFormat<CSCPage>
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<CSCPage>, CSCPageFmt, raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat<CSCPage>();
|
||||
});
|
||||
.set_body([]() { return new SparsePageRawFormat<CSCPage>(); });
|
||||
|
||||
XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(raw)
|
||||
#define SortedCSCPageFmt SparsePageFormat<SortedCSCPage>
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SortedCSCPage>, SortedCSCPageFmt, raw)
|
||||
.describe("Raw binary data format.")
|
||||
.set_body([]() {
|
||||
return new SparsePageRawFormat<SortedCSCPage>();
|
||||
});
|
||||
|
||||
.set_body([]() { return new SparsePageRawFormat<SortedCSCPage>(); });
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
#include <numeric> // for partial_sum
|
||||
#include <string> // for string
|
||||
|
||||
#include "../collective/communicator-inl.h" // for IsDistributed, GetRank
|
||||
|
||||
namespace xgboost::data {
|
||||
void Cache::Commit() {
|
||||
if (!this->written) {
|
||||
@ -28,6 +30,14 @@ void TryDeleteCacheFile(const std::string& file) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string MakeCachePrefix(std::string cache_prefix) {
|
||||
cache_prefix = cache_prefix.empty() ? "DMatrix" : cache_prefix;
|
||||
if (collective::IsDistributed()) {
|
||||
cache_prefix += ("-r" + std::to_string(collective::GetRank()));
|
||||
}
|
||||
return cache_prefix;
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
|
||||
#endif
|
||||
|
||||
@ -33,6 +33,8 @@
|
||||
namespace xgboost::data {
|
||||
void TryDeleteCacheFile(const std::string& file);
|
||||
|
||||
std::string MakeCachePrefix(std::string cache_prefix);
|
||||
|
||||
/**
|
||||
* @brief Information about the cache including path and page offsets.
|
||||
*/
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* Copyright 2014-2024, XGBoost Contributors
|
||||
* \file sparse_page_writer.h
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
@ -11,7 +11,6 @@
|
||||
|
||||
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
|
||||
#include "dmlc/registry.h" // for Registry, FunctionRegEntryBase
|
||||
#include "xgboost/data.h" // for SparsePage,CSCPage,SortedCSCPage,EllpackPage ...
|
||||
|
||||
namespace xgboost::data {
|
||||
template<typename T>
|
||||
@ -54,47 +53,13 @@ inline SparsePageFormat<T>* CreatePageFormat(const std::string& name) {
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for sparse page format.
|
||||
/**
|
||||
* @brief Registry entry for sparse page format.
|
||||
*/
|
||||
template<typename T>
|
||||
struct SparsePageFormatReg
|
||||
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg<T>,
|
||||
std::function<SparsePageFormat<T>* ()>> {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register sparse page format.
|
||||
*
|
||||
* \code
|
||||
* // example of registering a objective
|
||||
* XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
|
||||
* .describe("Raw binary data format.")
|
||||
* .set_body([]() {
|
||||
* return new RawFormat();
|
||||
* });
|
||||
* \endcode
|
||||
*/
|
||||
#define SparsePageFmt SparsePageFormat<SparsePage>
|
||||
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SparsePage>, SparsePageFmt, Name)
|
||||
|
||||
#define CSCPageFmt SparsePageFormat<CSCPage>
|
||||
#define XGBOOST_REGISTER_CSC_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<CSCPage>, CSCPageFmt, Name)
|
||||
|
||||
#define SortedCSCPageFmt SparsePageFormat<SortedCSCPage>
|
||||
#define XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<SortedCSCPage>, SortedCSCPageFmt, Name)
|
||||
|
||||
#define EllpackPageFmt SparsePageFormat<EllpackPage>
|
||||
#define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<EllpackPage>, EllpackPageFmt, Name)
|
||||
|
||||
#define GHistIndexPageFmt SparsePageFormat<GHistIndexMatrix>
|
||||
#define XGBOOST_REGISTER_GHIST_INDEX_PAGE_FORMAT(Name) \
|
||||
DMLC_REGISTRY_REGISTER(SparsePageFormatReg<GHistIndexMatrix>, \
|
||||
GHistIndexPageFmt, Name)
|
||||
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
|
||||
|
||||
@ -14,9 +14,10 @@
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
|
||||
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs, SetDevice
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/error_msg.h" // for InplacePredictProxy
|
||||
#include "../data/batch_utils.cuh" // for StaticBatch
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
@ -31,6 +32,8 @@
|
||||
namespace xgboost::predictor {
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
using data::cuda_impl::StaticBatch;
|
||||
|
||||
struct TreeView {
|
||||
RegTree::CategoricalSplitMatrix cats;
|
||||
common::Span<RegTree::Node const> d_tree;
|
||||
@ -475,15 +478,14 @@ struct PathInfo {
|
||||
};
|
||||
|
||||
// Transform model into path element form for GPUTreeShap
|
||||
void ExtractPaths(
|
||||
void ExtractPaths(Context const* ctx,
|
||||
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>* paths,
|
||||
DeviceModel* model, dh::device_vector<uint32_t>* path_categories,
|
||||
DeviceOrd device) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
common::SetDevice(device.ordinal);
|
||||
auto& device_model = *model;
|
||||
|
||||
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
|
||||
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
|
||||
auto d_nodes = device_model.nodes.ConstDeviceSpan();
|
||||
auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan();
|
||||
auto nodes_transform = dh::MakeTransformIterator<PathInfo>(
|
||||
@ -502,8 +504,7 @@ void ExtractPaths(
|
||||
}
|
||||
return PathInfo{static_cast<int64_t>(idx), path_length, tree_idx};
|
||||
});
|
||||
auto end = thrust::copy_if(
|
||||
thrust::cuda::par(alloc), nodes_transform,
|
||||
auto end = thrust::copy_if(ctx->CUDACtx()->CTP(), nodes_transform,
|
||||
nodes_transform + d_nodes.size(), info.begin(),
|
||||
[=] __device__(const PathInfo& e) { return e.leaf_position != -1; });
|
||||
info.resize(end - info.begin());
|
||||
@ -511,8 +512,7 @@ void ExtractPaths(
|
||||
info.begin(),
|
||||
[=] __device__(const PathInfo& info) { return info.length; });
|
||||
dh::caching_device_vector<size_t> path_segments(info.size() + 1);
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator,
|
||||
length_iterator + info.size() + 1,
|
||||
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), length_iterator, length_iterator + info.size() + 1,
|
||||
path_segments.begin());
|
||||
|
||||
paths->resize(path_segments.back());
|
||||
@ -528,19 +528,17 @@ void ExtractPaths(
|
||||
auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan();
|
||||
|
||||
size_t max_cat = 0;
|
||||
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
|
||||
if (thrust::any_of(ctx->CUDACtx()->CTP(), dh::tbegin(d_split_types), dh::tend(d_split_types),
|
||||
common::IsCatOp{})) {
|
||||
dh::PinnedMemory pinned;
|
||||
auto h_max_cat = pinned.GetSpan<RegTree::CategoricalSplitMatrix::Segment>(1);
|
||||
auto max_elem_it = dh::MakeTransformIterator<size_t>(
|
||||
dh::tbegin(d_cat_node_segments),
|
||||
[] __device__(RegTree::CategoricalSplitMatrix::Segment seg) { return seg.size; });
|
||||
size_t max_cat_it =
|
||||
thrust::max_element(thrust::device, max_elem_it,
|
||||
size_t max_cat_it = thrust::max_element(ctx->CUDACtx()->CTP(), max_elem_it,
|
||||
max_elem_it + d_cat_node_segments.size()) -
|
||||
max_elem_it;
|
||||
dh::safe_cuda(cudaMemcpy(h_max_cat.data(),
|
||||
d_cat_node_segments.data() + max_cat_it,
|
||||
dh::safe_cuda(cudaMemcpy(h_max_cat.data(), d_cat_node_segments.data() + max_cat_it,
|
||||
h_max_cat.size_bytes(), cudaMemcpyDeviceToHost));
|
||||
max_cat = h_max_cat[0].size;
|
||||
CHECK_GE(max_cat, 1);
|
||||
@ -550,7 +548,7 @@ void ExtractPaths(
|
||||
auto d_model_categories = device_model.categories.DeviceSpan();
|
||||
common::Span<uint32_t> d_path_categories = dh::ToSpan(*path_categories);
|
||||
|
||||
dh::LaunchN(info.size(), [=] __device__(size_t idx) {
|
||||
dh::LaunchN(info.size(), ctx->CUDACtx()->Stream(), [=] __device__(size_t idx) {
|
||||
auto path_info = d_info[idx];
|
||||
size_t tree_offset = d_tree_segments[path_info.tree_idx];
|
||||
TreeView tree{0, path_info.tree_idx, d_nodes,
|
||||
@ -864,7 +862,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
auto const kernel = [&](auto predict_fn) {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes, ctx_->CUDACtx()->Stream()}(
|
||||
predict_fn, data, model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||
@ -888,7 +886,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel d_model;
|
||||
|
||||
bool use_shared = false;
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}(
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, 0, ctx_->CUDACtx()->Stream()}(
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch, model.nodes.ConstDeviceSpan(),
|
||||
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||
@ -924,7 +922,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
} else {
|
||||
bst_idx_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
dmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_, feature_types), d_model,
|
||||
@ -989,7 +987,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes, ctx_->CUDACtx()->Stream()}(
|
||||
PredictKernel<Loader, typename Loader::BatchT>, m->Value(), d_model.nodes.ConstDeviceSpan(),
|
||||
out_preds->predictions.DeviceSpan(), d_model.tree_segments.ConstDeviceSpan(),
|
||||
d_model.tree_group.ConstDeviceSpan(), d_model.split_types.ConstDeviceSpan(),
|
||||
@ -1055,7 +1053,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, tree_end, ctx_->Device());
|
||||
dh::device_vector<uint32_t> categories;
|
||||
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
|
||||
ExtractPaths(ctx_, &device_paths, &d_model, &categories, ctx_->Device());
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(ctx_->Device());
|
||||
@ -1067,7 +1065,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
|
||||
}
|
||||
} else {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
|
||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_)};
|
||||
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
|
||||
std::numeric_limits<float>::quiet_NaN()};
|
||||
@ -1083,7 +1081,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
auto base_score = model.learner_model_param->BaseScore(ctx_);
|
||||
dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) {
|
||||
phis[(idx + 1) * contributions_columns - 1] +=
|
||||
margin.empty() ? base_score(0) : margin[idx];
|
||||
});
|
||||
@ -1125,7 +1123,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, tree_end, ctx_->Device());
|
||||
dh::device_vector<uint32_t> categories;
|
||||
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
|
||||
ExtractPaths(ctx_, &device_paths, &d_model, &categories, ctx_->Device());
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(ctx_->Device());
|
||||
@ -1137,7 +1135,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
|
||||
}
|
||||
} else {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
auto impl = batch.Impl();
|
||||
auto acc = impl->GetDeviceAccessor(ctx_, p_fmat->Info().feature_types.ConstDeviceSpan());
|
||||
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
|
||||
@ -1155,7 +1153,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
auto base_score = model.learner_model_param->BaseScore(ctx_);
|
||||
size_t n_features = model.learner_model_param->num_feature;
|
||||
dh::LaunchN(p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
ctx_->CUDACtx()->Stream(), [=] __device__(size_t idx) {
|
||||
size_t group = idx % ngroup;
|
||||
size_t row_idx = idx / ngroup;
|
||||
phis[gpu_treeshap::IndexPhiInteractions(row_idx, ngroup, group, n_features,
|
||||
@ -1199,7 +1197,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
bst_feature_t num_features = info.num_col_;
|
||||
|
||||
auto launch = [&](auto fn, std::uint32_t grid, auto data, bst_idx_t batch_offset) {
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes}(
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()}(
|
||||
fn, data, d_model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(),
|
||||
|
||||
@ -1223,7 +1221,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
} else {
|
||||
bst_idx_t batch_offset = 0;
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_)};
|
||||
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
|
||||
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
|
||||
|
||||
@ -148,18 +148,7 @@ class PoissonSampling : public thrust::binary_function<GradientPair, size_t, Gra
|
||||
CombineGradientPair combine_;
|
||||
};
|
||||
|
||||
NoSampling::NoSampling(BatchParam batch_param) : batch_param_(std::move(batch_param)) {}
|
||||
|
||||
GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) {
|
||||
return {dmat, gpair};
|
||||
}
|
||||
|
||||
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
|
||||
: batch_param_{std::move(batch_param)} {}
|
||||
|
||||
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const*,
|
||||
common::Span<GradientPair> gpair,
|
||||
DMatrix* p_fmat) {
|
||||
return {p_fmat, gpair};
|
||||
}
|
||||
@ -246,9 +235,10 @@ GradientBasedSampling::GradientBasedSampling(std::size_t n_rows, BatchParam batc
|
||||
grad_sum_(n_rows, 0.0f) {}
|
||||
|
||||
GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
||||
common::Span<GradientPair> gpair, DMatrix* dmat) {
|
||||
common::Span<GradientPair> gpair,
|
||||
DMatrix* p_fmat) {
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
size_t n_rows = dmat->Info().num_row_;
|
||||
size_t n_rows = p_fmat->Info().num_row_;
|
||||
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
|
||||
ctx, gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);
|
||||
|
||||
@ -257,7 +247,7 @@ GradientBasedSample GradientBasedSampling::Sample(Context const* ctx,
|
||||
thrust::counting_iterator<size_t>(0), dh::tbegin(gpair),
|
||||
PoissonSampling(dh::ToSpan(threshold_), threshold_index,
|
||||
RandomWeight(common::GlobalRandom()())));
|
||||
return {dmat, gpair};
|
||||
return {p_fmat, gpair};
|
||||
}
|
||||
|
||||
ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(size_t n_rows,
|
||||
@ -323,46 +313,46 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c
|
||||
|
||||
GradientBasedSampler::GradientBasedSampler(Context const* /*ctx*/, size_t n_rows,
|
||||
const BatchParam& batch_param, float subsample,
|
||||
int sampling_method, bool is_external_memory) {
|
||||
int sampling_method, bool concat_pages) {
|
||||
// The ctx is kept here for future development of stream-based operations.
|
||||
monitor_.Init("gradient_based_sampler");
|
||||
monitor_.Init(__func__);
|
||||
|
||||
bool is_sampling = subsample < 1.0;
|
||||
|
||||
if (is_sampling) {
|
||||
if (!is_sampling) {
|
||||
strategy_.reset(new NoSampling{});
|
||||
error::NoPageConcat(concat_pages);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (sampling_method) {
|
||||
case TrainParam::kUniform:
|
||||
if (is_external_memory) {
|
||||
case TrainParam::kUniform: {
|
||||
if (concat_pages) {
|
||||
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
|
||||
} else {
|
||||
strategy_.reset(new UniformSampling(batch_param, subsample));
|
||||
}
|
||||
break;
|
||||
case TrainParam::kGradientBased:
|
||||
if (is_external_memory) {
|
||||
}
|
||||
case TrainParam::kGradientBased: {
|
||||
if (concat_pages) {
|
||||
strategy_.reset(new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
|
||||
} else {
|
||||
strategy_.reset(new GradientBasedSampling(n_rows, batch_param, subsample));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "unknown sampling method";
|
||||
}
|
||||
} else {
|
||||
if (is_external_memory) {
|
||||
strategy_.reset(new ExternalMemoryNoSampling(batch_param));
|
||||
} else {
|
||||
strategy_.reset(new NoSampling(batch_param));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sample a DMatrix based on the given gradient pairs.
|
||||
GradientBasedSample GradientBasedSampler::Sample(Context const* ctx,
|
||||
common::Span<GradientPair> gpair, DMatrix* dmat) {
|
||||
monitor_.Start("Sample");
|
||||
monitor_.Start(__func__);
|
||||
GradientBasedSample sample = strategy_->Sample(ctx, gpair, dmat);
|
||||
monitor_.Stop("Sample");
|
||||
monitor_.Stop(__func__);
|
||||
return sample;
|
||||
}
|
||||
|
||||
|
||||
@ -24,31 +24,29 @@ class SamplingStrategy {
|
||||
virtual GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) = 0;
|
||||
virtual ~SamplingStrategy() = default;
|
||||
/**
|
||||
* @brief Whether pages are concatenated after sampling.
|
||||
*/
|
||||
[[nodiscard]] virtual bool ConcatPages() const { return false; }
|
||||
};
|
||||
|
||||
/*! \brief No sampling in in-memory mode. */
|
||||
class ExtMemSamplingStrategy : public SamplingStrategy {
|
||||
public:
|
||||
[[nodiscard]] bool ConcatPages() const final { return true; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief No-op.
|
||||
*/
|
||||
class NoSampling : public SamplingStrategy {
|
||||
public:
|
||||
explicit NoSampling(BatchParam batch_param);
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) override;
|
||||
|
||||
private:
|
||||
BatchParam batch_param_;
|
||||
};
|
||||
|
||||
/*! \brief No sampling in external memory mode. */
|
||||
class ExternalMemoryNoSampling : public SamplingStrategy {
|
||||
public:
|
||||
explicit ExternalMemoryNoSampling(BatchParam batch_param);
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
DMatrix* dmat) override;
|
||||
|
||||
private:
|
||||
BatchParam batch_param_;
|
||||
};
|
||||
|
||||
/*! \brief Uniform sampling in in-memory mode. */
|
||||
/**
|
||||
* @brief Uniform sampling in in-memory mode.
|
||||
*/
|
||||
class UniformSampling : public SamplingStrategy {
|
||||
public:
|
||||
UniformSampling(BatchParam batch_param, float subsample);
|
||||
@ -61,7 +59,7 @@ class UniformSampling : public SamplingStrategy {
|
||||
};
|
||||
|
||||
/*! \brief No sampling in external memory mode. */
|
||||
class ExternalMemoryUniformSampling : public SamplingStrategy {
|
||||
class ExternalMemoryUniformSampling : public ExtMemSamplingStrategy {
|
||||
public:
|
||||
ExternalMemoryUniformSampling(size_t n_rows, BatchParam batch_param, float subsample);
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
@ -91,7 +89,7 @@ class GradientBasedSampling : public SamplingStrategy {
|
||||
};
|
||||
|
||||
/*! \brief Gradient-based sampling in external memory mode.. */
|
||||
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||
class ExternalMemoryGradientBasedSampling : public ExtMemSamplingStrategy {
|
||||
public:
|
||||
ExternalMemoryGradientBasedSampling(size_t n_rows, BatchParam batch_param, float subsample);
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
|
||||
@ -120,7 +118,7 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
|
||||
class GradientBasedSampler {
|
||||
public:
|
||||
GradientBasedSampler(Context const* ctx, size_t n_rows, const BatchParam& batch_param,
|
||||
float subsample, int sampling_method, bool is_external_memory);
|
||||
float subsample, int sampling_method, bool concat_pages);
|
||||
|
||||
/*! \brief Sample from a DMatrix based on the given gradient pairs. */
|
||||
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair, DMatrix* dmat);
|
||||
@ -130,6 +128,8 @@ class GradientBasedSampler {
|
||||
common::Span<float> threshold, common::Span<float> grad_sum,
|
||||
size_t sample_rows);
|
||||
|
||||
[[nodiscard]] bool ConcatPages() const { return this->strategy_->ConcatPages(); }
|
||||
|
||||
private:
|
||||
common::Monitor monitor_;
|
||||
std::unique_ptr<SamplingStrategy> strategy_;
|
||||
|
||||
@ -23,6 +23,7 @@ struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
||||
constexpr static std::size_t CudaDefaultNodes() { return static_cast<std::size_t>(1) << 12; }
|
||||
|
||||
bool debug_synchronize{false};
|
||||
bool extmem_concat_pages{false};
|
||||
|
||||
void CheckTreesSynchronized(Context const* ctx, RegTree const* local_tree) const;
|
||||
|
||||
@ -42,6 +43,7 @@ struct HistMakerTrainParam : public XGBoostParameter<HistMakerTrainParam> {
|
||||
.set_default(NotSet())
|
||||
.set_lower_bound(1)
|
||||
.describe("Maximum number of nodes in histogram cache.");
|
||||
DMLC_DECLARE_FIELD(extmem_concat_pages).set_default(false);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -278,7 +278,7 @@ class GlobalApproxUpdater : public TreeUpdater {
|
||||
*sampled = linalg::Empty<GradientPair>(ctx_, gpair->Size(), 1);
|
||||
auto in = gpair->HostView().Values();
|
||||
std::copy(in.data(), in.data() + in.size(), sampled->HostView().Values().data());
|
||||
|
||||
error::NoPageConcat(this->hist_param_.extmem_concat_pages);
|
||||
SampleGradient(ctx_, param, sampled->HostView());
|
||||
}
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <limits> // for numeric_limits
|
||||
#include <ostream> // for ostream
|
||||
|
||||
#include "../data/batch_utils.cuh" // for DftPrefetchBatches, StaticBatch
|
||||
#include "gpu_hist/quantiser.cuh" // for GradientQuantiser
|
||||
#include "param.h" // for TrainParam
|
||||
#include "xgboost/base.h" // for bst_bin_t
|
||||
@ -119,26 +120,19 @@ struct DeviceSplitCandidate {
|
||||
};
|
||||
|
||||
namespace cuda_impl {
|
||||
constexpr auto DftPrefetchBatches() { return 2; }
|
||||
|
||||
inline BatchParam HistBatch(TrainParam const& param) {
|
||||
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
|
||||
p.prefetch_copy = true;
|
||||
p.n_prefetch_batches = DftPrefetchBatches();
|
||||
p.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches();
|
||||
return p;
|
||||
}
|
||||
|
||||
inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess,
|
||||
ObjInfo const& task) {
|
||||
return BatchParam{p.max_bin, hess, !task.const_hess};
|
||||
}
|
||||
|
||||
// Empty parameter to prevent regen, only used to control external memory prefetching.
|
||||
inline BatchParam StaticBatch(bool prefetch_copy) {
|
||||
BatchParam p;
|
||||
p.prefetch_copy = prefetch_copy;
|
||||
p.n_prefetch_batches = DftPrefetchBatches();
|
||||
return p;
|
||||
auto batch = BatchParam{p.max_bin, hess, !task.const_hess};
|
||||
batch.prefetch_copy = true;
|
||||
batch.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches();
|
||||
return batch;
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/random.h" // for ColumnSampler, GlobalRandom
|
||||
#include "../common/timer.h"
|
||||
#include "../data/batch_utils.cuh" // for StaticBatch
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/ellpack_page.h"
|
||||
#include "constraints.cuh"
|
||||
@ -50,11 +51,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
|
||||
using cuda_impl::ApproxBatch;
|
||||
using cuda_impl::HistBatch;
|
||||
|
||||
// Both the approx and hist initializes the DMatrix before creating the actual
|
||||
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
|
||||
// parameter to avoid any regen.
|
||||
using cuda_impl::StaticBatch;
|
||||
using data::cuda_impl::StaticBatch;
|
||||
|
||||
// Extra data for each node that is passed to the update position function
|
||||
struct NodeSplitData {
|
||||
@ -102,11 +99,11 @@ struct GPUHistMakerDevice {
|
||||
std::vector<std::unique_ptr<RowPartitioner>> partitioners_;
|
||||
|
||||
DeviceHistogramBuilder histogram_;
|
||||
std::vector<bst_idx_t> batch_ptr_;
|
||||
std::vector<bst_idx_t> const batch_ptr_;
|
||||
// node idx for each sample
|
||||
dh::device_vector<bst_node_t> positions_;
|
||||
HistMakerTrainParam const* hist_param_;
|
||||
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
|
||||
std::shared_ptr<common::HistogramCuts const> const cuts_;
|
||||
|
||||
auto CreatePartitionNodes(RegTree const* p_tree, std::vector<GPUExpandEntry> const& candidates) {
|
||||
std::vector<bst_node_t> nidx(candidates.size());
|
||||
@ -135,35 +132,35 @@ struct GPUHistMakerDevice {
|
||||
|
||||
dh::device_vector<int> monotone_constraints;
|
||||
|
||||
TrainParam param;
|
||||
TrainParam const param;
|
||||
|
||||
std::unique_ptr<GradientQuantiser> quantiser;
|
||||
|
||||
dh::PinnedMemory pinned;
|
||||
dh::PinnedMemory pinned2;
|
||||
|
||||
common::Monitor monitor;
|
||||
FeatureInteractionConstraintDevice interaction_constraints;
|
||||
|
||||
std::unique_ptr<GradientBasedSampler> sampler;
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
common::Monitor monitor;
|
||||
|
||||
GPUHistMakerDevice(Context const* ctx, TrainParam _param, HistMakerTrainParam const* hist_param,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler, BatchParam batch_param,
|
||||
MetaInfo const& info, std::vector<bst_idx_t> batch_ptr,
|
||||
std::shared_ptr<common::HistogramCuts const> cuts)
|
||||
: evaluator_{_param, static_cast<bst_feature_t>(info.num_col_), ctx->Device()},
|
||||
ctx_(ctx),
|
||||
param(std::move(_param)),
|
||||
column_sampler_(std::move(column_sampler)),
|
||||
interaction_constraints(param, static_cast<bst_feature_t>(info.num_col_)),
|
||||
ctx_{ctx},
|
||||
column_sampler_{std::move(column_sampler)},
|
||||
batch_ptr_{std::move(batch_ptr)},
|
||||
hist_param_{hist_param},
|
||||
cuts_{std::move(cuts)} {
|
||||
this->sampler =
|
||||
std::make_unique<GradientBasedSampler>(ctx, info.num_row_, batch_param, param.subsample,
|
||||
param.sampling_method, batch_ptr_.size() > 2);
|
||||
cuts_{std::move(cuts)},
|
||||
param{std::move(_param)},
|
||||
interaction_constraints(param, static_cast<bst_feature_t>(info.num_col_)),
|
||||
sampler{std::make_unique<GradientBasedSampler>(
|
||||
ctx, info.num_row_, batch_param, param.subsample, param.sampling_method,
|
||||
batch_ptr_.size() > 2 && this->hist_param_->extmem_concat_pages)} {
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
@ -185,33 +182,31 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
// Reset values for each update iteration
|
||||
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
|
||||
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair> const* dh_gpair, DMatrix* p_fmat) {
|
||||
this->monitor.Start(__func__);
|
||||
common::SetDevice(ctx_->Ordinal());
|
||||
|
||||
auto const& info = p_fmat->Info();
|
||||
// backup the gradient
|
||||
dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair, ctx_->CUDACtx()->Stream());
|
||||
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
this->interaction_constraints.Reset(ctx_);
|
||||
this->evaluator_.Reset(this->ctx_, *cuts_, p_fmat->Info().feature_types.ConstDeviceSpan(),
|
||||
p_fmat->Info().num_col_, this->param, p_fmat->Info().IsColumnSplit());
|
||||
|
||||
// Sampling
|
||||
/**
|
||||
* Sampling
|
||||
*/
|
||||
dh::CopyTo(dh_gpair->ConstDeviceSpan(), &this->d_gpair, ctx_->CUDACtx()->Stream());
|
||||
auto sample = this->sampler->Sample(ctx_, dh::ToSpan(d_gpair), p_fmat);
|
||||
this->gpair = sample.gpair;
|
||||
p_fmat = sample.p_fmat; // Update p_fmat before allocating partitioners
|
||||
p_fmat = sample.p_fmat;
|
||||
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||
std::size_t n_batches = p_fmat->NumBatches();
|
||||
bool is_concat = (n_batches + 1) != this->batch_ptr_.size();
|
||||
std::vector<bst_idx_t> batch_ptr{batch_ptr_};
|
||||
|
||||
/**
|
||||
* Initialize the partitioners
|
||||
*/
|
||||
bool is_concat = sampler->ConcatPages();
|
||||
std::size_t n_batches = is_concat ? 1 : p_fmat->NumBatches();
|
||||
std::vector<bst_idx_t> batch_ptr{this->batch_ptr_};
|
||||
if (is_concat) {
|
||||
// Concatenate the batch ptrs as well.
|
||||
batch_ptr = {static_cast<bst_idx_t>(0), p_fmat->Info().num_row_};
|
||||
}
|
||||
// Initialize partitions
|
||||
if (!partitioners_.empty()) {
|
||||
CHECK_EQ(partitioners_.size(), n_batches);
|
||||
}
|
||||
@ -230,8 +225,20 @@ struct GPUHistMakerDevice {
|
||||
CHECK_EQ(partitioners_.front()->Size(), p_fmat->Info().num_row_);
|
||||
}
|
||||
|
||||
// Other initializations
|
||||
quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
|
||||
/**
|
||||
* Initialize the evaluator
|
||||
*/
|
||||
this->column_sampler_->Init(ctx_, info.num_col_, info.feature_weights.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
this->interaction_constraints.Reset(ctx_);
|
||||
this->evaluator_.Reset(this->ctx_, *cuts_, info.feature_types.ConstDeviceSpan(), info.num_col_,
|
||||
this->param, info.IsColumnSplit());
|
||||
|
||||
/**
|
||||
* Other initializations
|
||||
*/
|
||||
this->quantiser = std::make_unique<GradientQuantiser>(ctx_, this->gpair, p_fmat->Info());
|
||||
|
||||
this->InitFeatureGroupsOnce(info);
|
||||
|
||||
@ -327,8 +334,8 @@ struct GPUHistMakerDevice {
|
||||
|
||||
auto d_ridx = partitioners_.at(k)->GetRows(nidx);
|
||||
this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc,
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
|
||||
d_node_hist, *quantiser);
|
||||
feature_groups->DeviceAccessor(ctx_->Device()), this->gpair,
|
||||
d_ridx, d_node_hist, *quantiser);
|
||||
monitor.Stop(__func__);
|
||||
}
|
||||
|
||||
@ -678,11 +685,11 @@ struct GPUHistMakerDevice {
|
||||
constexpr bst_node_t kRootNIdx = RegTree::kRoot;
|
||||
auto quantiser = *this->quantiser;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
|
||||
dh::tbegin(gpair),
|
||||
dh::tbegin(this->gpair),
|
||||
[=] __device__(auto const& gpair) { return quantiser.ToFixedPoint(gpair); });
|
||||
GradientPairInt64 root_sum_quantised =
|
||||
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{},
|
||||
thrust::plus<GradientPairInt64>{});
|
||||
dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + this->gpair.size(),
|
||||
GradientPairInt64{}, thrust::plus<GradientPairInt64>{});
|
||||
using ReduceT = typename decltype(root_sum_quantised)::ValueT;
|
||||
auto rc = collective::GlobalSum(
|
||||
ctx_, p_fmat->Info(), linalg::MakeVec(reinterpret_cast<ReduceT*>(&root_sum_quantised), 2));
|
||||
|
||||
@ -539,6 +539,7 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
// Copy gradient into buffer for sampling. This converts C-order to F-order.
|
||||
std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out));
|
||||
}
|
||||
error::NoPageConcat(this->hist_param_.extmem_concat_pages);
|
||||
SampleGradient(ctx_, *param, h_sample_out);
|
||||
auto *h_out_position = &out_position[tree_it - trees.begin()];
|
||||
if ((*tree_it)->IsMultiTarget()) {
|
||||
|
||||
@ -496,7 +496,7 @@ auto MakeExtMemForTest(bst_idx_t n_samples, bst_feature_t n_features, Json dconf
|
||||
|
||||
NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches};
|
||||
auto Xy = std::make_shared<data::SparsePageDMatrix>(
|
||||
&iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, "");
|
||||
&iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, "", false);
|
||||
MakeLabelForTest(Xy, p_fmat);
|
||||
return std::pair{p_fmat, Xy};
|
||||
}
|
||||
|
||||
@ -37,7 +37,8 @@ void TestSparseDMatrixLoadFile(Context const* ctx) {
|
||||
data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
n_threads,
|
||||
tmpdir.path + "cache"};
|
||||
tmpdir.path + "cache",
|
||||
false};
|
||||
ASSERT_EQ(AllThreadsForTest(), m.Ctx()->Threads());
|
||||
ASSERT_EQ(m.Info().num_col_, 5);
|
||||
ASSERT_EQ(m.Info().num_row_, 64);
|
||||
@ -364,9 +365,9 @@ auto TestSparsePageDMatrixDeterminism(int32_t threads) {
|
||||
CreateBigTestData(filename, 1 << 16);
|
||||
|
||||
data::FileIterator iter(filename + "?format=libsvm", 0, 1);
|
||||
std::unique_ptr<DMatrix> sparse{
|
||||
new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), threads, filename}};
|
||||
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix{
|
||||
&iter, iter.Proxy(), data::fileiter::Reset, data::fileiter::Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), threads, filename, false}};
|
||||
CHECK(sparse->Ctx()->Threads() == threads || sparse->Ctx()->Threads() == AllThreadsForTest());
|
||||
|
||||
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
|
||||
|
||||
@ -81,10 +81,11 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
||||
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
|
||||
ASSERT_THAT(
|
||||
[&] {
|
||||
GradientBasedSampler sampler(&ctx, kRows, param, kSubsample, TrainParam::kUniform, true);
|
||||
auto sample = sampler.Sample(&ctx, gpair.DeviceSpan(), dmat.get());
|
||||
auto p_fmat = sample.p_fmat;
|
||||
ASSERT_EQ(p_fmat, dmat.get());
|
||||
},
|
||||
GMockThrow("extmem_concat_pages"));
|
||||
}
|
||||
|
||||
TEST(GradientBasedSampler, UniformSampling) {
|
||||
@ -120,4 +121,4 @@ TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) {
|
||||
constexpr bool kFixedSizeSampling = false;
|
||||
VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling);
|
||||
}
|
||||
}; // namespace xgboost::tree
|
||||
} // namespace xgboost::tree
|
||||
|
||||
@ -23,7 +23,7 @@ namespace xgboost::tree {
|
||||
namespace {
|
||||
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
|
||||
RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
|
||||
const std::string& sampling_method, bst_bin_t max_bin) {
|
||||
const std::string& sampling_method, bst_bin_t max_bin, bool concat_pages) {
|
||||
Args args{
|
||||
{"max_depth", "2"},
|
||||
{"max_bin", std::to_string(max_bin)},
|
||||
@ -38,13 +38,17 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)};
|
||||
if (subsample < 1.0) {
|
||||
hist_maker->Configure(Args{{"extmem_concat_pages", std::to_string(concat_pages)}});
|
||||
} else {
|
||||
hist_maker->Configure(Args{});
|
||||
}
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
hist_maker->Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
{tree});
|
||||
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||
if (subsample < 1.0 && !dmat->SingleColBlock()) {
|
||||
if (subsample < 1.0 && !dmat->SingleColBlock() && concat_pages) {
|
||||
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||
} else {
|
||||
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
|
||||
@ -69,12 +73,12 @@ TEST(GpuHist, UniformSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, false);
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||
kRows);
|
||||
kRows, false);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@ -100,13 +104,13 @@ TEST(GpuHist, GradientBasedSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, false);
|
||||
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample,
|
||||
"gradient_based", kRows);
|
||||
"gradient_based", kRows, false);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@ -137,11 +141,11 @@ TEST(GpuHist, ExternalMemory) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows, true);
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows, true);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@ -181,14 +185,14 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows, true);
|
||||
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
common::GlobalRandom() = rng;
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
|
||||
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, kSubsample, kSamplingMethod,
|
||||
kRows);
|
||||
kRows, true);
|
||||
|
||||
Json jtree{Object{}};
|
||||
Json jtree_ext{Object{}};
|
||||
@ -228,6 +232,42 @@ TEST(GpuHist, MaxDepth) {
|
||||
ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(GpuHist, PageConcatConfig) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
bst_idx_t n_samples = 64, n_features = 32;
|
||||
auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.Batches(2).GenerateSparsePageDMatrix(
|
||||
"temp", true);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_fmat}));
|
||||
learner->SetParam("device", ctx.DeviceName());
|
||||
learner->SetParam("extmem_concat_pages", "true");
|
||||
learner->SetParam("subsample", "0.8");
|
||||
learner->Configure();
|
||||
|
||||
learner->UpdateOneIter(0, p_fmat);
|
||||
learner->SetParam("extmem_concat_pages", "false");
|
||||
learner->Configure();
|
||||
// GPU Hist rebuilds the updater after configuration. Training continues
|
||||
learner->UpdateOneIter(1, p_fmat);
|
||||
|
||||
learner->SetParam("extmem_concat_pages", "true");
|
||||
learner->SetParam("subsample", "1.0");
|
||||
ASSERT_THAT([&] { learner->UpdateOneIter(2, p_fmat); }, GMockThrow("extmem_concat_pages"));
|
||||
|
||||
// Throws error on CPU.
|
||||
{
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_fmat}));
|
||||
learner->SetParam("extmem_concat_pages", "true");
|
||||
ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages"));
|
||||
}
|
||||
{
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_fmat}));
|
||||
learner->SetParam("extmem_concat_pages", "true");
|
||||
learner->SetParam("tree_method", "approx");
|
||||
ASSERT_THAT([&] { learner->UpdateOneIter(0, p_fmat); }, GMockThrow("extmem_concat_pages"));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree GetHistTree(Context const* ctx, DMatrix* dmat) {
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
@ -3,6 +3,8 @@ import sys
|
||||
import pytest
|
||||
from hypothesis import given, settings, strategies
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing import no_cupy
|
||||
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
|
||||
|
||||
@ -72,6 +74,22 @@ def test_extmem_qdm(
|
||||
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)
|
||||
|
||||
|
||||
def test_concat_pages() -> None:
|
||||
it = tm.IteratorForTest(*tm.make_batches(64, 16, 4, use_cupy=True), cache=None)
|
||||
Xy = xgb.ExtMemQuantileDMatrix(it)
|
||||
with pytest.raises(ValueError, match="can not be used with concatenated pages"):
|
||||
booster = xgb.train(
|
||||
{
|
||||
"device": "cuda",
|
||||
"subsample": 0.5,
|
||||
"sampling_method": "gradient_based",
|
||||
"extmem_concat_pages": True,
|
||||
"objective": "reg:absoluteerror",
|
||||
},
|
||||
Xy,
|
||||
)
|
||||
|
||||
|
||||
@given(
|
||||
strategies.integers(1, 64),
|
||||
strategies.integers(1, 8),
|
||||
|
||||
@ -6,24 +6,32 @@ import pytest
|
||||
|
||||
from xgboost import testing as tm
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import test_demos as td # noqa
|
||||
DEMO_DIR = tm.demo_dir(__file__)
|
||||
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, "guide-python")
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_data_iterator():
|
||||
script = os.path.join(td.PYTHON_DEMO_DIR, "quantile_data_iterator.py")
|
||||
script = os.path.join(PYTHON_DEMO_DIR, "quantile_data_iterator.py")
|
||||
cmd = ["python", script]
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def test_update_process_demo():
|
||||
script = os.path.join(td.PYTHON_DEMO_DIR, "update_process.py")
|
||||
script = os.path.join(PYTHON_DEMO_DIR, "update_process.py")
|
||||
cmd = ["python", script]
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def test_categorical_demo():
|
||||
script = os.path.join(td.PYTHON_DEMO_DIR, "categorical.py")
|
||||
script = os.path.join(PYTHON_DEMO_DIR, "categorical.py")
|
||||
cmd = ["python", script]
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_rmm())
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_external_memory_demo():
|
||||
script = os.path.join(PYTHON_DEMO_DIR, "external_memory.py")
|
||||
cmd = ["python", script]
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user