diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b108ed263..819e8a9a6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -95,36 +95,40 @@ jobs: cd build cmake .. -DBUILD_STATIC_LIB=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -GNinja ninja -v install + cd - - name: Build and run C API demo with static shell: bash -l {0} run: | + pushd . cd demo/c-api/ mkdir build cd build cmake .. -GNinja -DCMAKE_PREFIX_PATH=$CONDA_PREFIX ninja -v + ctest cd .. - ./build/api-demo rm -rf ./build - cd ../.. + popd - name: Build and install XGBoost shared library shell: bash -l {0} run: | cd build cmake .. -DBUILD_STATIC_LIB=OFF -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -GNinja ninja -v install + cd - - name: Build and run C API demo with shared shell: bash -l {0} run: | + pushd . cd demo/c-api/ mkdir build cd build cmake .. -GNinja -DCMAKE_PREFIX_PATH=$CONDA_PREFIX ninja -v - cd .. - ./build/api-demo - cd ../../ - ./tests/ci_build/verify_link.sh ./demo/c-api/build/api-demo + ctest + popd + ./tests/ci_build/verify_link.sh ./demo/c-api/build/basic/api-demo + ./tests/ci_build/verify_link.sh ./demo/c-api/build/external-memory/external-memory-demo lint: runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index 85adfa2c7..09b137d92 100644 --- a/Makefile +++ b/Makefile @@ -92,7 +92,10 @@ endif mypy: cd python-package; \ mypy ./xgboost/dask.py && \ + mypy ../demo/guide-python/external_memory.py && \ mypy ../tests/python-gpu/test_gpu_with_dask.py && \ + mypy ../tests/python/test_data_iterator.py && \ + mypy ../tests/python-gpu/test_gpu_data_iterator.py && \ mypy ./xgboost/sklearn.py || exit 1; \ mypy . || true ; diff --git a/demo/c-api/CMakeLists.txt b/demo/c-api/CMakeLists.txt index 2f7d5bbe1..852600155 100644 --- a/demo/c-api/CMakeLists.txt +++ b/demo/c-api/CMakeLists.txt @@ -1,14 +1,17 @@ cmake_minimum_required(VERSION 3.13) -project(api-demo LANGUAGES C VERSION 0.0.1) -find_package(xgboost REQUIRED) +project(xgboost-c-examples) -# xgboost is built as static libraries, all cxx dependencies need to be linked into the -# executable. -if (XGBOOST_BUILD_STATIC_LIB) - enable_language(CXX) - # find again for those cxx libraries. - find_package(xgboost REQUIRED) -endif(XGBOOST_BUILD_STATIC_LIB) +add_subdirectory(basic) +add_subdirectory(external-memory) -add_executable(api-demo c-api-demo.c) -target_link_libraries(api-demo PRIVATE xgboost::xgboost) +enable_testing() +add_test( + NAME test_xgboost_demo_c_basic + COMMAND api-demo + WORKING_DIRECTORY ${xgboost-c-examples_BINARY_DIR} +) +add_test( + NAME test_xgboost_demo_c_external_memory + COMMAND external-memory-demo + WORKING_DIRECTORY ${xgboost-c-examples_BINARY_DIR} +) diff --git a/demo/c-api/basic/CMakeLists.txt b/demo/c-api/basic/CMakeLists.txt new file mode 100644 index 000000000..32e2bc432 --- /dev/null +++ b/demo/c-api/basic/CMakeLists.txt @@ -0,0 +1,13 @@ +project(api-demo LANGUAGES C VERSION 0.0.1) +find_package(xgboost REQUIRED) + +# xgboost is built as static libraries, all cxx dependencies need to be linked into the +# executable. +if (XGBOOST_BUILD_STATIC_LIB) + enable_language(CXX) + # find again for those cxx libraries. + find_package(xgboost REQUIRED) +endif(XGBOOST_BUILD_STATIC_LIB) + +add_executable(api-demo c-api-demo.c) +target_link_libraries(api-demo PRIVATE xgboost::xgboost) diff --git a/demo/c-api/Makefile b/demo/c-api/basic/Makefile similarity index 100% rename from demo/c-api/Makefile rename to demo/c-api/basic/Makefile diff --git a/demo/c-api/README.md b/demo/c-api/basic/README.md similarity index 100% rename from demo/c-api/README.md rename to demo/c-api/basic/README.md diff --git a/demo/c-api/c-api-demo.c b/demo/c-api/basic/c-api-demo.c similarity index 97% rename from demo/c-api/c-api-demo.c rename to demo/c-api/basic/c-api-demo.c index d3517da44..1c3d58de9 100644 --- a/demo/c-api/c-api-demo.c +++ b/demo/c-api/basic/c-api-demo.c @@ -24,8 +24,8 @@ int main(int argc, char** argv) { // load the data DMatrixHandle dtrain, dtest; - safe_xgboost(XGDMatrixCreateFromFile("../data/agaricus.txt.train", silent, &dtrain)); - safe_xgboost(XGDMatrixCreateFromFile("../data/agaricus.txt.test", silent, &dtest)); + safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain)); + safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest)); // create the booster BoosterHandle booster; diff --git a/demo/c-api/external-memory/CMakeLists.txt b/demo/c-api/external-memory/CMakeLists.txt new file mode 100644 index 000000000..0c21acb3c --- /dev/null +++ b/demo/c-api/external-memory/CMakeLists.txt @@ -0,0 +1,7 @@ +cmake_minimum_required(VERSION 3.13) +project(external-memory-demo LANGUAGES C VERSION 0.0.1) + +find_package(xgboost REQUIRED) + +add_executable(external-memory-demo external_memory.c) +target_link_libraries(external-memory-demo PRIVATE xgboost::xgboost) diff --git a/demo/c-api/external-memory/README.md b/demo/c-api/external-memory/README.md new file mode 100644 index 000000000..e578b535b --- /dev/null +++ b/demo/c-api/external-memory/README.md @@ -0,0 +1,16 @@ +Defining a Custom Data Iterator to Load Data from External Memory +================================================================= + +A simple demo for using custom data iterator with XGBoost. The feature is still +**experimental** and not ready for production use. If you are not familiar with C API, +please read its introduction in our tutorials and visit the basic demo first. + +Defining Data Iterator +---------------------- + +In the example, we define a custom data iterator with 2 methods: `reset` and `next`. The +`next` method passes data into XGBoost and tells XGBoost whether the iterator has reached +its end, and the `reset` method resets iterations. One important detail when using the C +API for data iterator is users need to make sure that the data passed into `next` method +must be kept in memory until the next iteration or `reset` is called. The external memory +DMatrix is not limited to training, but also valid for other features like prediction. \ No newline at end of file diff --git a/demo/c-api/external-memory/external_memory.c b/demo/c-api/external-memory/external_memory.c new file mode 100644 index 000000000..acafb51f8 --- /dev/null +++ b/demo/c-api/external-memory/external_memory.c @@ -0,0 +1,179 @@ +/*! + * Copyright 2021 XGBoost contributors + * + * \brief A simple example of using xgboost data callback API. + */ + +#include +#include +#include +#include + +#define safe_xgboost(err) \ + if ((err) != 0) { \ + fprintf(stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #err, \ + XGBGetLastError()); \ + exit(1); \ + } + +#define N_BATCHS 32 +#define BATCH_LEN 512 + +/* Shorthands. */ +typedef DMatrixHandle DMatrix; +typedef BoosterHandle Booster; + +typedef struct _DataIter { + /* Data of each batch. */ + float **data; + /* Labels of each batch */ + float **labels; + /* Length of each batch. */ + size_t *lengths; + /* Total number of batches. */ + size_t n; + /* Current iteration. */ + size_t cur_it; + + /* Private fields */ + DMatrix _proxy; + char _array[128]; +} DataIter; + +#define safe_malloc(ptr) \ + if ((ptr) == NULL) { \ + fprintf(stderr, "%s:%d: Failed to allocate memory.\n", __FILE__, \ + __LINE__); \ + exit(1); \ + } + +/** + * Initialize with random data for demo. In practice the data should be loaded + * from external memory. We just demonstrate how to use the iterator in + * XGBoost. + * + * \param batch_size Number of elements for each batch. The demo here is only using 1 + * column. + * \param n_batches Number of batches. + */ +void DataIterator_Init(DataIter *self, size_t batch_size, size_t n_batches) { + self->n = n_batches; + + self->lengths = (size_t *)malloc(self->n * sizeof(size_t)); + safe_malloc(self->lengths); + for (size_t i = 0; i < self->n; ++i) { + self->lengths[i] = batch_size; + } + + self->data = (float **)malloc(self->n * sizeof(float *)); + safe_malloc(self->data); + self->labels = (float **)malloc(self->n * sizeof(float *)); + safe_malloc(self->labels); + + /* Generate some random data. */ + for (size_t i = 0; i < self->n; ++i) { + self->data[i] = (float *)malloc(self->lengths[i] * sizeof(float)); + safe_malloc(self->data[i]); + for (size_t j = 0; j < self->lengths[i]; ++j) { + float x = (float)rand() / (float)(RAND_MAX); + self->data[i][j] = x; + } + + self->labels[i] = (float *)malloc(self->lengths[i] * sizeof(float)); + safe_malloc(self->labels[i]); + for (size_t j = 0; j < self->lengths[i]; ++j) { + float y = (float)rand() / (float)(RAND_MAX); + self->labels[i][j] = y; + } + } + + self->cur_it = 0; + safe_xgboost(XGProxyDMatrixCreate(&self->_proxy)); +} + +void DataIterator_Free(DataIter *self) { + for (size_t i = 0; i < self->n; ++i) { + free(self->data[i]); + free(self->labels[i]); + } + free(self->data); + free(self->lengths); + safe_xgboost(XGDMatrixFree(self->_proxy)); +}; + +int DataIterator_Next(DataIterHandle handle) { + DataIter *self = (DataIter *)(handle); + if (self->cur_it == self->n) { + self->cur_it = 0; + return 0; /* At end */ + } + + /* A JSON string encoding array interface (standard from numpy). */ + char array[] = "{\"data\": [%lu, false], \"shape\":[%lu, 1], \"typestr\": " + "\"_array, '\0', sizeof(self->_array)); + sprintf(self->_array, array, (size_t)self->data[self->cur_it], + self->lengths[self->cur_it]); + + safe_xgboost(XGProxyDMatrixSetDataDense(self->_proxy, self->_array)); + /* The data passed in the iterator must remain valid (not being freed until the next + * iteration or reset) */ + safe_xgboost(XGDMatrixSetDenseInfo(self->_proxy, "label", + self->labels[self->cur_it], + self->lengths[self->cur_it], 1)); + self->cur_it++; + return 1; /* Continue. */ +} + +void DataIterator_Reset(DataIterHandle handle) { + DataIter *self = (DataIter *)(handle); + self->cur_it = 0; +} + +/** + * Train a regression model and save it into JSON model file. + */ +void TrainModel(DMatrix Xy) { + /* Create booster for training. */ + Booster booster; + DMatrix cache[] = {Xy}; + safe_xgboost(XGBoosterCreate(cache, 1, &booster)); + /* Use approx for external memory training. */ + safe_xgboost(XGBoosterSetParam(booster, "tree_method", "approx")); + safe_xgboost(XGBoosterSetParam(booster, "objective", "reg:squarederror")); + + /* Start training. */ + const char *validation_names[1] = {"train"}; + const char *validation_result = NULL; + size_t n_rounds = 10; + for (size_t i = 0; i < n_rounds; ++i) { + safe_xgboost(XGBoosterUpdateOneIter(booster, i, Xy)); + safe_xgboost(XGBoosterEvalOneIter(booster, i, cache, validation_names, 1, + &validation_result)); + printf("%s\n", validation_result); + } + + /* Save the model to a JSON file. */ + safe_xgboost(XGBoosterSaveModel(booster, "model.json")); + + safe_xgboost(XGBoosterFree(booster)); +} + +int main() { + DataIter iter; + DataIterator_Init(&iter, BATCH_LEN, N_BATCHS); + + /* Create DMatrix from iterator. During training, some cache files with the + * prefix "cache-" will be generated in current directory */ + char config[] = "{\"missing\": NaN, \"cache_prefix\": \"cache\"}"; + DMatrix Xy; + safe_xgboost(XGDMatrixCreateFromCallback( + &iter, iter._proxy, DataIterator_Reset, DataIterator_Next, config, &Xy)); + + TrainModel(Xy); + + safe_xgboost(XGDMatrixFree(Xy)); + + DataIterator_Free(&iter); + return 0; +} diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index 3f2e6a51d..7bca5db03 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -1,22 +1,92 @@ +"""Experimental support for external memory. This is similar to the one in +`quantile_data_iterator.py`, but for external memory instead of Quantile DMatrix. The +feature is not ready for production use yet. + + .. versionadded:: 1.5.0 + +""" import os -import xgboost as xgb +import xgboost +from typing import Callable, List, Tuple +import tempfile +import numpy as np -### simple example for using external memory version -# this is the only difference, add a # followed by a cache prefix name -# several cache file with the prefix will be generated -# currently only support convert from libsvm file -CURRENT_DIR = os.path.dirname(__file__) -dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train#dtrain.cache')) -dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test#dtest.cache')) +def make_batches( + n_samples_per_batch: int, n_features: int, n_batches: int +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Generate random batches.""" + X = [] + y = [] + rng = np.random.RandomState(1994) + for i in range(n_batches): + _X = rng.randn(n_samples_per_batch, n_features) + _y = rng.randn(n_samples_per_batch) + X.append(_X) + y.append(_y) + return X, y -# specify validations set to watch performance -param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'} -# performance notice: set nthread to be the number of your real cpu -# some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4 -#param['nthread']=num_real_cpu +class Iterator(xgboost.DataIter): + """A custom iterator for loading files in batches.""" + def __init__(self, file_paths: List[Tuple[str, str]]): + self._file_paths = file_paths + self._it = 0 + # XGBoost will generate some cache files under current directory with the prefix + # "cache" + super().__init__(cache_prefix=os.path.join(".", "cache")) -watchlist = [(dtest, 'eval'), (dtrain, 'train')] -num_round = 2 -bst = xgb.train(param, dtrain, num_round, watchlist) + def load_file(self) -> Tuple[np.ndarray, np.ndarray]: + X_path, y_path = self._file_paths[self._it] + X = np.loadtxt(X_path) + y = np.loadtxt(y_path) + assert X.shape[0] == y.shape[0] + return X, y + + def next(self, input_data: Callable) -> int: + """Advance the iterator by 1 step and pass the data to XGBoost. This function is + called by XGBoost during the construction of ``DMatrix`` + + """ + if self._it == len(self._file_paths): + # return 0 to let XGBoost know this is the end of iteration + return 0 + + # input_data is a function passed in by XGBoost who has the similar signature to + # the ``DMatrix`` constructor. + X, y = self.load_file() + input_data(data=X, label=y) + self._it += 1 + return 1 + + def reset(self) -> None: + """Reset the iterator to its beginning""" + self._it = 0 + + +def main(tmpdir: str) -> xgboost.Booster: + # generate some random data for demo + batches = make_batches(1024, 17, 31) + files = [] + for i, (X, y) in enumerate(zip(*batches)): + X_path = os.path.join(tmpdir, "X-" + str(i) + ".txt") + np.savetxt(X_path, X) + y_path = os.path.join(tmpdir, "y-" + str(i) + ".txt") + np.savetxt(y_path, y) + files.append((X_path, y_path)) + + it = Iterator(files) + # For non-data arguments, specify it here once instead of passing them by the `next` + # method. + missing = np.NaN + Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) + + # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some + # caveats. This is still an experimental feature. + booster = xgboost.train({"tree_method": "approx"}, Xy) + return booster + + +if __name__ == "__main__": + with tempfile.TemporaryDirectory() as tmpdir: + main(tmpdir) diff --git a/demo/guide-python/data_iterator.py b/demo/guide-python/quantile_data_iterator.py similarity index 97% rename from demo/guide-python/data_iterator.py rename to demo/guide-python/quantile_data_iterator.py index dc910c606..97cbf388f 100644 --- a/demo/guide-python/data_iterator.py +++ b/demo/guide-python/quantile_data_iterator.py @@ -85,7 +85,7 @@ def main(): rounds = 100 it = IterForDMatrixDemo() - # Use iterator, must be `DeviceQuantileDMatrix` + # Use iterator, must be `DeviceQuantileDMatrix` for quantile DMatrix. m_with_it = xgboost.DeviceQuantileDMatrix(it) # Use regular DMatrix. diff --git a/doc/tutorials/c_api_tutorial.rst b/doc/tutorials/c_api_tutorial.rst index b7d012e5e..fc7664c6d 100644 --- a/doc/tutorials/c_api_tutorial.rst +++ b/doc/tutorials/c_api_tutorial.rst @@ -1,8 +1,8 @@ -############################## -C API Tutorial -############################## +############## +C API Tutorial +############## -In this tutorial, we are going to install XGBoost library & configure the CMakeLists.txt file of our C/C++ application to link XGBoost library with our application. Later on, we will see some useful tips for using C API and code snippets as examples to use various functions available in C API to perform basic task like loading, training model & predicting on test dataset. +In this tutorial, we are going to install XGBoost library & configure the CMakeLists.txt file of our C/C++ application to link XGBoost library with our application. Later on, we will see some useful tips for using C API and code snippets as examples to use various functions available in C API to perform basic task like loading, training model & predicting on test dataset. .. contents:: :backlinks: none @@ -12,7 +12,7 @@ In this tutorial, we are going to install XGBoost library & configure the CMakeL Requirements ************ -Install CMake - Follow the `cmake installation documentation `_ for instructions. +Install CMake - Follow the `cmake installation documentation `_ for instructions. Install Conda - Follow the `conda installation documentation `_ for instructions ************************************* @@ -31,18 +31,18 @@ Run the following commands on your terminal. The below commands will install the # Activate the Conda environment, into which we'll install XGBoost conda activate [env_name] # Build the compiled version of XGBoost inside the build folder - cmake .. -DBUILD_STATIC_LIB=ON -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX + cmake .. -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX # install XGBoost in your conda environment (usually under [your home directory]/miniconda3) make install ********************************************************************* -Configure CMakeList.txt file of your application to link with XGBoost +Configure CMakeList.txt file of your application to link with XGBoost ********************************************************************* Here, we assume that your C++ application is using CMake for builds. Use ``find_package()`` and ``target_link_libraries()`` in your application's CMakeList.txt to link with the XGBoost library: - + .. code-block:: cmake cmake_minimum_required(VERSION 3.13) @@ -79,8 +79,8 @@ a. In a C application: Use the following macro to guard all calls to XGBoost's C .. code-block:: c - #define safe_xgboost(call) { \ - int err = (call); \ + #define safe_xgboost(call) { \ + int err = (call); \ if (err != 0) { \ fprintf(stderr, "%s:%d: error in %s: %s\n", __FILE__, __LINE__, #call, XGBGetLastError()); \ exit(1); \ @@ -101,8 +101,8 @@ b. In a C++ application: modify the macro ``safe_xgboost`` to throw an exception .. code-block:: cpp - #define safe_xgboost(call) { \ - int err = (call); \ + #define safe_xgboost(call) { \ + int err = (call); \ if (err != 0) { \ throw new Exception(std::string(__FILE__) + ":" + std::to_string(__LINE__) + \ ": error in " + #call + ":" + XGBGetLastError())); \ @@ -125,29 +125,29 @@ c. Assertion technique: It works both in C/ C++. If expression evaluates to 0 (f #include #include #include - + int main(int argc, char** argv) { int silent = 0; - + BoosterHandle booster; - + // do something with booster - + //free the memory XGBoosterFree(booster) DMatrixHandle DMatrixHandle_param; - + // do something with DMatrixHandle_param - + // free the memory XGDMatrixFree(DMatrixHandle_param); - + return 0; } -3. For tree models, it is important to use consistent data formats during training and scoring/ predicting otherwise it will result in wrong outputs. +3. For tree models, it is important to use consistent data formats during training and scoring/ predicting otherwise it will result in wrong outputs. Example if we our training data is in ``dense matrix`` format then your prediction dataset should also be a ``dense matrix`` or if training in ``libsvm`` format then dataset for prediction should also be in ``libsvm`` format. @@ -166,7 +166,7 @@ Sample examples along with Code snippet to use C API functions 1. If the dataset is available in a file, it can be loaded into a ``DMatrix`` object using the `XGDMatrixCreateFromFile `_ .. code-block:: c - + DMatrixHandle data; // handle to DMatrix // Load the dat from file & store it in data variable of DMatrixHandle datatype safe_xgboost(XGDMatrixCreateFromFile("/path/to/file/filename", silent, &data)); @@ -188,10 +188,10 @@ Sample examples along with Code snippet to use C API functions // dmatrix variable will contain the created DMatrix using it safe_xgboost(XGDMatrixCreateFromMat(data1, 1, 50, 0, &dmatrix)); // here -1 represents the missing value in the matrix dataset - safe_xgboost(XGDMatrixCreateFromMat(data2, ROWS, COLS, -1, &dmatrix2)(; + safe_xgboost(XGDMatrixCreateFromMat(data2, ROWS, COLS, -1, &dmatrix2)); -3. Create a Booster object for training & testing on dataset using `XGBoosterCreate `_ +3. Create a Booster object for training & testing on dataset using `XGBoosterCreate `_ .. code-block:: c @@ -201,7 +201,7 @@ Sample examples along with Code snippet to use C API functions DMatrixHandle eval_dmats[eval_dmats_size] = {train, test}; safe_xgboost(XGBoosterCreate(eval_dmats, eval_dmats_size, &booster)); - + 4. For each ``DMatrix`` object, set the labels using `XGDMatrixSetFloatInfo `_. Later you can access the label using `XGDMatrixGetFloatInfo `_. .. code-block:: c @@ -221,7 +221,7 @@ Sample examples along with Code snippet to use C API functions // Loading the labels safe_xgboost(XGDMatrixSetFloatInfo(dmatrix, "label", labels, ROWS)); - + // reading the labels and store the length of the result bst_ulong result_len; @@ -233,12 +233,12 @@ Sample examples along with Code snippet to use C API functions for(unsigned int i = 0; i < result_len; i++) { printf("label[%i] = %f\n", i, result[i]); } - - + + 5. Set the parameters for the ``Booster`` object according to the requirement using `XGBoosterSetParam `_ . Check out the full list of parameters available `here `_ . .. code-block :: c - + BoosterHandle booster; safe_xgboost(XGBoosterSetParam(booster, "booster", "gblinear")); // default max_depth =6 diff --git a/doc/tutorials/external_memory.rst b/doc/tutorials/external_memory.rst index 4dc22571b..b9acf09cb 100644 --- a/doc/tutorials/external_memory.rst +++ b/doc/tutorials/external_memory.rst @@ -1,6 +1,75 @@ ##################################### Using XGBoost External Memory Version ##################################### + +XGBoost supports loading data from external memory using builtin data parser. And +starting from version 1.5, users can also define a custom iterator to load data in chunks. +The feature is still experimental and not yet ready for production use. In this tutorial +we will introduce both methods. Please note that training on data from external memory is +not supported by ``exact`` tree method. + +************* +Data Iterator +************* + +Starting from XGBoost 1.5, users can define their own data loader using Python or C +interface. There are some examples in the ``demo`` directory for quick start. This is a +generalized version of text input external memory, where users no longer need to prepare a +text file that XGBoost recognizes. To enable the feature, user need to define a data +iterator with 2 class methods ``next`` and ``reset`` then pass it into ``DMatrix`` +constructor. + +.. code-block:: python + + import os + from typing import List, Callable + import xgboost + from sklearn.datasets import load_svmlight_file + + class Iterator(xgboost.DataIter): + def __init__(self, svm_file_paths: List[str]): + self._file_paths = svm_file_paths + self._it = 0 + # XGBoost will generate some cache files under current directory with the prefix + # "cache" + super().__init__(cache_prefix=os.path.join(".", "cache")) + + def next(self, input_data: Callable): + """Advance the iterator by 1 step and pass the data to XGBoost. This function is + called by XGBoost during the construction of ``DMatrix`` + + """ + if self._it == len(self._file_paths): + # return 0 to let XGBoost know this is the end of iteration + return 0 + + # input_data is a function passed in by XGBoost who has the exact same signature of + # ``DMatrix`` + X, y = load_svmlight_file(self._file_paths[self._it]) + input_data(X, y) + self._it += 1 + # Return 1 to let XGBoost know we haven't seen all the files yet. + return 1 + + def reset(self): + """Reset the iterator to its beginning""" + self._it = 0 + + it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"]) + Xy = xgboost.DMatrix(it) + + # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats + # as noted in following sections. + booster = xgboost.train({"tree_method": "approx"}, Xy) + + +The above snippet is a simplifed version of ``demo/guide-python/external_memory.py``. For +an example in C, please see ``demo/c-api/external-memory/``. + +**************** +Text File Inputs +**************** + There is no big difference between using external memory version and in-memory version. The only difference is the filename format. @@ -36,10 +105,11 @@ more notes about text input formats, see :doc:`/tutorials/input_format`. For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``. -*********** -GPU Version -*********** -External memory is fully supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``). + +********************************** +GPU Version (GPU Hist tree method) +********************************** +External memory is supported in GPU algorithms (i.e. when ``tree_method`` is set to ``gpu_hist``). If you are still getting out-of-memory errors after enabling external memory, try subsampling the data to further reduce GPU memory usage: @@ -52,23 +122,14 @@ data to further reduce GPU memory usage: 'sampling_method': 'gradient_based', } -For more information, see `this paper `_. +For more information, see `this paper `_. Internally +the tree method still concatenate all the chunks into 1 final histogram index due to +performance reason, but in compressed format. So its scalability has an upper bound but +still has lower memory cost in general. -******************* -Distributed Version -******************* -The external memory mode naturally works on distributed version, you can simply set path like +******** +CPU Hist +******** -.. code-block:: none - - data = "hdfs://path-to-data/#dtrain.cache" - -XGBoost will cache the data to the local position. When you run on YARN, the current folder is temporary -so that you can directly use ``dtrain.cache`` to cache to current folder. - -*********** -Limitations -*********** -* The ``hist`` tree method hasn't been tested thoroughly with external memory support (see - `this issue `_). -* OSX is not tested. +It's limited by the same factor of GPU Hist, except that gradient based sampling is not +yet supported on CPU. diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index e0ff434df..663b5a5a2 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -6,7 +6,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md import os -from .core import DMatrix, DeviceQuantileDMatrix, Booster +from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter from .training import train, cv from . import rabit # noqa from . import tracker # noqa @@ -25,7 +25,7 @@ VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') with open(VERSION_FILE) as f: __version__ = f.read().strip() -__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', +__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter', 'train', 'cv', 'RabitTracker', 'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker', diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 51b87c3df..8f66eed52 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -5,7 +5,7 @@ import collections # pylint: disable=no-name-in-module,import-error from collections.abc import Mapping -from typing import List, Optional, Any, Union, Dict +from typing import List, Optional, Any, Union, Dict, TypeVar # pylint: enable=no-name-in-module,import-error from typing import Callable, Tuple import ctypes @@ -313,78 +313,130 @@ def _prediction_output(shape, dims, predts, is_cuda): return arr_predict -class DataIter: - '''The interface for user defined data iterator. Currently is only supported by Device - DMatrix. +class DataIter: # pylint: disable=too-many-instance-attributes + """The interface for user defined data iterator. + + Parameters + ---------- + cache_prefix: + Prefix to the cache files, only used in external memory. It can be either an URI + or a file path. + + """ + _T = TypeVar("_T") + + def __init__(self, cache_prefix: Optional[str] = None) -> None: + self.cache_prefix = cache_prefix - ''' - def __init__(self): self._handle = _ProxyDMatrix() - self.exception = None - self.enable_categorical = False - self._allow_host = False + self._exception: Optional[Exception] = None + self._enable_categorical = False + self._allow_host = True + # Stage data in Python until reset or next is called to avoid data being free. + self._temporary_data = None + + def _get_callbacks( + self, allow_host: bool, enable_categorical: bool + ) -> Tuple[Callable, Callable]: + assert hasattr(self, "cache_prefix"), "__init__ is not called." + self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( + self._reset_wrapper + ) + self._next_callback = ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + )(self._next_wrapper) + self._allow_host = allow_host + self._enable_categorical = enable_categorical + return self._reset_callback, self._next_callback @property - def proxy(self): - '''Handler of DMatrix proxy.''' + def proxy(self) -> "_ProxyDMatrix": + """Handle of DMatrix proxy.""" return self._handle - def reset_wrapper(self, this): # pylint: disable=unused-argument - '''A wrapper for user defined `reset` function.''' - self.reset() + def _handle_exception(self, fn: Callable, dft_ret: _T) -> _T: + if self._exception is not None: + return dft_ret - def next_wrapper(self, this): # pylint: disable=unused-argument - '''A wrapper for user defined `next` function. + try: + return fn() + except Exception as e: # pylint: disable=broad-except + # Defer the exception in order to return 0 and stop the iteration. + # Exception inside a ctype callback function has no effect except + # for printing to stderr (doesn't stop the execution). + tb = sys.exc_info()[2] + # On dask, the worker is restarted and somehow the information is + # lost. + self._exception = e.with_traceback(tb) + return dft_ret + + def _reraise(self) -> None: + self._temporary_data = None + if self._exception is not None: + # pylint 2.7.0 believes `self._exception` can be None even with `assert + # isinstace` + exc = self._exception + self._exception = None + raise exc # pylint: disable=raising-bad-type + + def __del__(self) -> None: + assert self._temporary_data is None, self._temporary_data + assert self._exception is None + + def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument + """A wrapper for user defined `reset` function.""" + # free the data + self._temporary_data = None + self._handle_exception(self.reset, None) + + def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument + """A wrapper for user defined `next` function. `this` is not used in Python. ctypes can handle `self` of a Python member function automatically when converting it to c function pointer. - ''' - if self.exception is not None: - return 0 - + """ + @_deprecate_positional_args def data_handle( - data, - feature_names=None, - feature_types=None, - **kwargs + data: Any, + *, + feature_names: Optional[List[str]] = None, + feature_types: Optional[List[str]] = None, + **kwargs: Any, ): from .data import dispatch_proxy_set_data from .data import _proxy_transform - data, feature_names, feature_types = _proxy_transform( - data, feature_names, feature_types, self.enable_categorical, + + transformed, feature_names, feature_types = _proxy_transform( + data, + feature_names, + feature_types, + self._enable_categorical, ) - dispatch_proxy_set_data(self.proxy, data, self._allow_host) + # Stage the data, meta info are copied inside C++ MetaInfo. + self._temporary_data = transformed + dispatch_proxy_set_data(self.proxy, transformed, self._allow_host) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, **kwargs, ) - try: - # Differ the exception in order to return 0 and stop the iteration. - # Exception inside a ctype callback function has no effect except - # for printing to stderr (doesn't stop the execution). - ret = self.next(data_handle) # pylint: disable=not-callable - except Exception as e: # pylint: disable=broad-except - tb = sys.exc_info()[2] - # On dask the worker is restarted and somehow the information is - # lost. - self.exception = e.with_traceback(tb) - return 0 - return ret + # pylint: disable=not-callable + return self._handle_exception(lambda: self.next(data_handle), 0) - def reset(self): - '''Reset the data iterator. Prototype for user defined function.''' + def reset(self) -> None: + """Reset the data iterator. Prototype for user defined function.""" raise NotImplementedError() - def next(self, input_data): - '''Set the next batch of data. + def next(self, input_data: Callable) -> int: + """Set the next batch of data. Parameters ---------- - data_handle: callable + data_handle: A function with same data fields like `data`, `label` with `xgboost.DMatrix`. @@ -392,7 +444,7 @@ class DataIter: ------- 0 if there's no more batch, otherwise 1. - ''' + """ raise NotImplementedError() @@ -546,7 +598,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes self.handle = None return - from .data import dispatch_data_backend + from .data import dispatch_data_backend, _is_iter + + if _is_iter(data): + self._init_from_iter(data, enable_categorical) + assert self.handle is not None + return handle, feature_names, feature_types = dispatch_data_backend( data, @@ -575,6 +632,33 @@ class DMatrix: # pylint: disable=too-many-instance-attributes if feature_types is not None: self.feature_types = feature_types + def _init_from_iter(self, iterator: DataIter, enable_categorical: bool): + it = iterator + args = { + "missing": self.missing, + "nthread": self.nthread, + "cache_prefix": it.cache_prefix if it.cache_prefix else "", + } + args = from_pystr_to_cstr(json.dumps(args)) + handle = ctypes.c_void_p() + # pylint: disable=protected-access + reset_callback, next_callback = it._get_callbacks( + True, enable_categorical + ) + ret = _LIB.XGDMatrixCreateFromCallback( + None, + it.proxy.handle, + reset_callback, + next_callback, + args, + ctypes.byref(handle), + ) + # pylint: disable=protected-access + it._reraise() + # delay check_call to throw intermediate exception first + _check_call(ret) + self.handle = handle + def __del__(self): if hasattr(self, "handle") and self.handle: _check_call(_LIB.XGDMatrixFree(self.handle)) @@ -907,7 +991,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes if len(feature_names) != len(set(feature_names)): raise ValueError('feature_names must be unique') if len(feature_names) != self.num_col() and self.num_col() != 0: - msg = 'feature_names must have the same length as data' + msg = ("feature_names must have the same length as data, ", + f"expected {self.num_col()}, got {len(feature_names)}") raise ValueError(msg) # prohibit to use symbols may affect to parse. e.g. []< if not all(isinstance(f, str) and @@ -1001,30 +1086,44 @@ class _ProxyDMatrix(DMatrix): inplace_predict). """ + def __init__(self): # pylint: disable=super-init-not-called self.handle = ctypes.c_void_p() _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) def _set_data_from_cuda_interface(self, data): - '''Set data from CUDA array interface.''' + """Set data from CUDA array interface.""" interface = data.__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') + interface_str = bytes(json.dumps(interface, indent=2), "utf-8") _check_call( - _LIB.XGProxyDMatrixSetDataCudaArrayInterface( - self.handle, - interface_str - ) + _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) def _set_data_from_cuda_columnar(self, data): - '''Set data from CUDA columnar format.''' + """Set data from CUDA columnar format.""" from .data import _cudf_array_interfaces + _, interfaces_str = _cudf_array_interfaces(data) + _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) + + def _set_data_from_array(self, data: np.ndarray): + """Set data from numpy array.""" + from .data import _array_interface + _check_call( - _LIB.XGProxyDMatrixSetDataCudaColumnar( - self.handle, - interfaces_str - ) + _LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data)) + ) + + def _set_data_from_csr(self, csr): + """Set data from scipy csr""" + from .data import _array_interface + + _LIB.XGProxyDMatrixSetDataCSR( + self.handle, + _array_interface(csr.indptr), + _array_interface(csr.indices), + _array_interface(csr.data), + ctypes.c_size_t(csr.shape[1]), ) @@ -1110,13 +1209,14 @@ class DeviceQuantileDMatrix(DMatrix): else: it = SingleBatchInternalIter(data=data, **meta) - it.enable_categorical = enable_categorical - reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper) - next_callback = ctypes.CFUNCTYPE( - ctypes.c_int, - ctypes.c_void_p, - )(it.next_wrapper) handle = ctypes.c_void_p() + # pylint: disable=protected-access + reset_callback, next_callback = it._get_callbacks(False, enable_categorical) + if it.cache_prefix is not None: + raise ValueError( + "DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix " + "in iterator to fix this error." + ) ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( None, it.proxy.handle, @@ -1127,10 +1227,8 @@ class DeviceQuantileDMatrix(DMatrix): ctypes.c_int(self.max_bin), ctypes.byref(handle), ) - if it.exception is not None: - # pylint 2.7.0 believes `it.exception` can be None even with `assert - # isinstace` - raise it.exception # pylint: disable=raising-bad-type + # pylint: disable=protected-access + it._reraise() # delay check_call to throw intermediate exception first _check_call(ret) self.handle = handle @@ -2241,8 +2339,8 @@ class Booster(object): # pylint: disable=too-many-locals fmap = os.fspath(os.path.expanduser(fmap)) if not PANDAS_INSTALLED: - raise Exception(('pandas must be available to use this method.' - 'Install pandas before calling again.')) + raise ImportError(('pandas must be available to use this method.' + 'Install pandas before calling again.')) if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: raise ValueError('This method is not defined for Booster type {}' diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index aa4764c3c..501a2bee7 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,7 +5,7 @@ import ctypes import json import warnings import os -from typing import Any, Tuple +from typing import Any, Tuple, Callable import numpy as np @@ -238,10 +238,13 @@ def _transform_pandas_df(data, enable_categorical, if meta and len(data.columns) > 1: raise ValueError( 'DataFrame for {meta} cannot have multiple columns'.format( - meta=meta)) + meta=meta) + ) dtype = meta_type if meta_type else np.float32 - data = np.ascontiguousarray(data.values, dtype=dtype) + data = data.values + if meta_type: + data = data.astype(meta_type) return data, feature_names, feature_types @@ -759,19 +762,19 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 area for meta info. ''' - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.kwargs = kwargs self.it = 0 # pylint: disable=invalid-name super().__init__() - def next(self, input_data): + def next(self, input_data: Callable) -> int: if self.it == 1: return 0 self.it += 1 input_data(**self.kwargs) return 1 - def reset(self): + def reset(self) -> None: self.it = 0 @@ -785,6 +788,15 @@ def _proxy_transform(data, feature_names, feature_types, enable_categorical): return data, feature_names, feature_types if _is_dlpack(data): return _transform_dlpack(data), feature_names, feature_types + if _is_numpy_array(data): + return data, feature_names, feature_types + if _is_scipy_csr(data): + return data, feature_names, feature_types + if _is_pandas_df(data): + arr, feature_names, feature_types = _transform_pandas_df( + data, enable_categorical, feature_names, feature_types + ) + return arr, feature_names, feature_types raise TypeError("Value type is not supported for data iterator:" + str(type(data))) @@ -803,7 +815,16 @@ def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) - data = _transform_dlpack(data) proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 return - # Part of https://github.com/dmlc/xgboost/pull/7070 - assert allow_host is False, "host data is not yet supported." - raise TypeError('Value type is not supported for data iterator:' + - str(type(data))) + + err = TypeError("Value type is not supported for data iterator:" + str(type(data))) + + if not allow_host: + raise err + + if _is_numpy_array(data): + proxy._set_data_from_array(data) # pylint: disable=W0212 + return + if _is_scipy_csr(data): + proxy._set_data_from_csr(data) # pylint: disable=W0212 + return + raise err diff --git a/src/data/data.cc b/src/data/data.cc index 99041ef96..de27f51f6 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -874,8 +874,15 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { tid); } }); + + if (this->data.Empty()) { + transpose.offset.Resize(num_columns + 1); + transpose.offset.Fill(0); + } + CHECK_EQ(transpose.offset.Size(), num_columns + 1); return transpose; } + void SparsePage::Push(const SparsePage &batch) { auto& data_vec = data.HostVector(); auto& offset_vec = offset.HostVector(); @@ -1007,6 +1014,7 @@ void SparsePage::PushCSC(const SparsePage &batch) { auto const& other_offset = batch.offset.ConstHostVector(); if (other_data.empty()) { + self_offset = other_offset; return; } if (!self_data.empty()) { diff --git a/src/data/data.cu b/src/data/data.cu index a9397803c..de8a8c248 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -19,11 +19,16 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector* out) { cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); int32_t ptr_device = attr.device; - dh::safe_cuda(cudaSetDevice(ptr_device)); + if (ptr_device >= 0) { + dh::safe_cuda(cudaSetDevice(ptr_device)); + } return ptr_device; }; auto ptr_device = SetDeviceToPtr(column.data); + if (column.num_rows == 0) { + return; + } out->SetDevice(ptr_device); out->Resize(column.num_rows); @@ -123,7 +128,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { << "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1); ArrayInterface array_interface(interface_str); std::string key{c_key}; - array_interface.AsColumnVector(); + if (!((array_interface.num_cols == 1 && array_interface.num_rows == 0) || + (array_interface.num_cols == 0 && array_interface.num_rows == 1))) { + // Not an empty column, transform it. + array_interface.AsColumnVector(); + } + CHECK(!array_interface.valid.Data()) << "Meta info " << key << " should be dense, found validity mask"; if (array_interface.num_rows == 0) { diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index a772a064f..2da786969 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -154,7 +154,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { size_t NumRows() const { return num_rows_; } size_t NumColumns() const { return columns_.size(); } - size_t DeviceIdx() const { return device_idx_; } + int32_t DeviceIdx() const { return device_idx_; } private: CudfAdapterBatch batch_; @@ -202,12 +202,12 @@ class CupyAdapter : public detail::SingleBatchDataIter { size_t NumRows() const { return array_interface_.num_rows; } size_t NumColumns() const { return array_interface_.num_cols; } - size_t DeviceIdx() const { return device_idx_; } + int32_t DeviceIdx() const { return device_idx_; } private: ArrayInterface array_interface_; CupyAdapterBatch batch_; - int device_idx_; + int32_t device_idx_ {-1}; }; // Returns maximum row length diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 115d593e1..6d79250a0 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -10,6 +10,7 @@ namespace xgboost { namespace data { void EllpackPageSource::Fetch() { + dh::safe_cuda(cudaSetDevice(param_.gpu_id)); if (!this->ReadCache()) { auto const &csr = source_->Page(); this->page_.reset(new EllpackPage{}); diff --git a/src/data/proxy_dmatrix.cu b/src/data/proxy_dmatrix.cu index adad6f4c4..6fbd72100 100644 --- a/src/data/proxy_dmatrix.cu +++ b/src/data/proxy_dmatrix.cu @@ -14,6 +14,9 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) { device_ = adapter->DeviceIdx(); this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); + if (device_ < 0) { + CHECK_EQ(this->Info().num_row_, 0); + } } void DMatrixProxy::FromCudaArray(std::string interface_str) { @@ -22,6 +25,9 @@ void DMatrixProxy::FromCudaArray(std::string interface_str) { device_ = adapter->DeviceIdx(); this->Info().num_col_ = adapter->NumColumns(); this->Info().num_row_ = adapter->NumRows(); + if (device_ < 0) { + CHECK_EQ(this->Info().num_row_, 0); + } } } // namespace data diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 11d664666..2f130c7af 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -141,9 +141,8 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_ } else { LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); } - auto value = dmlc::get>( - proxy->Adapter())->Value(); - return fn(value); + return std::result_of_t>()->Value()))>(); } } } // namespace data diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 80de6706c..e770d4aa2 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -16,7 +16,10 @@ namespace data { // be supported in future. Does not currently support inferring row/column size template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { - dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx())); + auto device = + adapter->DeviceIdx() < 0 ? dh::CurrentDevice() : adapter->DeviceIdx(); + CHECK_GE(device, 0); + dh::safe_cuda(cudaSetDevice(device)); CHECK(adapter->NumRows() != kAdapterUnknownSize); CHECK(adapter->NumColumns() != kAdapterUnknownSize); @@ -27,8 +30,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { // Enforce single batch CHECK(!adapter->Next()); - info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(), - missing, sparse_page_.get()); + info_.num_nonzero_ = + CopyToSparsePage(adapter->Value(), device, missing, sparse_page_.get()); info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); // Synchronise worker columns diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index d6e26195b..e0502675e 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -15,6 +15,29 @@ MetaInfo &SparsePageDMatrix::Info() { return info_; } const MetaInfo &SparsePageDMatrix::Info() const { return info_; } +namespace detail { +// Use device dispatch +size_t NSamplesDevice(DMatrixProxy *proxy) +#if defined(XGBOOST_USE_CUDA) +; // NOLINT +#else +{ + common::AssertGPUSupport(); + return 0; +} +#endif +size_t NFeaturesDevice(DMatrixProxy *proxy) +#if defined(XGBOOST_USE_CUDA) +; // NOLINT +#else +{ + common::AssertGPUSupport(); + return 0; +} +#endif +} // namespace detail + + SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, @@ -35,13 +58,24 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p size_t nnz = 0; auto num_rows = [&]() { - return HostAdapterDispatch( - proxy, [](auto const &value) { return value.NumRows(); }); + bool type_error {false}; + size_t n_samples = HostAdapterDispatch( + proxy, [](auto const &value) { return value.NumRows(); }, &type_error); + if (type_error) { + n_samples = detail::NSamplesDevice(proxy); + } + return n_samples; }; auto num_cols = [&]() { - return HostAdapterDispatch( - proxy, [](auto const &value) { return value.NumCols(); }); + bool type_error {false}; + size_t n_features = HostAdapterDispatch( + proxy, [](auto const &value) { return value.NumCols(); }, &type_error); + if (type_error) { + n_features = detail::NFeaturesDevice(proxy); + } + return n_features; }; + // the proxy is iterated together with the sparse page source so we can obtain all // information in 1 pass. for (auto const &page : this->GetRowBatchesImpl()) { diff --git a/src/data/sparse_page_source.cu b/src/data/sparse_page_source.cu index 8c292ded6..bcadffaff 100644 --- a/src/data/sparse_page_source.cu +++ b/src/data/sparse_page_source.cu @@ -7,8 +7,24 @@ namespace xgboost { namespace data { + +namespace detail { +size_t NSamplesDevice(DMatrixProxy *proxy) { + return Dispatch(proxy, [](auto const &value) { return value.NumRows(); }); +} + +size_t NFeaturesDevice(DMatrixProxy *proxy) { + return Dispatch(proxy, [](auto const &value) { return value.NumCols(); }); +} +} // namespace detail + void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) { auto device = proxy->DeviceIdx(); + if (device < 0) { + device = dh::CurrentDevice(); + } + CHECK_GE(device, 0); + Dispatch(proxy, [&](auto const &value) { CopyToSparsePage(value, device, missing, page); }); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 2b634e7aa..eec5052dc 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -236,7 +236,7 @@ class SparsePageSource : public SparsePageSourceImpl { iter_{iter}, proxy_{proxy} { if (!cache_info_->written) { iter_.Reset(); - iter_.Next(); + CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch."; } this->Fetch(); } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index fd612b04b..952a60f0f 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -444,7 +444,7 @@ class ColMaker: public TreeUpdater { } // update the solution candidate - virtual void UpdateSolution(const SparsePage &batch, + virtual void UpdateSolution(const SortedCSCPage &batch, const std::vector &feat_set, const std::vector &gpair, DMatrix*) { diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index d598b420e..b9e91e6b1 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -77,12 +77,23 @@ TEST(SparsePageDMatrix, RetainEllpackPage) { for (size_t i = 0; i < iterators.size(); ++i) { ASSERT_EQ((*iterators[i]).Impl()->gidx_buffer.HostVector(), gidx_buffers.at(i).HostVector()); + if (i != iterators.size() - 1) { + ASSERT_EQ(iterators[i].use_count(), 1); + } else { + // The last batch is still being held by sparse page DMatrix. + ASSERT_EQ(iterators[i].use_count(), 2); + } } // make sure it's const and the caller can not modify the content of page. for (auto& page : m->GetBatches({0, 32})) { static_assert(std::is_const>::value, ""); } + + // The above iteration clears out all references inside DMatrix. + for (auto const& ptr : iterators) { + ASSERT_TRUE(ptr.unique()); + } } TEST(SparsePageDMatrix, EllpackPageContent) { diff --git a/tests/python-gpu/test_gpu_data_iterator.py b/tests/python-gpu/test_gpu_data_iterator.py new file mode 100644 index 000000000..5b9130301 --- /dev/null +++ b/tests/python-gpu/test_gpu_data_iterator.py @@ -0,0 +1,32 @@ +import numpy as np +import xgboost as xgb +from hypothesis import given, strategies, settings +import pytest +import sys + +sys.path.append("tests/python") +from test_data_iterator import SingleBatch, make_batches +from test_data_iterator import test_single_batch as cpu_single_batch +from test_data_iterator import run_data_iterator +from testing import IteratorForTest, no_cupy + + +def test_gpu_single_batch() -> None: + cpu_single_batch("gpu_hist") + + +@pytest.mark.skipif(**no_cupy()) +@given( + strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13) +) +@settings(deadline=None) +def test_gpu_data_iterator( + n_samples_per_batch: int, n_features: int, n_batches: int +) -> None: + run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", True) + run_data_iterator(n_samples_per_batch, n_features, n_batches, "gpu_hist", False) + + +def test_cpu_data_iterator() -> None: + """Make sure CPU algorithm can handle GPU inputs""" + run_data_iterator(1024, 2, 3, "approx", True) diff --git a/tests/python-gpu/test_gpu_demos.py b/tests/python-gpu/test_gpu_demos.py index 9dc6e4247..7e79378a3 100644 --- a/tests/python-gpu/test_gpu_demos.py +++ b/tests/python-gpu/test_gpu_demos.py @@ -9,7 +9,7 @@ import test_demos as td # noqa @pytest.mark.skipif(**tm.no_cupy()) def test_data_iterator(): - script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py') + script = os.path.join(td.PYTHON_DEMO_DIR, 'quantile_data_iterator.py') cmd = ['python', script] subprocess.check_call(cmd) diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 11140a708..a2da32d2f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -112,7 +112,6 @@ class TestGPUUpdaters: tm.dataset_strategy) @settings(deadline=None) def test_external_memory(self, param, num_rounds, dataset): - pytest.xfail(reason='TestGPUUpdaters::test_external_memory is flaky') # We cannot handle empty dataset yet assume(len(dataset.y) > 0) param['tree_method'] = 'gpu_hist' diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py new file mode 100644 index 000000000..742430d5d --- /dev/null +++ b/tests/python/test_data_iterator.py @@ -0,0 +1,135 @@ +import xgboost as xgb +from xgboost.data import SingleBatchInternalIter as SingleBatch +import numpy as np +from testing import IteratorForTest +from typing import Tuple, List +import pytest +from hypothesis import given, strategies, settings +from scipy.sparse import csr_matrix + + +def make_batches( + n_samples_per_batch: int, n_features: int, n_batches: int, use_cupy: bool = False +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + X = [] + y = [] + if use_cupy: + import cupy + + rng = cupy.random.RandomState(1994) + else: + rng = np.random.RandomState(1994) + for i in range(n_batches): + _X = rng.randn(n_samples_per_batch, n_features) + _y = rng.randn(n_samples_per_batch) + X.append(_X) + y.append(_y) + return X, y + + +def test_single_batch(tree_method: str = "approx") -> None: + from sklearn.datasets import load_breast_cancer + + n_rounds = 10 + X, y = load_breast_cancer(return_X_y=True) + X = X.astype(np.float32) + y = y.astype(np.float32) + + Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) + from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + + Xy = xgb.DMatrix(X, y) + from_dmat = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + assert from_it.get_dump() == from_dmat.get_dump() + + X, y = load_breast_cancer(return_X_y=True, as_frame=True) + X = X.astype(np.float32) + Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) + from_pd = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + # remove feature info to generate exact same text representation. + from_pd.feature_names = None + from_pd.feature_types = None + + assert from_pd.get_dump() == from_it.get_dump() + + X, y = load_breast_cancer(return_X_y=True) + X = csr_matrix(X) + Xy = xgb.DMatrix(SingleBatch(data=X, label=y)) + from_it = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + + X, y = load_breast_cancer(return_X_y=True) + Xy = xgb.DMatrix(SingleBatch(data=X, label=y), missing=0.0) + from_np = xgb.train({"tree_method": tree_method}, Xy, num_boost_round=n_rounds) + assert from_np.get_dump() == from_it.get_dump() + + +def run_data_iterator( + n_samples_per_batch: int, + n_features: int, + n_batches: int, + tree_method: str, + use_cupy: bool, +) -> None: + n_rounds = 2 + + it = IteratorForTest( + *make_batches(n_samples_per_batch, n_features, n_batches, use_cupy) + ) + if n_batches == 0: + with pytest.raises(ValueError, match="1 batch"): + Xy = xgb.DMatrix(it) + return + + Xy = xgb.DMatrix(it) + assert Xy.num_row() == n_samples_per_batch * n_batches + assert Xy.num_col() == n_features + + results_from_it: xgb.callback.EvaluationMonitor.EvalsLog = {} + from_it = xgb.train( + {"tree_method": tree_method, "max_depth": 2}, + Xy, + num_boost_round=n_rounds, + evals=[(Xy, "Train")], + evals_result=results_from_it, + verbose_eval=False, + ) + it_predt = from_it.predict(Xy) + + X, y = it.as_arrays() + Xy = xgb.DMatrix(X, y) + assert Xy.num_row() == n_samples_per_batch * n_batches + assert Xy.num_col() == n_features + + results_from_arrays: xgb.callback.EvaluationMonitor.EvalsLog = {} + from_arrays = xgb.train( + {"tree_method": tree_method, "max_depth": 2}, + Xy, + num_boost_round=n_rounds, + evals=[(Xy, "Train")], + evals_result=results_from_arrays, + verbose_eval=False, + ) + arr_predt = from_arrays.predict(Xy) + + if tree_method != "gpu_hist": + rtol = 1e-1 # flaky + else: + np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-3) + rtol = 1e-6 + + np.testing.assert_allclose( + results_from_it["Train"]["rmse"], + results_from_arrays["Train"]["rmse"], + rtol=rtol, + ) + + +@given( + strategies.integers(0, 1024), strategies.integers(1, 7), strategies.integers(0, 13) +) +@settings(deadline=None) +def test_data_iterator( + n_samples_per_batch: int, n_features: int, n_batches: int +) -> None: + run_data_iterator(n_samples_per_batch, n_features, n_batches, "approx", False) + run_data_iterator(n_samples_per_batch, n_features, n_batches, "hist", False) diff --git a/tests/python/testing.py b/tests/python/testing.py index d47d1c18c..fe6d9b32c 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -8,7 +8,7 @@ from io import StringIO from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from xgboost.compat import DASK_INSTALLED import pytest -import tempfile +import gc import xgboost as xgb import numpy as np import platform @@ -143,10 +143,35 @@ def skip_s390x(): return {"condition": condition, "reason": reason} +class IteratorForTest(xgb.core.DataIter): + def __init__(self, X, y): + assert len(X) == len(y) + self.X = X + self.y = y + self.it = 0 + super().__init__("./") + + def next(self, input_data): + if self.it == len(self.X): + return 0 + # Use copy to make sure the iterator doesn't hold a reference to the data. + input_data(data=self.X[self.it].copy(), label=self.y[self.it].copy()) + gc.collect() # clear up the copy, see if XGBoost access freed memory. + self.it += 1 + return 1 + + def reset(self): + self.it = 0 + + def as_arrays(self): + X = np.concatenate(self.X, axis=0) + y = np.concatenate(self.y, axis=0) + return X, y + + # Contains a dataset in numpy format as well as the relevant objective and metric class TestDataset: - def __init__(self, name, get_dataset, objective, metric - ): + def __init__(self, name, get_dataset, objective, metric): self.name = name self.objective = objective self.metric = metric @@ -171,16 +196,23 @@ class TestDataset: return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin) def get_external_dmat(self): - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, 'tmptmp_1234.csv') - np.savetxt(path, - np.hstack((self.y.reshape(len(self.y), 1), self.X)), - delimiter=',') - assert os.path.exists(path) - uri = path + '?format=csv&label_column=0#tmptmp_' - # The uri looks like: - # 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_' - return xgb.DMatrix(uri, weight=self.w, base_margin=self.margin) + n_samples = self.X.shape[0] + n_batches = 10 + per_batch = n_samples // n_batches + 1 + + predictor = [] + response = [] + for i in range(n_batches): + beg = i * per_batch + end = min((i + 1) * per_batch, n_samples) + assert end != beg + X = self.X[beg: end, ...] + y = self.y[beg: end] + predictor.append(X) + response.append(y) + + it = IteratorForTest(predictor, response) + return xgb.DMatrix(it) def __repr__(self): return self.name