Support exporting cut values (#9356)

This commit is contained in:
Jiaming Yuan 2023-07-08 15:32:41 +08:00 committed by GitHub
parent c3124813e8
commit 20c52f07d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 722 additions and 101 deletions

View File

@ -810,7 +810,7 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
*/
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
/*!
/**
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
* quantized DMatrix, quantized values are returned instead.
*
@ -819,8 +819,10 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
* exiting the thread.
*
* @since 1.7.0
*
* \param handle the handle to the DMatrix
* \param config Json configuration string. At the moment it should be an empty document,
* \param config JSON configuration string. At the moment it should be an empty document,
* preserved for future use.
* \param out_indptr indptr of output CSR matrix.
* \param out_indices Column index of output CSR matrix.
@ -831,6 +833,24 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
/**
* @brief Export the quantile cuts used for training histogram-based models like `hist` and
* `approx`. Useful for model compression.
*
* @since 2.0.0
*
* @param handle the handle to the DMatrix
* @param config JSON configuration string. At the moment it should be an empty document,
* preserved for future use.
*
* @param out_indptr indptr of output CSC matrix represented by a JSON encoded
* __(cuda_)array_interface__.
* @param out_data Data value of CSC matrix represented by a JSON encoded
* __(cuda_)array_interface__.
*/
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
char const **out_indptr, char const **out_data);
/** @} */ // End of DMatrix
/**

View File

@ -282,7 +282,7 @@ struct BatchParam {
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
bool ParamNotEqual(BatchParam const& other) const {
[[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {
// Check non-floating parameters.
bool cond = max_bin != other.max_bin;
// Check sparse thresh.
@ -293,11 +293,11 @@ struct BatchParam {
return cond;
}
bool Initialized() const { return max_bin != 0; }
[[nodiscard]] bool Initialized() const { return max_bin != 0; }
/**
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
*/
BatchParam MakeCache() const {
[[nodiscard]] BatchParam MakeCache() const {
auto p = *this;
// These parameters have nothing to do with how the gradient index was generated in the
// first place.
@ -319,7 +319,7 @@ struct HostSparsePageView {
static_cast<Inst::index_type>(size)};
}
size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
[[nodiscard]] size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
};
/*!
@ -337,7 +337,7 @@ class SparsePage {
/*! \brief an instance of sparse vector in the batch */
using Inst = common::Span<Entry const>;
HostSparsePageView GetView() const {
[[nodiscard]] HostSparsePageView GetView() const {
return {offset.ConstHostSpan(), data.ConstHostSpan()};
}
@ -353,12 +353,12 @@ class SparsePage {
virtual ~SparsePage() = default;
/*! \return Number of instances in the page. */
inline size_t Size() const {
[[nodiscard]] size_t Size() const {
return offset.Size() == 0 ? 0 : offset.Size() - 1;
}
/*! \return estimation of memory cost of this page */
inline size_t MemCostBytes() const {
[[nodiscard]] size_t MemCostBytes() const {
return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
}
@ -376,7 +376,7 @@ class SparsePage {
base_rowid = row_id;
}
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
[[nodiscard]] SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
/**
* \brief Sort the column index.
@ -385,7 +385,7 @@ class SparsePage {
/**
* \brief Check wether the column index is sorted.
*/
bool IsIndicesSorted(int32_t n_threads) const;
[[nodiscard]] bool IsIndicesSorted(int32_t n_threads) const;
/**
* \brief Reindex the column index with an offset.
*/
@ -440,49 +440,7 @@ class SortedCSCPage : public SparsePage {
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
class EllpackPageImpl;
/*!
* \brief A page stored in ELLPACK format.
*
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
* including CUDA-specific implementation details in the header.
*/
class EllpackPage {
public:
/*!
* \brief Default constructor.
*
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
* set later by the reader.
*/
EllpackPage();
/*!
* \brief Constructor from an existing DMatrix.
*
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
EllpackPage(EllpackPage&& that);
/*! \return Number of instances in the page. */
size_t Size() const;
/*! \brief Set the base row id for this page. */
void SetBaseRowId(std::size_t row_id);
const EllpackPageImpl* Impl() const { return impl_.get(); }
EllpackPageImpl* Impl() { return impl_.get(); }
private:
std::unique_ptr<EllpackPageImpl> impl_;
};
class EllpackPage;
class GHistIndexMatrix;
template<typename T>
@ -492,7 +450,7 @@ class BatchIteratorImpl {
virtual ~BatchIteratorImpl() = default;
virtual const T& operator*() const = 0;
virtual BatchIteratorImpl& operator++() = 0;
virtual bool AtEnd() const = 0;
[[nodiscard]] virtual bool AtEnd() const = 0;
virtual std::shared_ptr<T const> Page() const = 0;
};
@ -519,12 +477,12 @@ class BatchIterator {
return !impl_->AtEnd();
}
bool AtEnd() const {
[[nodiscard]] bool AtEnd() const {
CHECK(impl_ != nullptr);
return impl_->AtEnd();
}
std::shared_ptr<T const> Page() const {
[[nodiscard]] std::shared_ptr<T const> Page() const {
return impl_->Page();
}
@ -563,15 +521,15 @@ class DMatrix {
this->Info().SetInfo(ctx, key, StringView{interface_str});
}
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
[[nodiscard]] virtual const MetaInfo& Info() const = 0;
/*! \brief Get thread local memory for returning data from DMatrix. */
XGBAPIThreadLocalEntry& GetThreadLocal() const;
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
/**
* \brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter.
*/
virtual Context const* Ctx() const = 0;
[[nodiscard]] virtual Context const* Ctx() const = 0;
/**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
@ -583,16 +541,16 @@ class DMatrix {
template <typename T>
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
template <typename T>
bool PageExists() const;
[[nodiscard]] bool PageExists() const;
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
[[nodiscard]] virtual bool SingleColBlock() const = 0;
/*! \brief virtual destructor */
virtual ~DMatrix();
/*! \brief Whether the matrix is dense. */
bool IsDense() const {
[[nodiscard]] bool IsDense() const {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}
@ -695,9 +653,9 @@ class DMatrix {
BatchParam const& param) = 0;
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
[[nodiscard]] virtual bool EllpackExists() const = 0;
[[nodiscard]] virtual bool GHistIndexExists() const = 0;
[[nodiscard]] virtual bool SparsePageExists() const = 0;
};
template <>

View File

@ -3,6 +3,7 @@
"""Core XGBoost Library."""
import copy
import ctypes
import importlib.util
import json
import os
import re
@ -381,6 +382,54 @@ def c_array(
return (ctype * len(values))(*values)
def from_array_interface(interface: dict) -> NumpyOrCupy:
"""Convert array interface to numpy or cupy array"""
class Array: # pylint: disable=too-few-public-methods
"""Wrapper type for communicating with numpy and cupy."""
_interface: Optional[dict] = None
@property
def __array_interface__(self) -> Optional[dict]:
return self._interface
@__array_interface__.setter
def __array_interface__(self, interface: dict) -> None:
self._interface = copy.copy(interface)
# converts some fields to tuple as required by numpy
self._interface["shape"] = tuple(self._interface["shape"])
self._interface["data"] = tuple(self._interface["data"])
if self._interface.get("strides", None) is not None:
self._interface["strides"] = tuple(self._interface["strides"])
@property
def __cuda_array_interface__(self) -> Optional[dict]:
return self.__array_interface__
@__cuda_array_interface__.setter
def __cuda_array_interface__(self, interface: dict) -> None:
self.__array_interface__ = interface
arr = Array()
if "stream" in interface:
# CUDA stream is presented, this is a __cuda_array_interface__.
spec = importlib.util.find_spec("cupy")
if spec is None:
raise ImportError("`cupy` is required for handling CUDA buffer.")
import cupy as cp # pylint: disable=import-error
arr.__cuda_array_interface__ = interface
out = cp.array(arr, copy=True)
else:
arr.__array_interface__ = interface
out = np.array(arr, copy=True)
return out
def _prediction_output(
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
) -> NumpyOrCupy:
@ -1060,6 +1109,32 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
)
return ret
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get quantile cuts for quantization."""
n_features = self.num_col()
c_sindptr = ctypes.c_char_p()
c_sdata = ctypes.c_char_p()
config = make_jcargs()
_check_call(
_LIB.XGDMatrixGetQuantileCut(
self.handle, config, ctypes.byref(c_sindptr), ctypes.byref(c_sdata)
)
)
assert c_sindptr.value is not None
assert c_sdata.value is not None
i_indptr = json.loads(c_sindptr.value)
indptr = from_array_interface(i_indptr)
assert indptr.size == n_features + 1
assert indptr.dtype == np.uint64
i_data = json.loads(c_sdata.value)
data = from_array_interface(i_data)
assert data.size == indptr[-1]
assert data.dtype == np.float32
return indptr, data
def num_row(self) -> int:
"""Get the number of rows in the DMatrix."""
ret = c_bst_ulong()

View File

@ -265,6 +265,14 @@ def make_batches(
return X, y, w
def make_regression(
n_samples: int, n_features: int, use_cupy: bool
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
"""Make a simple regression dataset."""
X, y, w = make_batches(n_samples, n_features, 1, use_cupy)
return X[0], y[0], w[0]
def make_batches_sparse(
n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float
) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]:

View File

@ -1,7 +1,7 @@
"""Tests for updaters."""
import json
from functools import partial, update_wrapper
from typing import Dict
from typing import Any, Dict
import numpy as np
@ -159,3 +159,100 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
for i in range(alpha.shape[0]):
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
def check_cut(
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
) -> None:
"""Check the cut values."""
from pandas.api.types import is_categorical_dtype
assert data.shape[0] == indptr[-1]
assert data.shape[0] == n_entries
assert indptr.dtype == np.uint64
for i in range(1, indptr.size):
beg = int(indptr[i - 1])
end = int(indptr[i])
for j in range(beg + 1, end):
assert data[j] > data[j - 1]
if is_categorical_dtype(dtypes[i - 1]):
assert data[j] == data[j - 1] + 1
def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
"""Check with optional cupy."""
from pandas.api.types import is_categorical_dtype
n_samples = 1024
n_features = 14
max_bin = 16
dtypes = [np.float32] * n_features
# numerical
X, y, w = tm.make_regression(n_samples, n_features, use_cupy=use_cupy)
# - qdm
Xyw: xgb.DMatrix = xgb.QuantileDMatrix(X, y, weight=w, max_bin=max_bin)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
# - dm
Xyw = xgb.DMatrix(X, y, weight=w)
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
# - ext mem
n_batches = 3
n_samples_per_batch = 256
it = tm.IteratorForTest(
*tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
cache="cache",
)
Xy: xgb.DMatrix = xgb.DMatrix(it)
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
indptr, data = Xyw.get_quantile_cut()
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
# categorical
n_categories = 32
X, y = tm.make_categorical(n_samples, n_features, n_categories, False, sparsity=0.8)
if use_cupy:
import cudf # pylint: disable=import-error
import cupy as cp # pylint: disable=import-error
X = cudf.from_pandas(X)
y = cp.array(y)
# - qdm
Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True)
indptr, data = Xy.get_quantile_cut()
check_cut(n_categories * n_features, indptr, data, X.dtypes)
# - dm
Xy = xgb.DMatrix(X, y, enable_categorical=True)
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy)
indptr, data = Xy.get_quantile_cut()
check_cut(n_categories * n_features, indptr, data, X.dtypes)
# mixed
X, y = tm.make_categorical(
n_samples, n_features, n_categories, False, sparsity=0.8, cat_ratio=0.5
)
n_cat_features = len([0 for dtype in X.dtypes if is_categorical_dtype(dtype)])
n_num_features = n_features - n_cat_features
n_entries = n_categories * n_cat_features + (max_bin + 1) * n_num_features
# - qdm
Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True)
indptr, data = Xy.get_quantile_cut()
check_cut(n_entries, indptr, data, X.dtypes)
# - dm
Xy = xgb.DMatrix(X, y, enable_categorical=True)
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy)
indptr, data = Xy.get_quantile_cut()
check_cut(n_entries, indptr, data, X.dtypes)
def check_get_quantile_cut(tree_method: str) -> None:
"""Check the quantile cut getter."""
use_cupy = tree_method == "gpu_hist"
check_get_quantile_cut_device(tree_method, False)
if use_cupy:
check_get_quantile_cut_device(tree_method, True)

View File

@ -3,7 +3,7 @@
*/
#include "xgboost/c_api.h"
#include <algorithm> // for copy
#include <algorithm> // for copy, transform
#include <cinttypes> // for strtoimax
#include <cmath> // for nan
#include <cstring> // for strcmp
@ -20,9 +20,11 @@
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
#include "../data/ellpack_page.h" // for EllpackPage
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
@ -785,6 +787,104 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
API_END();
}
namespace {
template <typename Page>
void GetCutImpl(Context const *ctx, std::shared_ptr<DMatrix> p_m,
std::vector<std::uint64_t> *p_indptr, std::vector<float> *p_data) {
auto &indptr = *p_indptr;
auto &data = *p_data;
for (auto const &page : p_m->GetBatches<Page>(ctx, {})) {
auto const &cut = page.Cuts();
auto const &ptrs = cut.Ptrs();
indptr.resize(ptrs.size());
auto const &vals = cut.Values();
auto const &mins = cut.MinValues();
bst_feature_t n_features = p_m->Info().num_col_;
auto ft = p_m->Info().feature_types.ConstHostSpan();
std::size_t n_categories = std::count_if(ft.cbegin(), ft.cend(),
[](auto t) { return t == FeatureType::kCategorical; });
data.resize(vals.size() + n_features - n_categories); // |vals| + |mins|
std::size_t i{0}, n_numeric{0};
for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) {
CHECK_LT(i, data.size());
bool is_numeric = !common::IsCat(ft, fidx);
if (is_numeric) {
data[i] = mins[fidx];
i++;
}
auto beg = ptrs[fidx];
auto end = ptrs[fidx + 1];
CHECK_LE(end, data.size());
std::copy(vals.cbegin() + beg, vals.cbegin() + end, data.begin() + i);
i += (end - beg);
// shift by min values.
indptr[fidx] = ptrs[fidx] + n_numeric;
if (is_numeric) {
n_numeric++;
}
}
CHECK_EQ(n_numeric, n_features - n_categories);
indptr.back() = data.size();
CHECK_EQ(indptr.back(), vals.size() + mins.size() - n_categories);
break;
}
}
} // namespace
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
char const **out_indptr, char const **out_data) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(config);
xgboost_CHECK_C_ARG_PTR(out_indptr);
xgboost_CHECK_C_ARG_PTR(out_data);
auto jconfig = Json::Load(StringView{config});
if (!p_m->PageExists<GHistIndexMatrix>() && !p_m->PageExists<EllpackPage>()) {
LOG(FATAL) << "The quantile cut hasn't been generated yet. Unless this is a `QuantileDMatrix`, "
"quantile cut is generated during training.";
}
// Get return buffer
auto &data = p_m->GetThreadLocal().ret_vec_float;
auto &indptr = p_m->GetThreadLocal().ret_vec_u64;
if (p_m->PageExists<GHistIndexMatrix>()) {
auto ctx = p_m->Ctx()->IsCPU() ? *p_m->Ctx() : p_m->Ctx()->MakeCPU();
GetCutImpl<GHistIndexMatrix>(&ctx, p_m, &indptr, &data);
} else {
auto ctx = p_m->Ctx()->IsCUDA() ? *p_m->Ctx() : p_m->Ctx()->MakeCUDA(0);
GetCutImpl<EllpackPage>(&ctx, p_m, &indptr, &data);
}
// Create a CPU context
Context ctx;
// Get return buffer
auto &ret_vec_str = p_m->GetThreadLocal().ret_vec_str;
ret_vec_str.clear();
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
linalg::MakeTensorView(&ctx, common::Span{indptr.data(), indptr.size()}, indptr.size())));
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
linalg::MakeTensorView(&ctx, common::Span{data.data(), data.size()}, data.size())));
auto &charp_vecs = p_m->GetThreadLocal().ret_vec_charp;
charp_vecs.resize(ret_vec_str.size());
std::transform(ret_vec_str.cbegin(), ret_vec_str.cend(), charp_vecs.begin(),
[](auto const &str) { return str.c_str(); });
*out_indptr = charp_vecs[0];
*out_data = charp_vecs[1];
API_END();
}
// xgboost implementation
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,

View File

@ -24,6 +24,8 @@ struct XGBAPIThreadLocalEntry {
std::vector<const char *> ret_vec_charp;
/*! \brief returning float vector. */
std::vector<float> ret_vec_float;
/*! \brief returning uint vector. */
std::vector<std::uint64_t> ret_vec_u64;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returning prediction result. */

View File

@ -455,7 +455,7 @@ class ArrayInterface {
explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {}
explicit ArrayInterface(StringView str) : ArrayInterface<D>{Json::Load(str)} {}
explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {}
void AssignType(StringView typestr) {
using T = ArrayInterfaceHandler::Type;

View File

@ -3,12 +3,20 @@
*/
#ifndef XGBOOST_USE_CUDA
#include "ellpack_page.h"
#include <xgboost/data.h>
// dummy implementation of EllpackPage in case CUDA is not used
namespace xgboost {
class EllpackPageImpl {};
class EllpackPageImpl {
common::HistogramCuts cuts_;
public:
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
};
EllpackPage::EllpackPage() = default;
@ -32,6 +40,17 @@ size_t EllpackPage::Size() const {
return 0;
}
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
"EllpackPage is required";
return impl_->Cuts();
}
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
"EllpackPage is required";
return impl_->Cuts();
}
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

View File

@ -4,6 +4,10 @@
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <algorithm> // for copy
#include <utility> // for move
#include <vector> // for vector
#include "../common/categorical.h"
#include "../common/cuda_context.cuh"
#include "../common/hist_util.cuh"
@ -11,6 +15,7 @@
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "./ellpack_page.cuh"
#include "device_adapter.cuh" // for HasInfInData
#include "ellpack_page.h"
#include "gradient_index.h"
#include "xgboost/data.h"
@ -29,6 +34,16 @@ size_t EllpackPage::Size() const { return impl_->Size(); }
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
CHECK(impl_);
return impl_->Cuts();
}
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
CHECK(impl_);
return impl_->Cuts();
}
// Bin each input data entry, store the bin indices in compressed form.
__global__ void CompressBinEllpackKernel(
common::CompressedBufferWriter wr,

View File

@ -1,17 +1,18 @@
/*!
* Copyright 2019 by XGBoost Contributors
/**
* Copyright 2019-2023, XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_H_
#ifndef XGBOOST_DATA_ELLPACK_PAGE_CUH_
#define XGBOOST_DATA_ELLPACK_PAGE_CUH_
#include <thrust/binary_search.h>
#include <xgboost/data.h>
#include "../common/categorical.h"
#include "../common/compressed_iterator.h"
#include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "../common/categorical.h"
#include <thrust/binary_search.h>
#include "ellpack_page.h"
namespace xgboost {
/** \brief Struct for accessing and manipulating an ELLPACK matrix on the
@ -194,8 +195,8 @@ class EllpackPageImpl {
base_rowid = row_id;
}
common::HistogramCuts& Cuts() { return cuts_; }
common::HistogramCuts const& Cuts() const { return cuts_; }
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
/*! \return Estimation of memory cost of this page. */
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
@ -256,4 +257,4 @@ inline size_t GetRowStride(DMatrix* dmat) {
}
} // namespace xgboost
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_
#endif // XGBOOST_DATA_ELLPACK_PAGE_CUH_

59
src/data/ellpack_page.h Normal file
View File

@ -0,0 +1,59 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_H_
#include <memory> // for unique_ptr
#include "../common/hist_util.h" // for HistogramCuts
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for DMatrix, BatchParam
namespace xgboost {
class EllpackPageImpl;
/**
* @brief A page stored in ELLPACK format.
*
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
* including CUDA-specific implementation details in the header.
*/
class EllpackPage {
public:
/**
* @brief Default constructor.
*
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
* set later by the reader.
*/
EllpackPage();
/**
* @brief Constructor from an existing DMatrix.
*
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
* in CSR format.
*/
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
/*! \brief Destructor. */
~EllpackPage();
EllpackPage(EllpackPage&& that);
/*! \return Number of instances in the page. */
[[nodiscard]] size_t Size() const;
/*! \brief Set the base row id for this page. */
void SetBaseRowId(std::size_t row_id);
[[nodiscard]] const EllpackPageImpl* Impl() const { return impl_.get(); }
EllpackPageImpl* Impl() { return impl_.get(); }
[[nodiscard]] common::HistogramCuts& Cuts();
[[nodiscard]] common::HistogramCuts const& Cuts() const;
private:
std::unique_ptr<EllpackPageImpl> impl_;
};
} // namespace xgboost
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_

View File

@ -5,10 +5,10 @@
#include <utility>
#include "ellpack_page.cuh"
#include "ellpack_page.h" // for EllpackPage
#include "ellpack_page_source.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(device_));
if (!this->ReadCache()) {
@ -27,5 +27,4 @@ void EllpackPageSource::Fetch() {
this->WriteCache();
}
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data

View File

@ -6,17 +6,17 @@
#define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
#include <xgboost/data.h>
#include <memory>
#include <string>
#include <utility>
#include "../common/common.h"
#include "../common/hist_util.h"
#include "ellpack_page.h" // for EllpackPage
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
bool is_dense_;
size_t row_stride_;
@ -53,7 +53,6 @@ inline void EllpackPageSource::Fetch() {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_

View File

@ -245,6 +245,9 @@ class GHistIndexMatrix {
std::vector<float> const& values, std::vector<float> const& mins,
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const;
[[nodiscard]] common::HistogramCuts& Cuts() { return cut; }
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; }
private:
std::unique_ptr<common::ColumnMatrix> columns_;
std::vector<size_t> hit_count_tloc_;

View File

@ -16,7 +16,8 @@
#include "../common/threading_utils.h"
#include "./simple_batch_iterator.h"
#include "adapter.h"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.h" // for EllpackPage
#include "gradient_index.h"
#include "xgboost/c_api.h"
#include "xgboost/data.h"

View File

@ -165,7 +165,10 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
const BatchParam &param) {
CHECK_GE(param.max_bin, 2);
if (param.Initialized()) {
CHECK_GE(param.max_bin, 2);
}
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage(ctx);
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {

View File

@ -1,6 +1,8 @@
/**
* Copyright 2021-2023 by XGBoost contributors
*/
#include <memory>
#include "../common/hist_util.cuh"
#include "batch_utils.h" // for CheckEmpty, RegenGHist
#include "ellpack_page.cuh"
@ -11,7 +13,9 @@ namespace xgboost::data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
const BatchParam& param) {
CHECK(ctx->IsCUDA());
CHECK_GE(param.max_bin, 2);
if (param.Initialized()) {
CHECK_GE(param.max_bin, 2);
}
detail::CheckEmpty(batch_param_, param);
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0;
@ -21,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
cache_info_.erase(id);
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
std::unique_ptr<common::HistogramCuts> cuts;
cuts.reset(
new common::HistogramCuts{common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)});
cuts = std::make_unique<common::HistogramCuts>(
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0));
this->InitializeSparsePage(ctx); // reset after use.
row_stride = GetRowStride(this);

View File

@ -21,6 +21,7 @@
#include "../common/io.h"
#include "../common/timer.h"
#include "../data/ellpack_page.cuh"
#include "../data/ellpack_page.h"
#include "constraints.cuh"
#include "driver.h"
#include "gpu_hist/evaluate_splits.cuh"

View File

@ -8,6 +8,7 @@
#include <xgboost/learner.h>
#include <xgboost/version_config.h>
#include <array> // for array
#include <cstddef> // std::size_t
#include <limits> // std::numeric_limits
#include <string> // std::string
@ -15,6 +16,11 @@
#include "../../../src/c_api/c_api_error.h"
#include "../../../src/common/io.h"
#include "../../../src/data/adapter.h" // for ArrayAdapter
#include "../../../src/data/array_interface.h" // for ArrayInterface
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
#include "../../../src/data/sparse_page_dmatrix.h" // for SparsePageDMatrix
#include "../helpers.h"
TEST(CAPI, XGDMatrixCreateFromMatDT) {
@ -137,9 +143,9 @@ TEST(CAPI, ConfigIO) {
BoosterHandle handle = learner.get();
learner->UpdateOneIter(0, p_dmat);
char const* out[1];
std::array<char const* , 1> out;
bst_ulong len {0};
XGBoosterSaveJsonConfig(handle, &len, out);
XGBoosterSaveJsonConfig(handle, &len, out.data());
std::string config_str_0 { out[0] };
auto config_0 = Json::Load({config_str_0.c_str(), config_str_0.size()});
@ -147,7 +153,7 @@ TEST(CAPI, ConfigIO) {
bst_ulong len_1 {0};
std::string config_str_1 { out[0] };
XGBoosterSaveJsonConfig(handle, &len_1, out);
XGBoosterSaveJsonConfig(handle, &len_1, out.data());
auto config_1 = Json::Load({config_str_1.c_str(), config_str_1.size()});
ASSERT_EQ(config_0, config_1);
@ -266,9 +272,9 @@ TEST(CAPI, DMatrixSetFeatureName) {
ASSERT_EQ(std::to_string(i), c_out_features[i]);
}
char const* feat_types [] {"i", "q"};
std::array<char const *, 2> feat_types{"i", "q"};
static_assert(sizeof(feat_types) / sizeof(feat_types[0]) == kCols);
XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, kCols);
XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types.data(), kCols);
char const **c_out_types;
XGDMatrixGetStrFeatureInfo(handle, u8"feature_type", &out_len,
&c_out_types);
@ -405,4 +411,210 @@ TEST(CAPI, JArgs) {
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
}
}
namespace {
void MakeLabelForTest(std::shared_ptr<DMatrix> Xy, DMatrixHandle cxy) {
auto n_samples = Xy->Info().num_row_;
std::vector<float> y(n_samples);
for (std::size_t i = 0; i < y.size(); ++i) {
y[i] = static_cast<float>(i);
}
Xy->Info().labels.Reshape(n_samples);
Xy->Info().labels.Data()->HostVector() = y;
auto y_int = GetArrayInterface(Xy->Info().labels.Data(), n_samples, 1);
std::string s_y_int;
Json::Dump(y_int, &s_y_int);
XGDMatrixSetInfoFromInterface(cxy, "label", s_y_int.c_str());
}
auto MakeSimpleDMatrixForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) {
HostDeviceVector<float> storage;
auto arr_int = RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateArrayInterface(&storage);
data::ArrayAdapter adapter{StringView{arr_int}};
std::shared_ptr<DMatrix> Xy{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads())};
DMatrixHandle p_fmat;
std::string s_dconfig;
Json::Dump(dconfig, &s_dconfig);
CHECK_EQ(XGDMatrixCreateFromDense(arr_int.c_str(), s_dconfig.c_str(), &p_fmat), 0);
MakeLabelForTest(Xy, p_fmat);
return std::pair{p_fmat, Xy};
}
auto MakeQDMForTest(Context const *ctx, bst_row_t n_samples, bst_feature_t n_features,
Json dconfig) {
bst_bin_t n_bins{16};
dconfig["max_bin"] = Integer{n_bins};
std::size_t n_batches{4};
std::unique_ptr<ArrayIterForTest> iter_0;
if (ctx->IsCUDA()) {
iter_0 = std::make_unique<CudaArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
} else {
iter_0 = std::make_unique<NumpyArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
}
std::string s_dconfig;
Json::Dump(dconfig, &s_dconfig);
DMatrixHandle p_fmat;
CHECK_EQ(XGQuantileDMatrixCreateFromCallback(static_cast<DataIterHandle>(iter_0.get()),
iter_0->Proxy(), nullptr, Reset, Next,
s_dconfig.c_str(), &p_fmat),
0);
std::unique_ptr<ArrayIterForTest> iter_1;
if (ctx->IsCUDA()) {
iter_1 = std::make_unique<CudaArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
} else {
iter_1 = std::make_unique<NumpyArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
}
auto Xy =
std::make_shared<data::IterativeDMatrix>(iter_1.get(), iter_1->Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
return std::pair{p_fmat, Xy};
}
auto MakeExtMemForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) {
std::size_t n_batches{4};
NumpyArrayIterForTest iter_0{0.0f, n_samples, n_features, n_batches};
std::string s_dconfig;
dconfig["cache_prefix"] = String{"cache"};
Json::Dump(dconfig, &s_dconfig);
DMatrixHandle p_fmat;
CHECK_EQ(XGDMatrixCreateFromCallback(static_cast<DataIterHandle>(&iter_0), iter_0.Proxy(), Reset,
Next, s_dconfig.c_str(), &p_fmat),
0);
NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches};
auto Xy = std::make_shared<data::SparsePageDMatrix>(
&iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, "");
MakeLabelForTest(Xy, p_fmat);
return std::pair{p_fmat, Xy};
}
template <typename Page>
void CheckResult(Context const *ctx, bst_feature_t n_features, std::shared_ptr<DMatrix> Xy,
float const *out_data, std::uint64_t const *out_indptr) {
for (auto const &page : Xy->GetBatches<Page>(ctx, BatchParam{16, 0.2})) {
auto const &cut = page.Cuts();
auto const &ptrs = cut.Ptrs();
auto const &vals = cut.Values();
auto const &mins = cut.MinValues();
for (bst_feature_t f = 0; f < Xy->Info().num_col_; ++f) {
ASSERT_EQ(ptrs[f] + f, out_indptr[f]);
ASSERT_EQ(mins[f], out_data[out_indptr[f]]);
auto beg = out_indptr[f];
auto end = out_indptr[f + 1];
auto val_beg = ptrs[f];
for (std::uint64_t i = beg + 1, j = val_beg; i < end; ++i, ++j) {
ASSERT_EQ(vals[j], out_data[i]);
}
}
ASSERT_EQ(ptrs[n_features] + n_features, out_indptr[n_features]);
}
}
void TestXGDMatrixGetQuantileCut(Context const *ctx) {
bst_row_t n_samples{1024};
bst_feature_t n_features{16};
Json dconfig{Object{}};
dconfig["ntread"] = Integer{Context{}.Threads()};
dconfig["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
auto check_result = [n_features, &ctx](std::shared_ptr<DMatrix> Xy, StringView s_out_data,
StringView s_out_indptr) {
auto i_out_data = ArrayInterface<1, false>{s_out_data};
ASSERT_EQ(i_out_data.type, ArrayInterfaceHandler::kF4);
auto out_data = static_cast<float const *>(i_out_data.data);
ASSERT_TRUE(out_data);
auto i_out_indptr = ArrayInterface<1, false>{s_out_indptr};
ASSERT_EQ(i_out_indptr.type, ArrayInterfaceHandler::kU8);
auto out_indptr = static_cast<std::uint64_t const *>(i_out_indptr.data);
ASSERT_TRUE(out_data);
if (ctx->IsCPU()) {
CheckResult<GHistIndexMatrix>(ctx, n_features, Xy, out_data, out_indptr);
} else {
CheckResult<EllpackPage>(ctx, n_features, Xy, out_data, out_indptr);
}
};
Json config{Null{}};
std::string s_config;
Json::Dump(config, &s_config);
char const *out_indptr;
char const *out_data;
{
// SimpleDMatrix
auto [p_fmat, Xy] = MakeSimpleDMatrixForTest(n_samples, n_features, dconfig);
// assert fail, we don't have the quantile yet.
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1);
std::array<DMatrixHandle, 1> mats{p_fmat};
BoosterHandle booster;
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
if (ctx->IsCUDA()) {
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0);
}
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
check_result(Xy, out_data, out_indptr);
XGDMatrixFree(p_fmat);
XGBoosterFree(booster);
}
{
// IterativeDMatrix
auto [p_fmat, Xy] = MakeQDMForTest(ctx, n_samples, n_features, dconfig);
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
check_result(Xy, out_data, out_indptr);
XGDMatrixFree(p_fmat);
}
{
// SparsePageDMatrix
auto [p_fmat, Xy] = MakeExtMemForTest(n_samples, n_features, dconfig);
// assert fail, we don't have the quantile yet.
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1);
std::array<DMatrixHandle, 1> mats{p_fmat};
BoosterHandle booster;
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
if (ctx->IsCUDA()) {
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0);
}
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
XGDMatrixFree(p_fmat);
XGBoosterFree(booster);
}
}
} // namespace
TEST(CAPI, XGDMatrixGetQuantileCut) {
Context ctx;
TestXGDMatrixGetQuantileCut(&ctx);
}
#if defined(XGBOOST_USE_CUDA)
TEST(CAPI, GPUXGDMatrixGetQuantileCut) {
auto ctx = MakeCUDACtx(0);
TestXGDMatrixGetQuantileCut(&ctx);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

View File

@ -8,6 +8,7 @@
#include "../../../src/common/categorical.h"
#include "../../../src/common/hist_util.h"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/ellpack_page.h"
#include "../../../src/tree/param.h" // TrainParam
#include "../helpers.h"
#include "../histogram_helpers.h"

View File

@ -5,6 +5,7 @@
#include "../../../src/data/device_adapter.cuh"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/ellpack_page.h"
#include "../../../src/data/iterative_dmatrix.h"
#include "../../../src/tree/param.h" // TrainParam
#include "../helpers.h"

View File

@ -5,6 +5,7 @@
#include "../../../src/common/compressed_iterator.h"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/ellpack_page.h"
#include "../../../src/data/sparse_page_dmatrix.h"
#include "../../../src/tree/param.h" // TrainParam
#include "../filesystem.h" // dmlc::TemporaryDirectory

View File

@ -472,6 +472,18 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_
return m;
}
#if !defined(XGBOOST_USE_CUDA)
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
: ArrayIterForTest{sparsity, rows, cols, batches} {
common::AssertGPUSupport();
}
int CudaArrayIterForTest::Next() {
common::AssertGPUSupport();
return 0;
}
#endif // !defined(XGBOOST_USE_CUDA)
NumpyArrayIterForTest::NumpyArrayIterForTest(float sparsity, size_t rows, size_t cols,
size_t batches)
: ArrayIterForTest{sparsity, rows, cols, batches} {
@ -650,7 +662,7 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
: rows_{rows}, cols_{cols}, n_batches_{batches} {
XGProxyDMatrixCreate(&proxy_);
rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity});
rng_ = std::make_unique<RandomDataGenerator>(rows_, cols_, sparsity);
std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
}

View File

@ -11,6 +11,8 @@
#include <vector>
#include "../../../src/common/common.h"
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
#include "../../../src/data/ellpack_page.h" // for EllpackPage
#include "../../../src/data/sparse_page_source.h"
#include "../../../src/tree/constraints.cuh"
#include "../../../src/tree/param.h" // for TrainParam

View File

@ -1,3 +1,4 @@
import json
import sys
import numpy as np
@ -10,6 +11,16 @@ from test_dmatrix import set_base_margin_info
from xgboost import testing as tm
cupy = pytest.importorskip("cupy")
def test_array_interface() -> None:
arr = cupy.array([[1, 2, 3, 4], [1, 2, 3, 4]])
i_arr = arr.__cuda_array_interface__
i_arr = json.loads(json.dumps(i_arr))
ret = xgb.core.from_array_interface(i_arr)
np.testing.assert_equal(cupy.asnumpy(arr), cupy.asnumpy(ret))
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
'''Test constructing DMatrix from cupy'''

View File

@ -8,7 +8,11 @@ from hypothesis import assume, given, note, settings, strategies
import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
from xgboost.testing.updater import (
check_get_quantile_cut,
check_init_estimation,
check_quantile_loss,
)
sys.path.append("tests/python")
import test_updaters as test_up
@ -264,3 +268,7 @@ class TestGPUUpdaters:
},
num_boost_round=150,
)
@pytest.mark.skipif(**tm.no_cudf())
def test_get_quantile_cut(self) -> None:
check_get_quantile_cut("gpu_hist")

View File

@ -14,7 +14,11 @@ from xgboost.testing.params import (
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
from xgboost.testing.updater import (
check_get_quantile_cut,
check_init_estimation,
check_quantile_loss,
)
def train_result(param, dmat, num_rounds):
@ -537,3 +541,8 @@ class TestTreeMethod:
@pytest.mark.parametrize("weighted", [True, False])
def test_quantile_loss(self, weighted: bool) -> None:
check_quantile_loss("hist", weighted)
@pytest.mark.skipif(**tm.no_pandas())
@pytest.mark.parametrize("tree_method", ["hist"])
def test_get_quantile_cut(self, tree_method: str) -> None:
check_get_quantile_cut(tree_method)