diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 3cfba0468..2a7d51393 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -810,7 +810,7 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, */ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out); -/*! +/** * \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a * quantized DMatrix, quantized values are returned instead. * @@ -819,8 +819,10 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out); * XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until * exiting the thread. * + * @since 1.7.0 + * * \param handle the handle to the DMatrix - * \param config Json configuration string. At the moment it should be an empty document, + * \param config JSON configuration string. At the moment it should be an empty document, * preserved for future use. * \param out_indptr indptr of output CSR matrix. * \param out_indices Column index of output CSR matrix. @@ -831,6 +833,24 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out); XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config, bst_ulong *out_indptr, unsigned *out_indices, float *out_data); +/** + * @brief Export the quantile cuts used for training histogram-based models like `hist` and + * `approx`. Useful for model compression. + * + * @since 2.0.0 + * + * @param handle the handle to the DMatrix + * @param config JSON configuration string. At the moment it should be an empty document, + * preserved for future use. + * + * @param out_indptr indptr of output CSC matrix represented by a JSON encoded + * __(cuda_)array_interface__. + * @param out_data Data value of CSC matrix represented by a JSON encoded + * __(cuda_)array_interface__. + */ +XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config, + char const **out_indptr, char const **out_data); + /** @} */ // End of DMatrix /** diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 6305abff8..472ca43b3 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -282,7 +282,7 @@ struct BatchParam { BatchParam(bst_bin_t max_bin, common::Span hessian, bool regenerate) : max_bin{max_bin}, hess{hessian}, regen{regenerate} {} - bool ParamNotEqual(BatchParam const& other) const { + [[nodiscard]] bool ParamNotEqual(BatchParam const& other) const { // Check non-floating parameters. bool cond = max_bin != other.max_bin; // Check sparse thresh. @@ -293,11 +293,11 @@ struct BatchParam { return cond; } - bool Initialized() const { return max_bin != 0; } + [[nodiscard]] bool Initialized() const { return max_bin != 0; } /** * \brief Make a copy of self for DMatrix to describe how its existing index was generated. */ - BatchParam MakeCache() const { + [[nodiscard]] BatchParam MakeCache() const { auto p = *this; // These parameters have nothing to do with how the gradient index was generated in the // first place. @@ -319,7 +319,7 @@ struct HostSparsePageView { static_cast(size)}; } - size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; } + [[nodiscard]] size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; } }; /*! @@ -337,7 +337,7 @@ class SparsePage { /*! \brief an instance of sparse vector in the batch */ using Inst = common::Span; - HostSparsePageView GetView() const { + [[nodiscard]] HostSparsePageView GetView() const { return {offset.ConstHostSpan(), data.ConstHostSpan()}; } @@ -353,12 +353,12 @@ class SparsePage { virtual ~SparsePage() = default; /*! \return Number of instances in the page. */ - inline size_t Size() const { + [[nodiscard]] size_t Size() const { return offset.Size() == 0 ? 0 : offset.Size() - 1; } /*! \return estimation of memory cost of this page */ - inline size_t MemCostBytes() const { + [[nodiscard]] size_t MemCostBytes() const { return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry); } @@ -376,7 +376,7 @@ class SparsePage { base_rowid = row_id; } - SparsePage GetTranspose(int num_columns, int32_t n_threads) const; + [[nodiscard]] SparsePage GetTranspose(int num_columns, int32_t n_threads) const; /** * \brief Sort the column index. @@ -385,7 +385,7 @@ class SparsePage { /** * \brief Check wether the column index is sorted. */ - bool IsIndicesSorted(int32_t n_threads) const; + [[nodiscard]] bool IsIndicesSorted(int32_t n_threads) const; /** * \brief Reindex the column index with an offset. */ @@ -440,49 +440,7 @@ class SortedCSCPage : public SparsePage { explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {} }; -class EllpackPageImpl; -/*! - * \brief A page stored in ELLPACK format. - * - * This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid - * including CUDA-specific implementation details in the header. - */ -class EllpackPage { - public: - /*! - * \brief Default constructor. - * - * This is used in the external memory case. An empty ELLPACK page is constructed with its content - * set later by the reader. - */ - EllpackPage(); - - /*! - * \brief Constructor from an existing DMatrix. - * - * This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix - * in CSR format. - */ - explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param); - - /*! \brief Destructor. */ - ~EllpackPage(); - - EllpackPage(EllpackPage&& that); - - /*! \return Number of instances in the page. */ - size_t Size() const; - - /*! \brief Set the base row id for this page. */ - void SetBaseRowId(std::size_t row_id); - - const EllpackPageImpl* Impl() const { return impl_.get(); } - EllpackPageImpl* Impl() { return impl_.get(); } - - private: - std::unique_ptr impl_; -}; - +class EllpackPage; class GHistIndexMatrix; template @@ -492,7 +450,7 @@ class BatchIteratorImpl { virtual ~BatchIteratorImpl() = default; virtual const T& operator*() const = 0; virtual BatchIteratorImpl& operator++() = 0; - virtual bool AtEnd() const = 0; + [[nodiscard]] virtual bool AtEnd() const = 0; virtual std::shared_ptr Page() const = 0; }; @@ -519,12 +477,12 @@ class BatchIterator { return !impl_->AtEnd(); } - bool AtEnd() const { + [[nodiscard]] bool AtEnd() const { CHECK(impl_ != nullptr); return impl_->AtEnd(); } - std::shared_ptr Page() const { + [[nodiscard]] std::shared_ptr Page() const { return impl_->Page(); } @@ -563,15 +521,15 @@ class DMatrix { this->Info().SetInfo(ctx, key, StringView{interface_str}); } /*! \brief meta information of the dataset */ - virtual const MetaInfo& Info() const = 0; + [[nodiscard]] virtual const MetaInfo& Info() const = 0; /*! \brief Get thread local memory for returning data from DMatrix. */ - XGBAPIThreadLocalEntry& GetThreadLocal() const; + [[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const; /** * \brief Get the context object of this DMatrix. The context is created during construction of * DMatrix with user specified `nthread` parameter. */ - virtual Context const* Ctx() const = 0; + [[nodiscard]] virtual Context const* Ctx() const = 0; /** * \brief Gets batches. Use range based for loop over BatchSet to access individual batches. @@ -583,16 +541,16 @@ class DMatrix { template BatchSet GetBatches(Context const* ctx, const BatchParam& param); template - bool PageExists() const; + [[nodiscard]] bool PageExists() const; // the following are column meta data, should be able to answer them fast. /*! \return Whether the data columns single column block. */ - virtual bool SingleColBlock() const = 0; + [[nodiscard]] virtual bool SingleColBlock() const = 0; /*! \brief virtual destructor */ virtual ~DMatrix(); /*! \brief Whether the matrix is dense. */ - bool IsDense() const { + [[nodiscard]] bool IsDense() const { return Info().num_nonzero_ == Info().num_row_ * Info().num_col_; } @@ -695,9 +653,9 @@ class DMatrix { BatchParam const& param) = 0; virtual BatchSet GetExtBatches(Context const* ctx, BatchParam const& param) = 0; - virtual bool EllpackExists() const = 0; - virtual bool GHistIndexExists() const = 0; - virtual bool SparsePageExists() const = 0; + [[nodiscard]] virtual bool EllpackExists() const = 0; + [[nodiscard]] virtual bool GHistIndexExists() const = 0; + [[nodiscard]] virtual bool SparsePageExists() const = 0; }; template <> diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 07e8d89cc..31f34256d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -3,6 +3,7 @@ """Core XGBoost Library.""" import copy import ctypes +import importlib.util import json import os import re @@ -381,6 +382,54 @@ def c_array( return (ctype * len(values))(*values) +def from_array_interface(interface: dict) -> NumpyOrCupy: + """Convert array interface to numpy or cupy array""" + + class Array: # pylint: disable=too-few-public-methods + """Wrapper type for communicating with numpy and cupy.""" + + _interface: Optional[dict] = None + + @property + def __array_interface__(self) -> Optional[dict]: + return self._interface + + @__array_interface__.setter + def __array_interface__(self, interface: dict) -> None: + self._interface = copy.copy(interface) + # converts some fields to tuple as required by numpy + self._interface["shape"] = tuple(self._interface["shape"]) + self._interface["data"] = tuple(self._interface["data"]) + if self._interface.get("strides", None) is not None: + self._interface["strides"] = tuple(self._interface["strides"]) + + @property + def __cuda_array_interface__(self) -> Optional[dict]: + return self.__array_interface__ + + @__cuda_array_interface__.setter + def __cuda_array_interface__(self, interface: dict) -> None: + self.__array_interface__ = interface + + arr = Array() + + if "stream" in interface: + # CUDA stream is presented, this is a __cuda_array_interface__. + spec = importlib.util.find_spec("cupy") + if spec is None: + raise ImportError("`cupy` is required for handling CUDA buffer.") + + import cupy as cp # pylint: disable=import-error + + arr.__cuda_array_interface__ = interface + out = cp.array(arr, copy=True) + else: + arr.__array_interface__ = interface + out = np.array(arr, copy=True) + + return out + + def _prediction_output( shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool ) -> NumpyOrCupy: @@ -1060,6 +1109,32 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m ) return ret + def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]: + """Get quantile cuts for quantization.""" + n_features = self.num_col() + + c_sindptr = ctypes.c_char_p() + c_sdata = ctypes.c_char_p() + config = make_jcargs() + _check_call( + _LIB.XGDMatrixGetQuantileCut( + self.handle, config, ctypes.byref(c_sindptr), ctypes.byref(c_sdata) + ) + ) + assert c_sindptr.value is not None + assert c_sdata.value is not None + + i_indptr = json.loads(c_sindptr.value) + indptr = from_array_interface(i_indptr) + assert indptr.size == n_features + 1 + assert indptr.dtype == np.uint64 + + i_data = json.loads(c_sdata.value) + data = from_array_interface(i_data) + assert data.size == indptr[-1] + assert data.dtype == np.float32 + return indptr, data + def num_row(self) -> int: """Get the number of rows in the DMatrix.""" ret = c_bst_ulong() diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 862375026..8e2e13f43 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -265,6 +265,14 @@ def make_batches( return X, y, w +def make_regression( + n_samples: int, n_features: int, use_cupy: bool +) -> Tuple[ArrayLike, ArrayLike, ArrayLike]: + """Make a simple regression dataset.""" + X, y, w = make_batches(n_samples, n_features, 1, use_cupy) + return X[0], y[0], w[0] + + def make_batches_sparse( n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float ) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]: diff --git a/python-package/xgboost/testing/updater.py b/python-package/xgboost/testing/updater.py index 4086f92c8..62df8ec2e 100644 --- a/python-package/xgboost/testing/updater.py +++ b/python-package/xgboost/testing/updater.py @@ -1,7 +1,7 @@ """Tests for updaters.""" import json from functools import partial, update_wrapper -from typing import Dict +from typing import Any, Dict import numpy as np @@ -159,3 +159,100 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None: for i in range(alpha.shape[0]): np.testing.assert_allclose(predts[:, i], predt_multi[:, i]) + + +def check_cut( + n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any +) -> None: + """Check the cut values.""" + from pandas.api.types import is_categorical_dtype + + assert data.shape[0] == indptr[-1] + assert data.shape[0] == n_entries + + assert indptr.dtype == np.uint64 + for i in range(1, indptr.size): + beg = int(indptr[i - 1]) + end = int(indptr[i]) + for j in range(beg + 1, end): + assert data[j] > data[j - 1] + if is_categorical_dtype(dtypes[i - 1]): + assert data[j] == data[j - 1] + 1 + + +def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None: + """Check with optional cupy.""" + from pandas.api.types import is_categorical_dtype + + n_samples = 1024 + n_features = 14 + max_bin = 16 + dtypes = [np.float32] * n_features + + # numerical + X, y, w = tm.make_regression(n_samples, n_features, use_cupy=use_cupy) + # - qdm + Xyw: xgb.DMatrix = xgb.QuantileDMatrix(X, y, weight=w, max_bin=max_bin) + indptr, data = Xyw.get_quantile_cut() + check_cut((max_bin + 1) * n_features, indptr, data, dtypes) + # - dm + Xyw = xgb.DMatrix(X, y, weight=w) + xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw) + indptr, data = Xyw.get_quantile_cut() + check_cut((max_bin + 1) * n_features, indptr, data, dtypes) + # - ext mem + n_batches = 3 + n_samples_per_batch = 256 + it = tm.IteratorForTest( + *tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy), + cache="cache", + ) + Xy: xgb.DMatrix = xgb.DMatrix(it) + xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw) + indptr, data = Xyw.get_quantile_cut() + check_cut((max_bin + 1) * n_features, indptr, data, dtypes) + + # categorical + n_categories = 32 + X, y = tm.make_categorical(n_samples, n_features, n_categories, False, sparsity=0.8) + if use_cupy: + import cudf # pylint: disable=import-error + import cupy as cp # pylint: disable=import-error + + X = cudf.from_pandas(X) + y = cp.array(y) + # - qdm + Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True) + indptr, data = Xy.get_quantile_cut() + check_cut(n_categories * n_features, indptr, data, X.dtypes) + # - dm + Xy = xgb.DMatrix(X, y, enable_categorical=True) + xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy) + indptr, data = Xy.get_quantile_cut() + check_cut(n_categories * n_features, indptr, data, X.dtypes) + + # mixed + X, y = tm.make_categorical( + n_samples, n_features, n_categories, False, sparsity=0.8, cat_ratio=0.5 + ) + n_cat_features = len([0 for dtype in X.dtypes if is_categorical_dtype(dtype)]) + n_num_features = n_features - n_cat_features + n_entries = n_categories * n_cat_features + (max_bin + 1) * n_num_features + # - qdm + Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True) + indptr, data = Xy.get_quantile_cut() + check_cut(n_entries, indptr, data, X.dtypes) + # - dm + Xy = xgb.DMatrix(X, y, enable_categorical=True) + xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy) + indptr, data = Xy.get_quantile_cut() + check_cut(n_entries, indptr, data, X.dtypes) + + +def check_get_quantile_cut(tree_method: str) -> None: + """Check the quantile cut getter.""" + + use_cupy = tree_method == "gpu_hist" + check_get_quantile_cut_device(tree_method, False) + if use_cupy: + check_get_quantile_cut_device(tree_method, True) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 06bd43b2b..4e1f86ff2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -3,7 +3,7 @@ */ #include "xgboost/c_api.h" -#include // for copy +#include // for copy, transform #include // for strtoimax #include // for nan #include // for strcmp @@ -20,9 +20,11 @@ #include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor... #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry #include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... +#include "../common/hist_util.h" // for HistogramCuts #include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor #include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte... +#include "../data/ellpack_page.h" // for EllpackPage #include "../data/proxy_dmatrix.h" // for DMatrixProxy #include "../data/simple_dmatrix.h" // for SimpleDMatrix #include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN @@ -785,6 +787,104 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config API_END(); } +namespace { +template +void GetCutImpl(Context const *ctx, std::shared_ptr p_m, + std::vector *p_indptr, std::vector *p_data) { + auto &indptr = *p_indptr; + auto &data = *p_data; + for (auto const &page : p_m->GetBatches(ctx, {})) { + auto const &cut = page.Cuts(); + + auto const &ptrs = cut.Ptrs(); + indptr.resize(ptrs.size()); + + auto const &vals = cut.Values(); + auto const &mins = cut.MinValues(); + + bst_feature_t n_features = p_m->Info().num_col_; + auto ft = p_m->Info().feature_types.ConstHostSpan(); + std::size_t n_categories = std::count_if(ft.cbegin(), ft.cend(), + [](auto t) { return t == FeatureType::kCategorical; }); + data.resize(vals.size() + n_features - n_categories); // |vals| + |mins| + std::size_t i{0}, n_numeric{0}; + for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { + CHECK_LT(i, data.size()); + bool is_numeric = !common::IsCat(ft, fidx); + if (is_numeric) { + data[i] = mins[fidx]; + i++; + } + auto beg = ptrs[fidx]; + auto end = ptrs[fidx + 1]; + CHECK_LE(end, data.size()); + std::copy(vals.cbegin() + beg, vals.cbegin() + end, data.begin() + i); + i += (end - beg); + // shift by min values. + indptr[fidx] = ptrs[fidx] + n_numeric; + if (is_numeric) { + n_numeric++; + } + } + CHECK_EQ(n_numeric, n_features - n_categories); + + indptr.back() = data.size(); + CHECK_EQ(indptr.back(), vals.size() + mins.size() - n_categories); + break; + } +} +} // namespace + +XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config, + char const **out_indptr, char const **out_data) { + API_BEGIN(); + CHECK_HANDLE(); + + auto p_m = CastDMatrixHandle(handle); + + xgboost_CHECK_C_ARG_PTR(config); + xgboost_CHECK_C_ARG_PTR(out_indptr); + xgboost_CHECK_C_ARG_PTR(out_data); + + auto jconfig = Json::Load(StringView{config}); + + if (!p_m->PageExists() && !p_m->PageExists()) { + LOG(FATAL) << "The quantile cut hasn't been generated yet. Unless this is a `QuantileDMatrix`, " + "quantile cut is generated during training."; + } + // Get return buffer + auto &data = p_m->GetThreadLocal().ret_vec_float; + auto &indptr = p_m->GetThreadLocal().ret_vec_u64; + + if (p_m->PageExists()) { + auto ctx = p_m->Ctx()->IsCPU() ? *p_m->Ctx() : p_m->Ctx()->MakeCPU(); + GetCutImpl(&ctx, p_m, &indptr, &data); + } else { + auto ctx = p_m->Ctx()->IsCUDA() ? *p_m->Ctx() : p_m->Ctx()->MakeCUDA(0); + GetCutImpl(&ctx, p_m, &indptr, &data); + } + + // Create a CPU context + Context ctx; + // Get return buffer + auto &ret_vec_str = p_m->GetThreadLocal().ret_vec_str; + ret_vec_str.clear(); + + ret_vec_str.emplace_back(linalg::ArrayInterfaceStr( + linalg::MakeTensorView(&ctx, common::Span{indptr.data(), indptr.size()}, indptr.size()))); + ret_vec_str.emplace_back(linalg::ArrayInterfaceStr( + linalg::MakeTensorView(&ctx, common::Span{data.data(), data.size()}, data.size()))); + + auto &charp_vecs = p_m->GetThreadLocal().ret_vec_charp; + charp_vecs.resize(ret_vec_str.size()); + std::transform(ret_vec_str.cbegin(), ret_vec_str.cend(), charp_vecs.begin(), + [](auto const &str) { return str.c_str(); }); + + *out_indptr = charp_vecs[0]; + *out_data = charp_vecs[1]; + API_END(); +} + // xgboost implementation XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], xgboost::bst_ulong len, diff --git a/src/common/api_entry.h b/src/common/api_entry.h index db3bcfbc3..df1fcd704 100644 --- a/src/common/api_entry.h +++ b/src/common/api_entry.h @@ -24,6 +24,8 @@ struct XGBAPIThreadLocalEntry { std::vector ret_vec_charp; /*! \brief returning float vector. */ std::vector ret_vec_float; + /*! \brief returning uint vector. */ + std::vector ret_vec_u64; /*! \brief temp variable of gradient pairs. */ std::vector tmp_gpair; /*! \brief Temp variable for returning prediction result. */ diff --git a/src/data/array_interface.h b/src/data/array_interface.h index bd66c2a53..99effffef 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -455,7 +455,7 @@ class ArrayInterface { explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {} - explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {} + explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {} void AssignType(StringView typestr) { using T = ArrayInterfaceHandler::Type; diff --git a/src/data/ellpack_page.cc b/src/data/ellpack_page.cc index 1fd8f12b2..59cfd1943 100644 --- a/src/data/ellpack_page.cc +++ b/src/data/ellpack_page.cc @@ -3,12 +3,20 @@ */ #ifndef XGBOOST_USE_CUDA +#include "ellpack_page.h" + #include // dummy implementation of EllpackPage in case CUDA is not used namespace xgboost { -class EllpackPageImpl {}; +class EllpackPageImpl { + common::HistogramCuts cuts_; + + public: + [[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; } + [[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; } +}; EllpackPage::EllpackPage() = default; @@ -32,6 +40,17 @@ size_t EllpackPage::Size() const { return 0; } +[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() { + LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " + "EllpackPage is required"; + return impl_->Cuts(); +} + +[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const { + LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but " + "EllpackPage is required"; + return impl_->Cuts(); +} } // namespace xgboost #endif // XGBOOST_USE_CUDA diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 13fcf9adf..0ccd7a081 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -4,6 +4,10 @@ #include #include +#include // for copy +#include // for move +#include // for vector + #include "../common/categorical.h" #include "../common/cuda_context.cuh" #include "../common/hist_util.cuh" @@ -11,6 +15,7 @@ #include "../common/transform_iterator.h" // MakeIndexTransformIter #include "./ellpack_page.cuh" #include "device_adapter.cuh" // for HasInfInData +#include "ellpack_page.h" #include "gradient_index.h" #include "xgboost/data.h" @@ -29,6 +34,16 @@ size_t EllpackPage::Size() const { return impl_->Size(); } void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); } +[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() { + CHECK(impl_); + return impl_->Cuts(); +} + +[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const { + CHECK(impl_); + return impl_->Cuts(); +} + // Bin each input data entry, store the bin indices in compressed form. __global__ void CompressBinEllpackKernel( common::CompressedBufferWriter wr, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index ee6a2c221..96963463b 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -1,17 +1,18 @@ -/*! - * Copyright 2019 by XGBoost Contributors +/** + * Copyright 2019-2023, XGBoost Contributors */ -#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_ -#define XGBOOST_DATA_ELLPACK_PAGE_H_ +#ifndef XGBOOST_DATA_ELLPACK_PAGE_CUH_ +#define XGBOOST_DATA_ELLPACK_PAGE_CUH_ +#include #include +#include "../common/categorical.h" #include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" -#include "../common/categorical.h" -#include +#include "ellpack_page.h" namespace xgboost { /** \brief Struct for accessing and manipulating an ELLPACK matrix on the @@ -194,8 +195,8 @@ class EllpackPageImpl { base_rowid = row_id; } - common::HistogramCuts& Cuts() { return cuts_; } - common::HistogramCuts const& Cuts() const { return cuts_; } + [[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; } + [[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; } /*! \return Estimation of memory cost of this page. */ static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ; @@ -256,4 +257,4 @@ inline size_t GetRowStride(DMatrix* dmat) { } } // namespace xgboost -#endif // XGBOOST_DATA_ELLPACK_PAGE_H_ +#endif // XGBOOST_DATA_ELLPACK_PAGE_CUH_ diff --git a/src/data/ellpack_page.h b/src/data/ellpack_page.h new file mode 100644 index 000000000..07d6949b1 --- /dev/null +++ b/src/data/ellpack_page.h @@ -0,0 +1,59 @@ +/** + * Copyright 2017-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_ +#define XGBOOST_DATA_ELLPACK_PAGE_H_ + +#include // for unique_ptr + +#include "../common/hist_util.h" // for HistogramCuts +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for DMatrix, BatchParam + +namespace xgboost { +class EllpackPageImpl; +/** + * @brief A page stored in ELLPACK format. + * + * This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid + * including CUDA-specific implementation details in the header. + */ +class EllpackPage { + public: + /** + * @brief Default constructor. + * + * This is used in the external memory case. An empty ELLPACK page is constructed with its content + * set later by the reader. + */ + EllpackPage(); + /** + * @brief Constructor from an existing DMatrix. + * + * This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix + * in CSR format. + */ + explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param); + + /*! \brief Destructor. */ + ~EllpackPage(); + + EllpackPage(EllpackPage&& that); + + /*! \return Number of instances in the page. */ + [[nodiscard]] size_t Size() const; + + /*! \brief Set the base row id for this page. */ + void SetBaseRowId(std::size_t row_id); + + [[nodiscard]] const EllpackPageImpl* Impl() const { return impl_.get(); } + EllpackPageImpl* Impl() { return impl_.get(); } + + [[nodiscard]] common::HistogramCuts& Cuts(); + [[nodiscard]] common::HistogramCuts const& Cuts() const; + + private: + std::unique_ptr impl_; +}; +} // namespace xgboost +#endif // XGBOOST_DATA_ELLPACK_PAGE_H_ diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index fb414f4ae..abfc400c1 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -5,10 +5,10 @@ #include #include "ellpack_page.cuh" +#include "ellpack_page.h" // for EllpackPage #include "ellpack_page_source.h" -namespace xgboost { -namespace data { +namespace xgboost::data { void EllpackPageSource::Fetch() { dh::safe_cuda(cudaSetDevice(device_)); if (!this->ReadCache()) { @@ -27,5 +27,4 @@ void EllpackPageSource::Fetch() { this->WriteCache(); } } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 121ffcf9e..146db94ed 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -6,17 +6,17 @@ #define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ #include + #include #include #include #include "../common/common.h" #include "../common/hist_util.h" +#include "ellpack_page.h" // for EllpackPage #include "sparse_page_source.h" -namespace xgboost { -namespace data { - +namespace xgboost::data { class EllpackPageSource : public PageSourceIncMixIn { bool is_dense_; size_t row_stride_; @@ -53,7 +53,6 @@ inline void EllpackPageSource::Fetch() { common::AssertGPUSupport(); } #endif // !defined(XGBOOST_USE_CUDA) -} // namespace data -} // namespace xgboost +} // namespace xgboost::data #endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 840be4b06..901451ad9 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -245,6 +245,9 @@ class GHistIndexMatrix { std::vector const& values, std::vector const& mins, bst_row_t ridx, bst_feature_t fidx, bool is_cat) const; + [[nodiscard]] common::HistogramCuts& Cuts() { return cut; } + [[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; } + private: std::unique_ptr columns_; std::vector hit_count_tloc_; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index b77c8fd84..5a2f6f8df 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -16,7 +16,8 @@ #include "../common/threading_utils.h" #include "./simple_batch_iterator.h" #include "adapter.h" -#include "batch_utils.h" // for CheckEmpty, RegenGHist +#include "batch_utils.h" // for CheckEmpty, RegenGHist +#include "ellpack_page.h" // for EllpackPage #include "gradient_index.h" #include "xgboost/c_api.h" #include "xgboost/data.h" diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index f84fa8c01..ec9c90b10 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -165,7 +165,10 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ctx, const BatchParam ¶m) { - CHECK_GE(param.max_bin, 2); + if (param.Initialized()) { + CHECK_GE(param.max_bin, 2); + } + detail::CheckEmpty(batch_param_, param); auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(ctx); if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 0a4cde43d..38304f725 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -1,6 +1,8 @@ /** * Copyright 2021-2023 by XGBoost contributors */ +#include + #include "../common/hist_util.cuh" #include "batch_utils.h" // for CheckEmpty, RegenGHist #include "ellpack_page.cuh" @@ -11,7 +13,9 @@ namespace xgboost::data { BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, const BatchParam& param) { CHECK(ctx->IsCUDA()); - CHECK_GE(param.max_bin, 2); + if (param.Initialized()) { + CHECK_GE(param.max_bin, 2); + } detail::CheckEmpty(batch_param_, param); auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); size_t row_stride = 0; @@ -21,8 +25,8 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(Context const* ctx, cache_info_.erase(id); MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); std::unique_ptr cuts; - cuts.reset( - new common::HistogramCuts{common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)}); + cuts = std::make_unique( + common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)); this->InitializeSparsePage(ctx); // reset after use. row_stride = GetRowStride(this); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9378bde20..e2a863e3d 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -21,6 +21,7 @@ #include "../common/io.h" #include "../common/timer.h" #include "../data/ellpack_page.cuh" +#include "../data/ellpack_page.h" #include "constraints.cuh" #include "driver.h" #include "gpu_hist/evaluate_splits.cuh" diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 675da940c..4e1b342ae 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -8,6 +8,7 @@ #include #include +#include // for array #include // std::size_t #include // std::numeric_limits #include // std::string @@ -15,6 +16,11 @@ #include "../../../src/c_api/c_api_error.h" #include "../../../src/common/io.h" +#include "../../../src/data/adapter.h" // for ArrayAdapter +#include "../../../src/data/array_interface.h" // for ArrayInterface +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix +#include "../../../src/data/sparse_page_dmatrix.h" // for SparsePageDMatrix #include "../helpers.h" TEST(CAPI, XGDMatrixCreateFromMatDT) { @@ -137,9 +143,9 @@ TEST(CAPI, ConfigIO) { BoosterHandle handle = learner.get(); learner->UpdateOneIter(0, p_dmat); - char const* out[1]; + std::array out; bst_ulong len {0}; - XGBoosterSaveJsonConfig(handle, &len, out); + XGBoosterSaveJsonConfig(handle, &len, out.data()); std::string config_str_0 { out[0] }; auto config_0 = Json::Load({config_str_0.c_str(), config_str_0.size()}); @@ -147,7 +153,7 @@ TEST(CAPI, ConfigIO) { bst_ulong len_1 {0}; std::string config_str_1 { out[0] }; - XGBoosterSaveJsonConfig(handle, &len_1, out); + XGBoosterSaveJsonConfig(handle, &len_1, out.data()); auto config_1 = Json::Load({config_str_1.c_str(), config_str_1.size()}); ASSERT_EQ(config_0, config_1); @@ -266,9 +272,9 @@ TEST(CAPI, DMatrixSetFeatureName) { ASSERT_EQ(std::to_string(i), c_out_features[i]); } - char const* feat_types [] {"i", "q"}; + std::array feat_types{"i", "q"}; static_assert(sizeof(feat_types) / sizeof(feat_types[0]) == kCols); - XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, kCols); + XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types.data(), kCols); char const **c_out_types; XGDMatrixGetStrFeatureInfo(handle, u8"feature_type", &out_len, &c_out_types); @@ -405,4 +411,210 @@ TEST(CAPI, JArgs) { ASSERT_THROW({ RequiredArg(args, "null", __func__); }, dmlc::Error); } } + +namespace { +void MakeLabelForTest(std::shared_ptr Xy, DMatrixHandle cxy) { + auto n_samples = Xy->Info().num_row_; + std::vector y(n_samples); + for (std::size_t i = 0; i < y.size(); ++i) { + y[i] = static_cast(i); + } + + Xy->Info().labels.Reshape(n_samples); + Xy->Info().labels.Data()->HostVector() = y; + + auto y_int = GetArrayInterface(Xy->Info().labels.Data(), n_samples, 1); + std::string s_y_int; + Json::Dump(y_int, &s_y_int); + + XGDMatrixSetInfoFromInterface(cxy, "label", s_y_int.c_str()); +} + +auto MakeSimpleDMatrixForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) { + HostDeviceVector storage; + auto arr_int = RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateArrayInterface(&storage); + + data::ArrayAdapter adapter{StringView{arr_int}}; + std::shared_ptr Xy{ + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), Context{}.Threads())}; + + DMatrixHandle p_fmat; + std::string s_dconfig; + Json::Dump(dconfig, &s_dconfig); + CHECK_EQ(XGDMatrixCreateFromDense(arr_int.c_str(), s_dconfig.c_str(), &p_fmat), 0); + + MakeLabelForTest(Xy, p_fmat); + return std::pair{p_fmat, Xy}; +} + +auto MakeQDMForTest(Context const *ctx, bst_row_t n_samples, bst_feature_t n_features, + Json dconfig) { + bst_bin_t n_bins{16}; + dconfig["max_bin"] = Integer{n_bins}; + + std::size_t n_batches{4}; + std::unique_ptr iter_0; + if (ctx->IsCUDA()) { + iter_0 = std::make_unique(0.0f, n_samples, n_features, n_batches); + } else { + iter_0 = std::make_unique(0.0f, n_samples, n_features, n_batches); + } + std::string s_dconfig; + Json::Dump(dconfig, &s_dconfig); + DMatrixHandle p_fmat; + CHECK_EQ(XGQuantileDMatrixCreateFromCallback(static_cast(iter_0.get()), + iter_0->Proxy(), nullptr, Reset, Next, + s_dconfig.c_str(), &p_fmat), + 0); + + std::unique_ptr iter_1; + if (ctx->IsCUDA()) { + iter_1 = std::make_unique(0.0f, n_samples, n_features, n_batches); + } else { + iter_1 = std::make_unique(0.0f, n_samples, n_features, n_batches); + } + auto Xy = + std::make_shared(iter_1.get(), iter_1->Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, n_bins); + return std::pair{p_fmat, Xy}; +} + +auto MakeExtMemForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) { + std::size_t n_batches{4}; + NumpyArrayIterForTest iter_0{0.0f, n_samples, n_features, n_batches}; + std::string s_dconfig; + dconfig["cache_prefix"] = String{"cache"}; + Json::Dump(dconfig, &s_dconfig); + DMatrixHandle p_fmat; + CHECK_EQ(XGDMatrixCreateFromCallback(static_cast(&iter_0), iter_0.Proxy(), Reset, + Next, s_dconfig.c_str(), &p_fmat), + 0); + + NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches}; + auto Xy = std::make_shared( + &iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits::quiet_NaN(), 0, ""); + MakeLabelForTest(Xy, p_fmat); + return std::pair{p_fmat, Xy}; +} + +template +void CheckResult(Context const *ctx, bst_feature_t n_features, std::shared_ptr Xy, + float const *out_data, std::uint64_t const *out_indptr) { + for (auto const &page : Xy->GetBatches(ctx, BatchParam{16, 0.2})) { + auto const &cut = page.Cuts(); + auto const &ptrs = cut.Ptrs(); + auto const &vals = cut.Values(); + auto const &mins = cut.MinValues(); + for (bst_feature_t f = 0; f < Xy->Info().num_col_; ++f) { + ASSERT_EQ(ptrs[f] + f, out_indptr[f]); + ASSERT_EQ(mins[f], out_data[out_indptr[f]]); + auto beg = out_indptr[f]; + auto end = out_indptr[f + 1]; + auto val_beg = ptrs[f]; + for (std::uint64_t i = beg + 1, j = val_beg; i < end; ++i, ++j) { + ASSERT_EQ(vals[j], out_data[i]); + } + } + + ASSERT_EQ(ptrs[n_features] + n_features, out_indptr[n_features]); + } +} + +void TestXGDMatrixGetQuantileCut(Context const *ctx) { + bst_row_t n_samples{1024}; + bst_feature_t n_features{16}; + + Json dconfig{Object{}}; + dconfig["ntread"] = Integer{Context{}.Threads()}; + dconfig["missing"] = Number{std::numeric_limits::quiet_NaN()}; + + auto check_result = [n_features, &ctx](std::shared_ptr Xy, StringView s_out_data, + StringView s_out_indptr) { + auto i_out_data = ArrayInterface<1, false>{s_out_data}; + ASSERT_EQ(i_out_data.type, ArrayInterfaceHandler::kF4); + auto out_data = static_cast(i_out_data.data); + ASSERT_TRUE(out_data); + + auto i_out_indptr = ArrayInterface<1, false>{s_out_indptr}; + ASSERT_EQ(i_out_indptr.type, ArrayInterfaceHandler::kU8); + auto out_indptr = static_cast(i_out_indptr.data); + ASSERT_TRUE(out_data); + + if (ctx->IsCPU()) { + CheckResult(ctx, n_features, Xy, out_data, out_indptr); + } else { + CheckResult(ctx, n_features, Xy, out_data, out_indptr); + } + }; + + Json config{Null{}}; + std::string s_config; + Json::Dump(config, &s_config); + char const *out_indptr; + char const *out_data; + + { + // SimpleDMatrix + auto [p_fmat, Xy] = MakeSimpleDMatrixForTest(n_samples, n_features, dconfig); + // assert fail, we don't have the quantile yet. + ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1); + + std::array mats{p_fmat}; + BoosterHandle booster; + ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); + ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); + if (ctx->IsCUDA()) { + ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); + } + ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); + ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); + + check_result(Xy, out_data, out_indptr); + + XGDMatrixFree(p_fmat); + XGBoosterFree(booster); + } + + { + // IterativeDMatrix + auto [p_fmat, Xy] = MakeQDMForTest(ctx, n_samples, n_features, dconfig); + ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); + + check_result(Xy, out_data, out_indptr); + XGDMatrixFree(p_fmat); + } + + { + // SparsePageDMatrix + auto [p_fmat, Xy] = MakeExtMemForTest(n_samples, n_features, dconfig); + // assert fail, we don't have the quantile yet. + ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1); + + std::array mats{p_fmat}; + BoosterHandle booster; + ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0); + ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0); + if (ctx->IsCUDA()) { + ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0); + } + ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0); + ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0); + + XGDMatrixFree(p_fmat); + XGBoosterFree(booster); + } +} +} // namespace + +TEST(CAPI, XGDMatrixGetQuantileCut) { + Context ctx; + TestXGDMatrixGetQuantileCut(&ctx); +} + +#if defined(XGBOOST_USE_CUDA) +TEST(CAPI, GPUXGDMatrixGetQuantileCut) { + auto ctx = MakeCUDACtx(0); + TestXGDMatrixGetQuantileCut(&ctx); +} +#endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index d56f1c7b5..4b279a1a4 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -8,6 +8,7 @@ #include "../../../src/common/categorical.h" #include "../../../src/common/hist_util.h" #include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/ellpack_page.h" #include "../../../src/tree/param.h" // TrainParam #include "../helpers.h" #include "../histogram_helpers.h" diff --git a/tests/cpp/data/test_iterative_dmatrix.cu b/tests/cpp/data/test_iterative_dmatrix.cu index 2f2f1f84f..6b856f3fa 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cu +++ b/tests/cpp/data/test_iterative_dmatrix.cu @@ -5,6 +5,7 @@ #include "../../../src/data/device_adapter.cuh" #include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/ellpack_page.h" #include "../../../src/data/iterative_dmatrix.h" #include "../../../src/tree/param.h" // TrainParam #include "../helpers.h" diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index f2f828507..17ed64c90 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -5,6 +5,7 @@ #include "../../../src/common/compressed_iterator.h" #include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/ellpack_page.h" #include "../../../src/data/sparse_page_dmatrix.h" #include "../../../src/tree/param.h" // TrainParam #include "../filesystem.h" // dmlc::TemporaryDirectory diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 4f44b7b1e..111c7b30e 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -472,6 +472,18 @@ std::shared_ptr RandomDataGenerator::GenerateQuantileDMatrix(bool with_ return m; } +#if !defined(XGBOOST_USE_CUDA) +CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) + : ArrayIterForTest{sparsity, rows, cols, batches} { + common::AssertGPUSupport(); +} + +int CudaArrayIterForTest::Next() { + common::AssertGPUSupport(); + return 0; +} +#endif // !defined(XGBOOST_USE_CUDA) + NumpyArrayIterForTest::NumpyArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) : ArrayIterForTest{sparsity, rows, cols, batches} { @@ -650,7 +662,7 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} { XGProxyDMatrixCreate(&proxy_); - rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity}); + rng_ = std::make_unique(rows_, cols_, sparsity); std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_); } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index fd3034db5..b250cd2ab 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -11,6 +11,8 @@ #include #include "../../../src/common/common.h" +#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl +#include "../../../src/data/ellpack_page.h" // for EllpackPage #include "../../../src/data/sparse_page_source.h" #include "../../../src/tree/constraints.cuh" #include "../../../src/tree/param.h" // for TrainParam diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 70080b13a..71667fa7b 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -1,3 +1,4 @@ +import json import sys import numpy as np @@ -10,6 +11,16 @@ from test_dmatrix import set_base_margin_info from xgboost import testing as tm +cupy = pytest.importorskip("cupy") + + +def test_array_interface() -> None: + arr = cupy.array([[1, 2, 3, 4], [1, 2, 3, 4]]) + i_arr = arr.__cuda_array_interface__ + i_arr = json.loads(json.dumps(i_arr)) + ret = xgb.core.from_array_interface(i_arr) + np.testing.assert_equal(cupy.asnumpy(arr), cupy.asnumpy(ret)) + def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN): '''Test constructing DMatrix from cupy''' diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a6b183daf..7fea42f60 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -8,7 +8,11 @@ from hypothesis import assume, given, note, settings, strategies import xgboost as xgb from xgboost import testing as tm from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy -from xgboost.testing.updater import check_init_estimation, check_quantile_loss +from xgboost.testing.updater import ( + check_get_quantile_cut, + check_init_estimation, + check_quantile_loss, +) sys.path.append("tests/python") import test_updaters as test_up @@ -264,3 +268,7 @@ class TestGPUUpdaters: }, num_boost_round=150, ) + + @pytest.mark.skipif(**tm.no_cudf()) + def test_get_quantile_cut(self) -> None: + check_get_quantile_cut("gpu_hist") diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index 095c9936a..2027942fe 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -14,7 +14,11 @@ from xgboost.testing.params import ( hist_multi_parameter_strategy, hist_parameter_strategy, ) -from xgboost.testing.updater import check_init_estimation, check_quantile_loss +from xgboost.testing.updater import ( + check_get_quantile_cut, + check_init_estimation, + check_quantile_loss, +) def train_result(param, dmat, num_rounds): @@ -537,3 +541,8 @@ class TestTreeMethod: @pytest.mark.parametrize("weighted", [True, False]) def test_quantile_loss(self, weighted: bool) -> None: check_quantile_loss("hist", weighted) + + @pytest.mark.skipif(**tm.no_pandas()) + @pytest.mark.parametrize("tree_method", ["hist"]) + def test_get_quantile_cut(self, tree_method: str) -> None: + check_get_quantile_cut(tree_method)