Export Python Interface for external memory. (#7070)

* Add Python iterator interface.
* Add tests.
* Add demo.
* Add documents.
* Handle empty dataset.
This commit is contained in:
Jiaming Yuan 2021-07-22 15:15:53 +08:00 committed by GitHub
parent e64ee6592f
commit e6088366df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 961 additions and 200 deletions

View File

@ -95,36 +95,40 @@ jobs:
cd build
cmake .. -DBUILD_STATIC_LIB=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -GNinja
ninja -v install
cd -
- name: Build and run C API demo with static
shell: bash -l {0}
run: |
pushd .
cd demo/c-api/
mkdir build
cd build
cmake .. -GNinja -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
ninja -v
ctest
cd ..
./build/api-demo
rm -rf ./build
cd ../..
popd
- name: Build and install XGBoost shared library
shell: bash -l {0}
run: |
cd build
cmake .. -DBUILD_STATIC_LIB=OFF -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -GNinja
ninja -v install
cd -
- name: Build and run C API demo with shared
shell: bash -l {0}
run: |
pushd .
cd demo/c-api/
mkdir build
cd build
cmake .. -GNinja -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
ninja -v
cd ..
./build/api-demo
cd ../../
./tests/ci_build/verify_link.sh ./demo/c-api/build/api-demo
ctest
popd
./tests/ci_build/verify_link.sh ./demo/c-api/build/basic/api-demo
./tests/ci_build/verify_link.sh ./demo/c-api/build/external-memory/external-memory-demo
lint:
runs-on: ubuntu-latest

View File

@ -92,7 +92,10 @@ endif
mypy:
cd python-package; \
mypy ./xgboost/dask.py && \
mypy ../demo/guide-python/external_memory.py && \
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
mypy ../tests/python/test_data_iterator.py && \
mypy ../tests/python-gpu/test_gpu_data_iterator.py && \
mypy ./xgboost/sklearn.py || exit 1; \
mypy . || true ;

View File

@ -1,14 +1,17 @@
cmake_minimum_required(VERSION 3.13)
project(api-demo LANGUAGES C VERSION 0.0.1)
find_package(xgboost REQUIRED)
project(xgboost-c-examples)
# xgboost is built as static libraries, all cxx dependencies need to be linked into the
# executable.
if (XGBOOST_BUILD_STATIC_LIB)
enable_language(CXX)
# find again for those cxx libraries.
find_package(xgboost REQUIRED)
endif(XGBOOST_BUILD_STATIC_LIB)
add_subdirectory(basic)
add_subdirectory(external-memory)
add_executable(api-demo c-api-demo.c)
target_link_libraries(api-demo PRIVATE xgboost::xgboost)
enable_testing()
add_test(
NAME test_xgboost_demo_c_basic
COMMAND api-demo
WORKING_DIRECTORY ${xgboost-c-examples_BINARY_DIR}
)
add_test(
NAME test_xgboost_demo_c_external_memory
COMMAND external-memory-demo
WORKING_DIRECTORY ${xgboost-c-examples_BINARY_DIR}
)

View File

@ -0,0 +1,13 @@
project(api-demo LANGUAGES C VERSION 0.0.1)
find_package(xgboost REQUIRED)
# xgboost is built as static libraries, all cxx dependencies need to be linked into the
# executable.
if (XGBOOST_BUILD_STATIC_LIB)
enable_language(CXX)
# find again for those cxx libraries.
find_package(xgboost REQUIRED)
endif(XGBOOST_BUILD_STATIC_LIB)
add_executable(api-demo c-api-demo.c)
target_link_libraries(api-demo PRIVATE xgboost::xgboost)

View File

@ -24,8 +24,8 @@ int main(int argc, char** argv) {
// load the data
DMatrixHandle dtrain, dtest;
safe_xgboost(XGDMatrixCreateFromFile("../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../data/agaricus.txt.test", silent, &dtest));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
// create the booster
BoosterHandle booster;

View File

@ -0,0 +1,7 @@
cmake_minimum_required(VERSION 3.13)
project(external-memory-demo LANGUAGES C VERSION 0.0.1)
find_package(xgboost REQUIRED)
add_executable(external-memory-demo external_memory.c)
target_link_libraries(external-memory-demo PRIVATE xgboost::xgboost)

View File

@ -0,0 +1,16 @@
Defining a Custom Data Iterator to Load Data from External Memory
=================================================================
A simple demo for using custom data iterator with XGBoost. The feature is still
**experimental** and not ready for production use. If you are not familiar with C API,
please read its introduction in our tutorials and visit the basic demo first.
Defining Data Iterator
----------------------
In the example, we define a custom data iterator with 2 methods: `reset` and `next`. The
`next` method passes data into XGBoost and tells XGBoost whether the iterator has reached
its end, and the `reset` method resets iterations. One important detail when using the C
API for data iterator is users need to make sure that the data passed into `next` method
must be kept in memory until the next iteration or `reset` is called. The external memory
DMatrix is not limited to training, but also valid for other features like prediction.

View File

@ -0,0 +1,179 @@
/*!
* Copyright 2021 XGBoost contributors
*
* \brief A simple example of using xgboost data callback API.
*/
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include <xgboost/c_api.h>
#define safe_xgboost(err) \
if ((err) != 0) { \
fprintf(stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #err, \
XGBGetLastError()); \
exit(1); \
}
#define N_BATCHS 32
#define BATCH_LEN 512
/* Shorthands. */
typedef DMatrixHandle DMatrix;
typedef BoosterHandle Booster;
typedef struct _DataIter {
/* Data of each batch. */
float **data;
/* Labels of each batch */
float **labels;
/* Length of each batch. */
size_t *lengths;
/* Total number of batches. */
size_t n;
/* Current iteration. */
size_t cur_it;
/* Private fields */
DMatrix _proxy;
char _array[128];
} DataIter;
#define safe_malloc(ptr) \
if ((ptr) == NULL) { \
fprintf(stderr, "%s:%d: Failed to allocate memory.\n", __FILE__, \
__LINE__); \
exit(1); \
}
/**
* Initialize with random data for demo. In practice the data should be loaded
* from external memory. We just demonstrate how to use the iterator in
* XGBoost.
*
* \param batch_size Number of elements for each batch. The demo here is only using 1
* column.
* \param n_batches Number of batches.
*/
void DataIterator_Init(DataIter *self, size_t batch_size, size_t n_batches) {
self->n = n_batches;
self->lengths = (size_t *)malloc(self->n * sizeof(size_t));
safe_malloc(self->lengths);
for (size_t i = 0; i < self->n; ++i) {
self->lengths[i] = batch_size;
}
self->data = (float **)malloc(self->n * sizeof(float *));
safe_malloc(self->data);
self->labels = (float **)malloc(self->n * sizeof(float *));
safe_malloc(self->labels);
/* Generate some random data. */
for (size_t i = 0; i < self->n; ++i) {
self->data[i] = (float *)malloc(self->lengths[i] * sizeof(float));
safe_malloc(self->data[i]);
for (size_t j = 0; j < self->lengths[i]; ++j) {
float x = (float)rand() / (float)(RAND_MAX);
self->data[i][j] = x;
}
self->labels[i] = (float *)malloc(self->lengths[i] * sizeof(float));
safe_malloc(self->labels[i]);
for (size_t j = 0; j < self->lengths[i]; ++j) {
float y = (float)rand() / (float)(RAND_MAX);
self->labels[i][j] = y;
}
}
self->cur_it = 0;
safe_xgboost(XGProxyDMatrixCreate(&self->_proxy));
}
void DataIterator_Free(DataIter *self) {
for (size_t i = 0; i < self->n; ++i) {
free(self->data[i]);
free(self->labels[i]);
}
free(self->data);
free(self->lengths);
safe_xgboost(XGDMatrixFree(self->_proxy));
};
int DataIterator_Next(DataIterHandle handle) {
DataIter *self = (DataIter *)(handle);
if (self->cur_it == self->n) {
self->cur_it = 0;
return 0; /* At end */
}
/* A JSON string encoding array interface (standard from numpy). */
char array[] = "{\"data\": [%lu, false], \"shape\":[%lu, 1], \"typestr\": "
"\"<f4\", \"version\": 3}";
memset(self->_array, '\0', sizeof(self->_array));
sprintf(self->_array, array, (size_t)self->data[self->cur_it],
self->lengths[self->cur_it]);
safe_xgboost(XGProxyDMatrixSetDataDense(self->_proxy, self->_array));
/* The data passed in the iterator must remain valid (not being freed until the next
* iteration or reset) */
safe_xgboost(XGDMatrixSetDenseInfo(self->_proxy, "label",
self->labels[self->cur_it],
self->lengths[self->cur_it], 1));
self->cur_it++;
return 1; /* Continue. */
}
void DataIterator_Reset(DataIterHandle handle) {
DataIter *self = (DataIter *)(handle);
self->cur_it = 0;
}
/**
* Train a regression model and save it into JSON model file.
*/
void TrainModel(DMatrix Xy) {
/* Create booster for training. */
Booster booster;
DMatrix cache[] = {Xy};
safe_xgboost(XGBoosterCreate(cache, 1, &booster));
/* Use approx for external memory training. */
safe_xgboost(XGBoosterSetParam(booster, "tree_method", "approx"));
safe_xgboost(XGBoosterSetParam(booster, "objective", "reg:squarederror"));
/* Start training. */
const char *validation_names[1] = {"train"};
const char *validation_result = NULL;
size_t n_rounds = 10;
for (size_t i = 0; i < n_rounds; ++i) {
safe_xgboost(XGBoosterUpdateOneIter(booster, i, Xy));
safe_xgboost(XGBoosterEvalOneIter(booster, i, cache, validation_names, 1,
&validation_result));
printf("%s\n", validation_result);
}
/* Save the model to a JSON file. */
safe_xgboost(XGBoosterSaveModel(booster, "model.json"));
safe_xgboost(XGBoosterFree(booster));
}
int main() {
DataIter iter;
DataIterator_Init(&iter, BATCH_LEN, N_BATCHS);
/* Create DMatrix from iterator. During training, some cache files with the
* prefix "cache-" will be generated in current directory */
char config[] = "{\"missing\": NaN, \"cache_prefix\": \"cache\"}";
DMatrix Xy;
safe_xgboost(XGDMatrixCreateFromCallback(
&iter, iter._proxy, DataIterator_Reset, DataIterator_Next, config, &Xy));
TrainModel(Xy);
safe_xgboost(XGDMatrixFree(Xy));
DataIterator_Free(&iter);
return 0;
}

View File

@ -1,22 +1,92 @@
"""Experimental support for external memory. This is similar to the one in
`quantile_data_iterator.py`, but for external memory instead of Quantile DMatrix. The
feature is not ready for production use yet.
.. versionadded:: 1.5.0
"""
import os
import xgboost as xgb
import xgboost
from typing import Callable, List, Tuple
import tempfile
import numpy as np
### simple example for using external memory version
# this is the only difference, add a # followed by a cache prefix name
# several cache file with the prefix will be generated
# currently only support convert from libsvm file
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train#dtrain.cache'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test#dtest.cache'))
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Generate random batches."""
X = []
y = []
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
X.append(_X)
y.append(_y)
return X, y
# specify validations set to watch performance
param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
# performance notice: set nthread to be the number of your real cpu
# some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
#param['nthread']=num_real_cpu
class Iterator(xgboost.DataIter):
"""A custom iterator for loading files in batches."""
def __init__(self, file_paths: List[Tuple[str, str]]):
self._file_paths = file_paths
self._it = 0
# XGBoost will generate some cache files under current directory with the prefix
# "cache"
super().__init__(cache_prefix=os.path.join(".", "cache"))
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)
def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
X_path, y_path = self._file_paths[self._it]
X = np.loadtxt(X_path)
y = np.loadtxt(y_path)
assert X.shape[0] == y.shape[0]
return X, y
def next(self, input_data: Callable) -> int:
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is
called by XGBoost during the construction of ``DMatrix``
"""
if self._it == len(self._file_paths):
# return 0 to let XGBoost know this is the end of iteration
return 0
# input_data is a function passed in by XGBoost who has the similar signature to
# the ``DMatrix`` constructor.
X, y = self.load_file()
input_data(data=X, label=y)
self._it += 1
return 1
def reset(self) -> None:
"""Reset the iterator to its beginning"""
self._it = 0
def main(tmpdir: str) -> xgboost.Booster:
# generate some random data for demo
batches = make_batches(1024, 17, 31)
files = []
for i, (X, y) in enumerate(zip(*batches)):
X_path = os.path.join(tmpdir, "X-" + str(i) + ".txt")
np.savetxt(X_path, X)
y_path = os.path.join(tmpdir, "y-" + str(i) + ".txt")
np.savetxt(y_path, y)
files.append((X_path, y_path))
it = Iterator(files)
# For non-data arguments, specify it here once instead of passing them by the `next`
# method.
missing = np.NaN
Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some
# caveats. This is still an experimental feature.
booster = xgboost.train({"tree_method": "approx"}, Xy)
return booster
if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tmpdir:
main(tmpdir)

View File

@ -85,7 +85,7 @@ def main():
rounds = 100
it = IterForDMatrixDemo()
# Use iterator, must be `DeviceQuantileDMatrix`
# Use iterator, must be `DeviceQuantileDMatrix` for quantile DMatrix.
m_with_it = xgboost.DeviceQuantileDMatrix(it)
# Use regular DMatrix.

View File

@ -1,6 +1,6 @@
##############################
##############
C API Tutorial
##############################
##############
In this tutorial, we are going to install XGBoost library & configure the CMakeLists.txt file of our C/C++ application to link XGBoost library with our application. Later on, we will see some useful tips for using C API and code snippets as examples to use various functions available in C API to perform basic task like loading, training model & predicting on test dataset.
@ -31,7 +31,7 @@ Run the following commands on your terminal. The below commands will install the
# Activate the Conda environment, into which we'll install XGBoost
conda activate [env_name]
# Build the compiled version of XGBoost inside the build folder
cmake .. -DBUILD_STATIC_LIB=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
cmake .. -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX
# install XGBoost in your conda environment (usually under [your home directory]/miniconda3)
make install
@ -188,7 +188,7 @@ Sample examples along with Code snippet to use C API functions
// dmatrix variable will contain the created DMatrix using it
safe_xgboost(XGDMatrixCreateFromMat(data1, 1, 50, 0, &dmatrix));
// here -1 represents the missing value in the matrix dataset
safe_xgboost(XGDMatrixCreateFromMat(data2, ROWS, COLS, -1, &dmatrix2)(;
safe_xgboost(XGDMatrixCreateFromMat(data2, ROWS, COLS, -1, &dmatrix2));
3. Create a Booster object for training & testing on dataset using `XGBoosterCreate <https://xgboost.readthedocs.io/en/stable/dev/c__api_8h.html#ad9fe6f8c8c4901db1c7581a96a21f9ae>`_

View File

@ -1,6 +1,75 @@
#####################################
Using XGBoost External Memory Version
#####################################
XGBoost supports loading data from external memory using builtin data parser. And
starting from version 1.5, users can also define a custom iterator to load data in chunks.
The feature is still experimental and not yet ready for production use. In this tutorial
we will introduce both methods. Please note that training on data from external memory is
not supported by ``exact`` tree method.
*************
Data Iterator
*************
Starting from XGBoost 1.5, users can define their own data loader using Python or C
interface. There are some examples in the ``demo`` directory for quick start. This is a
generalized version of text input external memory, where users no longer need to prepare a
text file that XGBoost recognizes. To enable the feature, user need to define a data
iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix``
constructor.
.. code-block:: python
import os
from typing import List, Callable
import xgboost
from sklearn.datasets import load_svmlight_file
class Iterator(xgboost.DataIter):
def __init__(self, svm_file_paths: List[str]):
self._file_paths = svm_file_paths
self._it = 0
# XGBoost will generate some cache files under current directory with the prefix
# "cache"
super().__init__(cache_prefix=os.path.join(".", "cache"))
def next(self, input_data: Callable):
"""Advance the iterator by 1 step and pass the data to XGBoost. This function is
called by XGBoost during the construction of ``DMatrix``
"""
if self._it == len(self._file_paths):
# return 0 to let XGBoost know this is the end of iteration
return 0
# input_data is a function passed in by XGBoost who has the exact same signature of
# ``DMatrix``
X, y = load_svmlight_file(self._file_paths[self._it])
input_data(X, y)
self._it += 1
# Return 1 to let XGBoost know we haven't seen all the files yet.
return 1
def reset(self):
"""Reset the iterator to its beginning"""
self._it = 0
it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"])
Xy = xgboost.DMatrix(it)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
# as noted in following sections.
booster = xgboost.train({"tree_method": "approx"}, Xy)
The above snippet is a simplifed version of ``demo/guide-python/external_memory.py``. For
an example in C, please see ``demo/c-api/external-memory/``.
****************
Text File Inputs
****************
There is no big difference between using external memory version and in-memory version.
The only difference is the filename format.
@ -36,10 +105,11 @@ more notes about text input formats, see :doc:`/tutorials/input_format`.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``.
***********
GPU Version
***********
External memory is fully supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``).
**********************************
GPU Version (GPU Hist tree method)
**********************************
External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``).
If you are still getting out-of-memory errors after enabling external memory, try subsampling the
data to further reduce GPU memory usage:
@ -52,23 +122,14 @@ data to further reduce GPU memory usage:
'sampling_method': 'gradient_based',
}
For more information, see `this paper <https://arxiv.org/abs/2005.09148>`_.
For more information, see `this paper <https://arxiv.org/abs/2005.09148>`_. Internally
the tree method still concatenate all the chunks into 1 final histogram index due to
performance reason, but in compressed format. So its scalability has an upper bound but
still has lower memory cost in general.
*******************
Distributed Version
*******************
The external memory mode naturally works on distributed version, you can simply set path like
********
CPU Hist
********
.. code-block:: none
data = "hdfs://path-to-data/#dtrain.cache"
XGBoost will cache the data to the local position. When you run on YARN, the current folder is temporary
so that you can directly use ``dtrain.cache`` to cache to current folder.
***********
Limitations
***********
* The ``hist`` tree method hasn't been tested thoroughly with external memory support (see
`this issue <https://github.com/dmlc/xgboost/issues/4093>`_).
* OSX is not tested.
It's limited by the same factor of GPU Hist, except that gradient based sampling is not
yet supported on CPU.

View File

@ -6,7 +6,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
import os
from .core import DMatrix, DeviceQuantileDMatrix, Booster
from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter
from .training import train, cv
from . import rabit # noqa
from . import tracker # noqa
@ -25,7 +25,7 @@ VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
with open(VERSION_FILE) as f:
__version__ = f.read().strip()
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter',
'train', 'cv',
'RabitTracker',
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',

View File

@ -5,7 +5,7 @@
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
from typing import List, Optional, Any, Union, Dict
from typing import List, Optional, Any, Union, Dict, TypeVar
# pylint: enable=no-name-in-module,import-error
from typing import Callable, Tuple
import ctypes
@ -313,78 +313,130 @@ def _prediction_output(shape, dims, predts, is_cuda):
return arr_predict
class DataIter:
'''The interface for user defined data iterator. Currently is only supported by Device
DMatrix.
class DataIter: # pylint: disable=too-many-instance-attributes
"""The interface for user defined data iterator.
Parameters
----------
cache_prefix:
Prefix to the cache files, only used in external memory. It can be either an URI
or a file path.
"""
_T = TypeVar("_T")
def __init__(self, cache_prefix: Optional[str] = None) -> None:
self.cache_prefix = cache_prefix
'''
def __init__(self):
self._handle = _ProxyDMatrix()
self.exception = None
self.enable_categorical = False
self._allow_host = False
self._exception: Optional[Exception] = None
self._enable_categorical = False
self._allow_host = True
# Stage data in Python until reset or next is called to avoid data being free.
self._temporary_data = None
def _get_callbacks(
self, allow_host: bool, enable_categorical: bool
) -> Tuple[Callable, Callable]:
assert hasattr(self, "cache_prefix"), "__init__ is not called."
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
self._reset_wrapper
)
self._next_callback = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_void_p,
)(self._next_wrapper)
self._allow_host = allow_host
self._enable_categorical = enable_categorical
return self._reset_callback, self._next_callback
@property
def proxy(self):
'''Handler of DMatrix proxy.'''
def proxy(self) -> "_ProxyDMatrix":
"""Handle of DMatrix proxy."""
return self._handle
def reset_wrapper(self, this): # pylint: disable=unused-argument
'''A wrapper for user defined `reset` function.'''
self.reset()
def _handle_exception(self, fn: Callable, dft_ret: _T) -> _T:
if self._exception is not None:
return dft_ret
def next_wrapper(self, this): # pylint: disable=unused-argument
'''A wrapper for user defined `next` function.
try:
return fn()
except Exception as e: # pylint: disable=broad-except
# Defer the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution).
tb = sys.exc_info()[2]
# On dask, the worker is restarted and somehow the information is
# lost.
self._exception = e.with_traceback(tb)
return dft_ret
def _reraise(self) -> None:
self._temporary_data = None
if self._exception is not None:
# pylint 2.7.0 believes `self._exception` can be None even with `assert
# isinstace`
exc = self._exception
self._exception = None
raise exc # pylint: disable=raising-bad-type
def __del__(self) -> None:
assert self._temporary_data is None, self._temporary_data
assert self._exception is None
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
"""A wrapper for user defined `reset` function."""
# free the data
self._temporary_data = None
self._handle_exception(self.reset, None)
def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
"""A wrapper for user defined `next` function.
`this` is not used in Python. ctypes can handle `self` of a Python
member function automatically when converting it to c function
pointer.
'''
if self.exception is not None:
return 0
"""
@_deprecate_positional_args
def data_handle(
data,
feature_names=None,
feature_types=None,
**kwargs
data: Any,
*,
feature_names: Optional[List[str]] = None,
feature_types: Optional[List[str]] = None,
**kwargs: Any,
):
from .data import dispatch_proxy_set_data
from .data import _proxy_transform
data, feature_names, feature_types = _proxy_transform(
data, feature_names, feature_types, self.enable_categorical,
transformed, feature_names, feature_types = _proxy_transform(
data,
feature_names,
feature_types,
self._enable_categorical,
)
dispatch_proxy_set_data(self.proxy, data, self._allow_host)
# Stage the data, meta info are copied inside C++ MetaInfo.
self._temporary_data = transformed
dispatch_proxy_set_data(self.proxy, transformed, self._allow_host)
self.proxy.set_info(
feature_names=feature_names,
feature_types=feature_types,
**kwargs,
)
try:
# Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution).
ret = self.next(data_handle) # pylint: disable=not-callable
except Exception as e: # pylint: disable=broad-except
tb = sys.exc_info()[2]
# On dask the worker is restarted and somehow the information is
# lost.
self.exception = e.with_traceback(tb)
return 0
return ret
# pylint: disable=not-callable
return self._handle_exception(lambda: self.next(data_handle), 0)
def reset(self):
'''Reset the data iterator. Prototype for user defined function.'''
def reset(self) -> None:
"""Reset the data iterator. Prototype for user defined function."""
raise NotImplementedError()
def next(self, input_data):
'''Set the next batch of data.
def next(self, input_data: Callable) -> int:
"""Set the next batch of data.
Parameters
----------
data_handle: callable
data_handle:
A function with same data fields like `data`, `label` with
`xgboost.DMatrix`.
@ -392,7 +444,7 @@ class DataIter:
-------
0 if there's no more batch, otherwise 1.
'''
"""
raise NotImplementedError()
@ -546,7 +598,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.handle = None
return
from .data import dispatch_data_backend
from .data import dispatch_data_backend, _is_iter
if _is_iter(data):
self._init_from_iter(data, enable_categorical)
assert self.handle is not None
return
handle, feature_names, feature_types = dispatch_data_backend(
data,
@ -575,6 +632,33 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if feature_types is not None:
self.feature_types = feature_types
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool):
it = iterator
args = {
"missing": self.missing,
"nthread": self.nthread,
"cache_prefix": it.cache_prefix if it.cache_prefix else "",
}
args = from_pystr_to_cstr(json.dumps(args))
handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(
True, enable_categorical
)
ret = _LIB.XGDMatrixCreateFromCallback(
None,
it.proxy.handle,
reset_callback,
next_callback,
args,
ctypes.byref(handle),
)
# pylint: disable=protected-access
it._reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
def __del__(self):
if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle))
@ -907,7 +991,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = 'feature_names must have the same length as data'
msg = ("feature_names must have the same length as data, ",
f"expected {self.num_col()}, got {len(feature_names)}")
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
if not all(isinstance(f, str) and
@ -1001,30 +1086,44 @@ class _ProxyDMatrix(DMatrix):
inplace_predict).
"""
def __init__(self): # pylint: disable=super-init-not-called
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
def _set_data_from_cuda_interface(self, data):
'''Set data from CUDA array interface.'''
"""Set data from CUDA array interface."""
interface = data.__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
_check_call(
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(
self.handle,
interface_str
)
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
)
def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.'''
"""Set data from CUDA columnar format."""
from .data import _cudf_array_interfaces
_, interfaces_str = _cudf_array_interfaces(data)
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
def _set_data_from_array(self, data: np.ndarray):
"""Set data from numpy array."""
from .data import _array_interface
_check_call(
_LIB.XGProxyDMatrixSetDataCudaColumnar(
self.handle,
interfaces_str
)
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
)
def _set_data_from_csr(self, csr):
"""Set data from scipy csr"""
from .data import _array_interface
_LIB.XGProxyDMatrixSetDataCSR(
self.handle,
_array_interface(csr.indptr),
_array_interface(csr.indices),
_array_interface(csr.data),
ctypes.c_size_t(csr.shape[1]),
)
@ -1110,13 +1209,14 @@ class DeviceQuantileDMatrix(DMatrix):
else:
it = SingleBatchInternalIter(data=data, **meta)
it.enable_categorical = enable_categorical
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
next_callback = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_void_p,
)(it.next_wrapper)
handle = ctypes.c_void_p()
# pylint: disable=protected-access
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
if it.cache_prefix is not None:
raise ValueError(
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
"in iterator to fix this error."
)
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
@ -1127,10 +1227,8 @@ class DeviceQuantileDMatrix(DMatrix):
ctypes.c_int(self.max_bin),
ctypes.byref(handle),
)
if it.exception is not None:
# pylint 2.7.0 believes `it.exception` can be None even with `assert
# isinstace`
raise it.exception # pylint: disable=raising-bad-type
# pylint: disable=protected-access
it._reraise()
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
@ -2241,8 +2339,8 @@ class Booster(object):
# pylint: disable=too-many-locals
fmap = os.fspath(os.path.expanduser(fmap))
if not PANDAS_INSTALLED:
raise Exception(('pandas must be available to use this method.'
'Install pandas before calling again.'))
raise ImportError(('pandas must be available to use this method.'
'Install pandas before calling again.'))
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
raise ValueError('This method is not defined for Booster type {}'

View File

@ -5,7 +5,7 @@ import ctypes
import json
import warnings
import os
from typing import Any, Tuple
from typing import Any, Tuple, Callable
import numpy as np
@ -238,10 +238,13 @@ def _transform_pandas_df(data, enable_categorical,
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
meta=meta)
)
dtype = meta_type if meta_type else np.float32
data = np.ascontiguousarray(data.values, dtype=dtype)
data = data.values
if meta_type:
data = data.astype(meta_type)
return data, feature_names, feature_types
@ -759,19 +762,19 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
area for meta info.
'''
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name
super().__init__()
def next(self, input_data):
def next(self, input_data: Callable) -> int:
if self.it == 1:
return 0
self.it += 1
input_data(**self.kwargs)
return 1
def reset(self):
def reset(self) -> None:
self.it = 0
@ -785,6 +788,15 @@ def _proxy_transform(data, feature_names, feature_types, enable_categorical):
return data, feature_names, feature_types
if _is_dlpack(data):
return _transform_dlpack(data), feature_names, feature_types
if _is_numpy_array(data):
return data, feature_names, feature_types
if _is_scipy_csr(data):
return data, feature_names, feature_types
if _is_pandas_df(data):
arr, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types
)
return arr, feature_names, feature_types
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
@ -803,7 +815,16 @@ def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -
data = _transform_dlpack(data)
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212
return
# Part of https://github.com/dmlc/xgboost/pull/7070
assert allow_host is False, "host data is not yet supported."
raise TypeError('Value type is not supported for data iterator:' +
str(type(data)))
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
if not allow_host:
raise err
if _is_numpy_array(data):
proxy._set_data_from_array(data) # pylint: disable=W0212
return
if _is_scipy_csr(data):
proxy._set_data_from_csr(data) # pylint: disable=W0212
return
raise err

View File

@ -874,8 +874,15 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
tid);
}
});
if (this->data.Empty()) {
transpose.offset.Resize(num_columns + 1);
transpose.offset.Fill(0);
}
CHECK_EQ(transpose.offset.Size(), num_columns + 1);
return transpose;
}
void SparsePage::Push(const SparsePage &batch) {
auto& data_vec = data.HostVector();
auto& offset_vec = offset.HostVector();
@ -1007,6 +1014,7 @@ void SparsePage::PushCSC(const SparsePage &batch) {
auto const& other_offset = batch.offset.ConstHostVector();
if (other_data.empty()) {
self_offset = other_offset;
return;
}
if (!self_data.empty()) {

View File

@ -19,11 +19,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
if (ptr_device >= 0) {
dh::safe_cuda(cudaSetDevice(ptr_device));
}
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);
if (column.num_rows == 0) {
return;
}
out->SetDevice(ptr_device);
out->Resize(column.num_rows);
@ -123,7 +128,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(interface_str);
std::string key{c_key};
array_interface.AsColumnVector();
if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) ||
(array_interface.num_cols == 0 && array_interface.num_rows == 1))) {
// Not an empty column, transform it.
array_interface.AsColumnVector();
}
CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
if (array_interface.num_rows == 0) {

View File

@ -154,7 +154,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
size_t NumRows() const { return num_rows_; }
size_t NumColumns() const { return columns_.size(); }
size_t DeviceIdx() const { return device_idx_; }
int32_t DeviceIdx() const { return device_idx_; }
private:
CudfAdapterBatch batch_;
@ -202,12 +202,12 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
size_t NumRows() const { return array_interface_.num_rows; }
size_t NumColumns() const { return array_interface_.num_cols; }
size_t DeviceIdx() const { return device_idx_; }
int32_t DeviceIdx() const { return device_idx_; }
private:
ArrayInterface array_interface_;
CupyAdapterBatch batch_;
int device_idx_;
int32_t device_idx_ {-1};
};
// Returns maximum row length

View File

@ -10,6 +10,7 @@
namespace xgboost {
namespace data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) {
auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{});

View File

@ -14,6 +14,9 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}
void DMatrixProxy::FromCudaArray(std::string interface_str) {
@ -22,6 +25,9 @@ void DMatrixProxy::FromCudaArray(std::string interface_str) {
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) {
CHECK_EQ(this->Info().num_row_, 0);
}
}
} // namespace data

View File

@ -141,9 +141,8 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
proxy->Adapter())->Value();
return fn(value);
return std::result_of_t<Fn(
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
}
}
} // namespace data

View File

@ -16,7 +16,10 @@ namespace data {
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto device =
adapter->DeviceIdx() < 0 ? dh::CurrentDevice() : adapter->DeviceIdx();
CHECK_GE(device, 0);
dh::safe_cuda(cudaSetDevice(device));
CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
@ -27,8 +30,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Enforce single batch
CHECK(!adapter->Next());
info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(),
missing, sparse_page_.get());
info_.num_nonzero_ =
CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get());
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns

View File

@ -15,6 +15,29 @@ MetaInfo &SparsePageDMatrix::Info() { return info_; }
const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
namespace detail {
// Use device dispatch
size_t NSamplesDevice(DMatrixProxy *proxy)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
return 0;
}
#endif
size_t NFeaturesDevice(DMatrixProxy *proxy)
#if defined(XGBOOST_USE_CUDA)
; // NOLINT
#else
{
common::AssertGPUSupport();
return 0;
}
#endif
} // namespace detail
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
@ -35,13 +58,24 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
size_t nnz = 0;
auto num_rows = [&]() {
return HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumRows(); });
bool type_error {false};
size_t n_samples = HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumRows(); }, &type_error);
if (type_error) {
n_samples = detail::NSamplesDevice(proxy);
}
return n_samples;
};
auto num_cols = [&]() {
return HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); });
bool type_error {false};
size_t n_features = HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); }, &type_error);
if (type_error) {
n_features = detail::NFeaturesDevice(proxy);
}
return n_features;
};
// the proxy is iterated together with the sparse page source so we can obtain all
// information in 1 pass.
for (auto const &page : this->GetRowBatchesImpl()) {

View File

@ -7,8 +7,24 @@
namespace xgboost {
namespace data {
namespace detail {
size_t NSamplesDevice(DMatrixProxy *proxy) {
return Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
}
size_t NFeaturesDevice(DMatrixProxy *proxy) {
return Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
}
} // namespace detail
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
auto device = proxy->DeviceIdx();
if (device < 0) {
device = dh::CurrentDevice();
}
CHECK_GE(device, 0);
Dispatch(proxy, [&](auto const &value) {
CopyToSparsePage(value, device, missing, page);
});

View File

@ -236,7 +236,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
iter_{iter}, proxy_{proxy} {
if (!cache_info_->written) {
iter_.Reset();
iter_.Next();
CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch.";
}
this->Fetch();
}

View File

@ -444,7 +444,7 @@ class ColMaker: public TreeUpdater {
}
// update the solution candidate
virtual void UpdateSolution(const SparsePage &batch,
virtual void UpdateSolution(const SortedCSCPage &batch,
const std::vector<bst_feature_t> &feat_set,
const std::vector<GradientPair> &gpair,
DMatrix*) {

View File

@ -77,12 +77,23 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
for (size_t i = 0; i < iterators.size(); ++i) {
ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector());
if (i != iterators.size() - 1) {
ASSERT_EQ(iterators[i].use_count(), 1);
} else {
// The last batch is still being held by sparse page DMatrix.
ASSERT_EQ(iterators[i].use_count(), 2);
}
}
// make sure it's const and the caller can not modify the content of page.
for (auto& page : m->GetBatches<EllpackPage>({0, 32})) {
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value, "");
}
// The above iteration clears out all references inside DMatrix.
for (auto const& ptr : iterators) {
ASSERT_TRUE(ptr.unique());
}
}
TEST(SparsePageDMatrix, EllpackPageContent) {

View File

@ -0,0 +1,32 @@
import numpy as np
import xgboost as xgb
from hypothesis import given, strategies, settings
import pytest
import sys
sys.path.append("tests/python")
from test_data_iterator import SingleBatch, make_batches
from test_data_iterator import test_single_batch as cpu_single_batch
from test_data_iterator import run_data_iterator
from testing import IteratorForTest, no_cupy
def test_gpu_single_batch() -> None:
cpu_single_batch("gpu_hist")
@pytest.mark.skipif(**no_cupy())
@given(
strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13)
)
@settings(deadline=None)
def test_gpu_data_iterator(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> None:
run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", True)
run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", False)
def test_cpu_data_iterator() -> None:
"""Make sure CPU algorithm can handle GPU inputs"""
run_data_iterator(1024, 2, 3, "approx", True)

View File

@ -9,7 +9,7 @@ import test_demos as td # noqa
@pytest.mark.skipif(**tm.no_cupy())
def test_data_iterator():
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')
script = os.path.join(td.PYTHON_DEMO_DIR, 'quantile_data_iterator.py')
cmd = ['python', script]
subprocess.check_call(cmd)

View File

@ -112,7 +112,6 @@ class TestGPUUpdaters:
tm.dataset_strategy)
@settings(deadline=None)
def test_external_memory(self, param, num_rounds, dataset):
pytest.xfail(reason='TestGPUUpdaters::test_external_memory is flaky')
# We cannot handle empty dataset yet
assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist'

View File

@ -0,0 +1,135 @@
import xgboost as xgb
from xgboost.data import SingleBatchInternalIter as SingleBatch
import numpy as np
from testing import IteratorForTest
from typing import Tuple, List
import pytest
from hypothesis import given, strategies, settings
from scipy.sparse import csr_matrix
def make_batches(
n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
X = []
y = []
if use_cupy:
import cupy
rng = cupy.random.RandomState(1994)
else:
rng = np.random.RandomState(1994)
for i in range(n_batches):
_X = rng.randn(n_samples_per_batch, n_features)
_y = rng.randn(n_samples_per_batch)
X.append(_X)
y.append(_y)
return X, y
def test_single_batch(tree_method: str = "approx") -> None:
from sklearn.datasets import load_breast_cancer
n_rounds = 10
X, y = load_breast_cancer(return_X_y=True)
X = X.astype(np.float32)
y = y.astype(np.float32)
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
Xy = xgb.DMatrix(X, y)
from_dmat = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
assert from_it.get_dump() == from_dmat.get_dump()
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X = X.astype(np.float32)
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
from_pd = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
# remove feature info to generate exact same text representation.
from_pd.feature_names = None
from_pd.feature_types = None
assert from_pd.get_dump() == from_it.get_dump()
X, y = load_breast_cancer(return_X_y=True)
X = csr_matrix(X)
Xy = xgb.DMatrix(SingleBatch(data=X, label=y))
from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
X, y = load_breast_cancer(return_X_y=True)
Xy = xgb.DMatrix(SingleBatch(data=X, label=y), missing=0.0)
from_np = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds)
assert from_np.get_dump() == from_it.get_dump()
def run_data_iterator(
n_samples_per_batch: int,
n_features: int,
n_batches: int,
tree_method: str,
use_cupy: bool,
) -> None:
n_rounds = 2
it = IteratorForTest(
*make_batches(n_samples_per_batch, n_features, n_batches, use_cupy)
)
if n_batches == 0:
with pytest.raises(ValueError, match="1 batch"):
Xy = xgb.DMatrix(it)
return
Xy = xgb.DMatrix(it)
assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features
results_from_it: xgb.callback.EvaluationMonitor.EvalsLog = {}
from_it = xgb.train(
{"tree_method": tree_method, "max_depth": 2},
Xy,
num_boost_round=n_rounds,
evals=[(Xy, "Train")],
evals_result=results_from_it,
verbose_eval=False,
)
it_predt = from_it.predict(Xy)
X, y = it.as_arrays()
Xy = xgb.DMatrix(X, y)
assert Xy.num_row() == n_samples_per_batch * n_batches
assert Xy.num_col() == n_features
results_from_arrays: xgb.callback.EvaluationMonitor.EvalsLog = {}
from_arrays = xgb.train(
{"tree_method": tree_method, "max_depth": 2},
Xy,
num_boost_round=n_rounds,
evals=[(Xy, "Train")],
evals_result=results_from_arrays,
verbose_eval=False,
)
arr_predt = from_arrays.predict(Xy)
if tree_method != "gpu_hist":
rtol = 1e-1 # flaky
else:
np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-3)
rtol = 1e-6
np.testing.assert_allclose(
results_from_it["Train"]["rmse"],
results_from_arrays["Train"]["rmse"],
rtol=rtol,
)
@given(
strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13)
)
@settings(deadline=None)
def test_data_iterator(
n_samples_per_batch: int, n_features: int, n_batches: int
) -> None:
run_data_iterator(n_samples_per_batch, n_features, n_batches, "approx", False)
run_data_iterator(n_samples_per_batch, n_features, n_batches, "hist", False)

View File

@ -8,7 +8,7 @@ from io import StringIO
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import DASK_INSTALLED
import pytest
import tempfile
import gc
import xgboost as xgb
import numpy as np
import platform
@ -143,10 +143,35 @@ def skip_s390x():
return {"condition": condition, "reason": reason}
class IteratorForTest(xgb.core.DataIter):
def __init__(self, X, y):
assert len(X) == len(y)
self.X = X
self.y = y
self.it = 0
super().__init__("./")
def next(self, input_data):
if self.it == len(self.X):
return 0
# Use copy to make sure the iterator doesn't hold a reference to the data.
input_data(data=self.X[self.it].copy(), label=self.y[self.it].copy())
gc.collect() # clear up the copy, see if XGBoost access freed memory.
self.it += 1
return 1
def reset(self):
self.it = 0
def as_arrays(self):
X = np.concatenate(self.X, axis=0)
y = np.concatenate(self.y, axis=0)
return X, y
# Contains a dataset in numpy format as well as the relevant objective and metric
class TestDataset:
def __init__(self, name, get_dataset, objective, metric
):
def __init__(self, name, get_dataset, objective, metric):
self.name = name
self.objective = objective
self.metric = metric
@ -171,16 +196,23 @@ class TestDataset:
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
def get_external_dmat(self):
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'tmptmp_1234.csv')
np.savetxt(path,
np.hstack((self.y.reshape(len(self.y), 1), self.X)),
delimiter=',')
assert os.path.exists(path)
uri = path + '?format=csv&label_column=0#tmptmp_'
# The uri looks like:
# 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_'
return xgb.DMatrix(uri, weight=self.w, base_margin=self.margin)
n_samples = self.X.shape[0]
n_batches = 10
per_batch = n_samples // n_batches + 1
predictor = []
response = []
for i in range(n_batches):
beg = i * per_batch
end = min((i + 1) * per_batch, n_samples)
assert end != beg
X = self.X[beg: end, ...]
y = self.y[beg: end]
predictor.append(X)
response.append(y)
it = IteratorForTest(predictor, response)
return xgb.DMatrix(it)
def __repr__(self):
return self.name