Support exporting cut values (#9356)
This commit is contained in:
parent
c3124813e8
commit
20c52f07d2
@ -810,7 +810,7 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle,
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
|
||||
/*!
|
||||
/**
|
||||
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
|
||||
* quantized DMatrix, quantized values are returned instead.
|
||||
*
|
||||
@ -819,8 +819,10 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
* XGBoost. This is to avoid allocating a huge memory buffer that can not be freed until
|
||||
* exiting the thread.
|
||||
*
|
||||
* @since 1.7.0
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param config Json configuration string. At the moment it should be an empty document,
|
||||
* \param config JSON configuration string. At the moment it should be an empty document,
|
||||
* preserved for future use.
|
||||
* \param out_indptr indptr of output CSR matrix.
|
||||
* \param out_indices Column index of output CSR matrix.
|
||||
@ -831,6 +833,24 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
bst_ulong *out_indptr, unsigned *out_indices, float *out_data);
|
||||
|
||||
/**
|
||||
* @brief Export the quantile cuts used for training histogram-based models like `hist` and
|
||||
* `approx`. Useful for model compression.
|
||||
*
|
||||
* @since 2.0.0
|
||||
*
|
||||
* @param handle the handle to the DMatrix
|
||||
* @param config JSON configuration string. At the moment it should be an empty document,
|
||||
* preserved for future use.
|
||||
*
|
||||
* @param out_indptr indptr of output CSC matrix represented by a JSON encoded
|
||||
* __(cuda_)array_interface__.
|
||||
* @param out_data Data value of CSC matrix represented by a JSON encoded
|
||||
* __(cuda_)array_interface__.
|
||||
*/
|
||||
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
|
||||
char const **out_indptr, char const **out_data);
|
||||
|
||||
/** @} */ // End of DMatrix
|
||||
|
||||
/**
|
||||
|
||||
@ -282,7 +282,7 @@ struct BatchParam {
|
||||
BatchParam(bst_bin_t max_bin, common::Span<float> hessian, bool regenerate)
|
||||
: max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
|
||||
|
||||
bool ParamNotEqual(BatchParam const& other) const {
|
||||
[[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {
|
||||
// Check non-floating parameters.
|
||||
bool cond = max_bin != other.max_bin;
|
||||
// Check sparse thresh.
|
||||
@ -293,11 +293,11 @@ struct BatchParam {
|
||||
|
||||
return cond;
|
||||
}
|
||||
bool Initialized() const { return max_bin != 0; }
|
||||
[[nodiscard]] bool Initialized() const { return max_bin != 0; }
|
||||
/**
|
||||
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
|
||||
*/
|
||||
BatchParam MakeCache() const {
|
||||
[[nodiscard]] BatchParam MakeCache() const {
|
||||
auto p = *this;
|
||||
// These parameters have nothing to do with how the gradient index was generated in the
|
||||
// first place.
|
||||
@ -319,7 +319,7 @@ struct HostSparsePageView {
|
||||
static_cast<Inst::index_type>(size)};
|
||||
}
|
||||
|
||||
size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
|
||||
[[nodiscard]] size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -337,7 +337,7 @@ class SparsePage {
|
||||
/*! \brief an instance of sparse vector in the batch */
|
||||
using Inst = common::Span<Entry const>;
|
||||
|
||||
HostSparsePageView GetView() const {
|
||||
[[nodiscard]] HostSparsePageView GetView() const {
|
||||
return {offset.ConstHostSpan(), data.ConstHostSpan()};
|
||||
}
|
||||
|
||||
@ -353,12 +353,12 @@ class SparsePage {
|
||||
virtual ~SparsePage() = default;
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
inline size_t Size() const {
|
||||
[[nodiscard]] size_t Size() const {
|
||||
return offset.Size() == 0 ? 0 : offset.Size() - 1;
|
||||
}
|
||||
|
||||
/*! \return estimation of memory cost of this page */
|
||||
inline size_t MemCostBytes() const {
|
||||
[[nodiscard]] size_t MemCostBytes() const {
|
||||
return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
|
||||
}
|
||||
|
||||
@ -376,7 +376,7 @@ class SparsePage {
|
||||
base_rowid = row_id;
|
||||
}
|
||||
|
||||
SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
|
||||
[[nodiscard]] SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
|
||||
|
||||
/**
|
||||
* \brief Sort the column index.
|
||||
@ -385,7 +385,7 @@ class SparsePage {
|
||||
/**
|
||||
* \brief Check wether the column index is sorted.
|
||||
*/
|
||||
bool IsIndicesSorted(int32_t n_threads) const;
|
||||
[[nodiscard]] bool IsIndicesSorted(int32_t n_threads) const;
|
||||
/**
|
||||
* \brief Reindex the column index with an offset.
|
||||
*/
|
||||
@ -440,49 +440,7 @@ class SortedCSCPage : public SparsePage {
|
||||
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
class EllpackPageImpl;
|
||||
/*!
|
||||
* \brief A page stored in ELLPACK format.
|
||||
*
|
||||
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
|
||||
* including CUDA-specific implementation details in the header.
|
||||
*/
|
||||
class EllpackPage {
|
||||
public:
|
||||
/*!
|
||||
* \brief Default constructor.
|
||||
*
|
||||
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
|
||||
* set later by the reader.
|
||||
*/
|
||||
EllpackPage();
|
||||
|
||||
/*!
|
||||
* \brief Constructor from an existing DMatrix.
|
||||
*
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
|
||||
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
|
||||
EllpackPage(EllpackPage&& that);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
size_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(std::size_t row_id);
|
||||
|
||||
const EllpackPageImpl* Impl() const { return impl_.get(); }
|
||||
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<EllpackPageImpl> impl_;
|
||||
};
|
||||
|
||||
class EllpackPage;
|
||||
class GHistIndexMatrix;
|
||||
|
||||
template<typename T>
|
||||
@ -492,7 +450,7 @@ class BatchIteratorImpl {
|
||||
virtual ~BatchIteratorImpl() = default;
|
||||
virtual const T& operator*() const = 0;
|
||||
virtual BatchIteratorImpl& operator++() = 0;
|
||||
virtual bool AtEnd() const = 0;
|
||||
[[nodiscard]] virtual bool AtEnd() const = 0;
|
||||
virtual std::shared_ptr<T const> Page() const = 0;
|
||||
};
|
||||
|
||||
@ -519,12 +477,12 @@ class BatchIterator {
|
||||
return !impl_->AtEnd();
|
||||
}
|
||||
|
||||
bool AtEnd() const {
|
||||
[[nodiscard]] bool AtEnd() const {
|
||||
CHECK(impl_ != nullptr);
|
||||
return impl_->AtEnd();
|
||||
}
|
||||
|
||||
std::shared_ptr<T const> Page() const {
|
||||
[[nodiscard]] std::shared_ptr<T const> Page() const {
|
||||
return impl_->Page();
|
||||
}
|
||||
|
||||
@ -563,15 +521,15 @@ class DMatrix {
|
||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||
}
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual const MetaInfo& Info() const = 0;
|
||||
[[nodiscard]] virtual const MetaInfo& Info() const = 0;
|
||||
|
||||
/*! \brief Get thread local memory for returning data from DMatrix. */
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const;
|
||||
[[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
|
||||
/**
|
||||
* \brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* DMatrix with user specified `nthread` parameter.
|
||||
*/
|
||||
virtual Context const* Ctx() const = 0;
|
||||
[[nodiscard]] virtual Context const* Ctx() const = 0;
|
||||
|
||||
/**
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
@ -583,16 +541,16 @@ class DMatrix {
|
||||
template <typename T>
|
||||
BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
|
||||
template <typename T>
|
||||
bool PageExists() const;
|
||||
[[nodiscard]] bool PageExists() const;
|
||||
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
[[nodiscard]] virtual bool SingleColBlock() const = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~DMatrix();
|
||||
|
||||
/*! \brief Whether the matrix is dense. */
|
||||
bool IsDense() const {
|
||||
[[nodiscard]] bool IsDense() const {
|
||||
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
|
||||
}
|
||||
|
||||
@ -695,9 +653,9 @@ class DMatrix {
|
||||
BatchParam const& param) = 0;
|
||||
virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
|
||||
|
||||
virtual bool EllpackExists() const = 0;
|
||||
virtual bool GHistIndexExists() const = 0;
|
||||
virtual bool SparsePageExists() const = 0;
|
||||
[[nodiscard]] virtual bool EllpackExists() const = 0;
|
||||
[[nodiscard]] virtual bool GHistIndexExists() const = 0;
|
||||
[[nodiscard]] virtual bool SparsePageExists() const = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
"""Core XGBoost Library."""
|
||||
import copy
|
||||
import ctypes
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -381,6 +382,54 @@ def c_array(
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def from_array_interface(interface: dict) -> NumpyOrCupy:
|
||||
"""Convert array interface to numpy or cupy array"""
|
||||
|
||||
class Array: # pylint: disable=too-few-public-methods
|
||||
"""Wrapper type for communicating with numpy and cupy."""
|
||||
|
||||
_interface: Optional[dict] = None
|
||||
|
||||
@property
|
||||
def __array_interface__(self) -> Optional[dict]:
|
||||
return self._interface
|
||||
|
||||
@__array_interface__.setter
|
||||
def __array_interface__(self, interface: dict) -> None:
|
||||
self._interface = copy.copy(interface)
|
||||
# converts some fields to tuple as required by numpy
|
||||
self._interface["shape"] = tuple(self._interface["shape"])
|
||||
self._interface["data"] = tuple(self._interface["data"])
|
||||
if self._interface.get("strides", None) is not None:
|
||||
self._interface["strides"] = tuple(self._interface["strides"])
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self) -> Optional[dict]:
|
||||
return self.__array_interface__
|
||||
|
||||
@__cuda_array_interface__.setter
|
||||
def __cuda_array_interface__(self, interface: dict) -> None:
|
||||
self.__array_interface__ = interface
|
||||
|
||||
arr = Array()
|
||||
|
||||
if "stream" in interface:
|
||||
# CUDA stream is presented, this is a __cuda_array_interface__.
|
||||
spec = importlib.util.find_spec("cupy")
|
||||
if spec is None:
|
||||
raise ImportError("`cupy` is required for handling CUDA buffer.")
|
||||
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
arr.__cuda_array_interface__ = interface
|
||||
out = cp.array(arr, copy=True)
|
||||
else:
|
||||
arr.__array_interface__ = interface
|
||||
out = np.array(arr, copy=True)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _prediction_output(
|
||||
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
|
||||
) -> NumpyOrCupy:
|
||||
@ -1060,6 +1109,32 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_quantile_cut(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get quantile cuts for quantization."""
|
||||
n_features = self.num_col()
|
||||
|
||||
c_sindptr = ctypes.c_char_p()
|
||||
c_sdata = ctypes.c_char_p()
|
||||
config = make_jcargs()
|
||||
_check_call(
|
||||
_LIB.XGDMatrixGetQuantileCut(
|
||||
self.handle, config, ctypes.byref(c_sindptr), ctypes.byref(c_sdata)
|
||||
)
|
||||
)
|
||||
assert c_sindptr.value is not None
|
||||
assert c_sdata.value is not None
|
||||
|
||||
i_indptr = json.loads(c_sindptr.value)
|
||||
indptr = from_array_interface(i_indptr)
|
||||
assert indptr.size == n_features + 1
|
||||
assert indptr.dtype == np.uint64
|
||||
|
||||
i_data = json.loads(c_sdata.value)
|
||||
data = from_array_interface(i_data)
|
||||
assert data.size == indptr[-1]
|
||||
assert data.dtype == np.float32
|
||||
return indptr, data
|
||||
|
||||
def num_row(self) -> int:
|
||||
"""Get the number of rows in the DMatrix."""
|
||||
ret = c_bst_ulong()
|
||||
|
||||
@ -265,6 +265,14 @@ def make_batches(
|
||||
return X, y, w
|
||||
|
||||
|
||||
def make_regression(
|
||||
n_samples: int, n_features: int, use_cupy: bool
|
||||
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
|
||||
"""Make a simple regression dataset."""
|
||||
X, y, w = make_batches(n_samples, n_features, 1, use_cupy)
|
||||
return X[0], y[0], w[0]
|
||||
|
||||
|
||||
def make_batches_sparse(
|
||||
n_samples_per_batch: int, n_features: int, n_batches: int, sparsity: float
|
||||
) -> Tuple[List[sparse.csr_matrix], List[np.ndarray], List[np.ndarray]]:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Tests for updaters."""
|
||||
import json
|
||||
from functools import partial, update_wrapper
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -159,3 +159,100 @@ def check_quantile_loss(tree_method: str, weighted: bool) -> None:
|
||||
|
||||
for i in range(alpha.shape[0]):
|
||||
np.testing.assert_allclose(predts[:, i], predt_multi[:, i])
|
||||
|
||||
|
||||
def check_cut(
|
||||
n_entries: int, indptr: np.ndarray, data: np.ndarray, dtypes: Any
|
||||
) -> None:
|
||||
"""Check the cut values."""
|
||||
from pandas.api.types import is_categorical_dtype
|
||||
|
||||
assert data.shape[0] == indptr[-1]
|
||||
assert data.shape[0] == n_entries
|
||||
|
||||
assert indptr.dtype == np.uint64
|
||||
for i in range(1, indptr.size):
|
||||
beg = int(indptr[i - 1])
|
||||
end = int(indptr[i])
|
||||
for j in range(beg + 1, end):
|
||||
assert data[j] > data[j - 1]
|
||||
if is_categorical_dtype(dtypes[i - 1]):
|
||||
assert data[j] == data[j - 1] + 1
|
||||
|
||||
|
||||
def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
|
||||
"""Check with optional cupy."""
|
||||
from pandas.api.types import is_categorical_dtype
|
||||
|
||||
n_samples = 1024
|
||||
n_features = 14
|
||||
max_bin = 16
|
||||
dtypes = [np.float32] * n_features
|
||||
|
||||
# numerical
|
||||
X, y, w = tm.make_regression(n_samples, n_features, use_cupy=use_cupy)
|
||||
# - qdm
|
||||
Xyw: xgb.DMatrix = xgb.QuantileDMatrix(X, y, weight=w, max_bin=max_bin)
|
||||
indptr, data = Xyw.get_quantile_cut()
|
||||
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
|
||||
# - dm
|
||||
Xyw = xgb.DMatrix(X, y, weight=w)
|
||||
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
|
||||
indptr, data = Xyw.get_quantile_cut()
|
||||
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
|
||||
# - ext mem
|
||||
n_batches = 3
|
||||
n_samples_per_batch = 256
|
||||
it = tm.IteratorForTest(
|
||||
*tm.make_batches(n_samples_per_batch, n_features, n_batches, use_cupy),
|
||||
cache="cache",
|
||||
)
|
||||
Xy: xgb.DMatrix = xgb.DMatrix(it)
|
||||
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xyw)
|
||||
indptr, data = Xyw.get_quantile_cut()
|
||||
check_cut((max_bin + 1) * n_features, indptr, data, dtypes)
|
||||
|
||||
# categorical
|
||||
n_categories = 32
|
||||
X, y = tm.make_categorical(n_samples, n_features, n_categories, False, sparsity=0.8)
|
||||
if use_cupy:
|
||||
import cudf # pylint: disable=import-error
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
X = cudf.from_pandas(X)
|
||||
y = cp.array(y)
|
||||
# - qdm
|
||||
Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True)
|
||||
indptr, data = Xy.get_quantile_cut()
|
||||
check_cut(n_categories * n_features, indptr, data, X.dtypes)
|
||||
# - dm
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy)
|
||||
indptr, data = Xy.get_quantile_cut()
|
||||
check_cut(n_categories * n_features, indptr, data, X.dtypes)
|
||||
|
||||
# mixed
|
||||
X, y = tm.make_categorical(
|
||||
n_samples, n_features, n_categories, False, sparsity=0.8, cat_ratio=0.5
|
||||
)
|
||||
n_cat_features = len([0 for dtype in X.dtypes if is_categorical_dtype(dtype)])
|
||||
n_num_features = n_features - n_cat_features
|
||||
n_entries = n_categories * n_cat_features + (max_bin + 1) * n_num_features
|
||||
# - qdm
|
||||
Xy = xgb.QuantileDMatrix(X, y, max_bin=max_bin, enable_categorical=True)
|
||||
indptr, data = Xy.get_quantile_cut()
|
||||
check_cut(n_entries, indptr, data, X.dtypes)
|
||||
# - dm
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
xgb.train({"tree_method": tree_method, "max_bin": max_bin}, Xy)
|
||||
indptr, data = Xy.get_quantile_cut()
|
||||
check_cut(n_entries, indptr, data, X.dtypes)
|
||||
|
||||
|
||||
def check_get_quantile_cut(tree_method: str) -> None:
|
||||
"""Check the quantile cut getter."""
|
||||
|
||||
use_cupy = tree_method == "gpu_hist"
|
||||
check_get_quantile_cut_device(tree_method, False)
|
||||
if use_cupy:
|
||||
check_get_quantile_cut_device(tree_method, True)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
*/
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <algorithm> // for copy, transform
|
||||
#include <cinttypes> // for strtoimax
|
||||
#include <cmath> // for nan
|
||||
#include <cstring> // for strcmp
|
||||
@ -20,9 +20,11 @@
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor...
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch...
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf...
|
||||
#include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor
|
||||
#include "../data/adapter.h" // for ArrayAdapter, DenseAdapter, RecordBatchesIte...
|
||||
#include "../data/ellpack_page.h" // for EllpackPage
|
||||
#include "../data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../data/simple_dmatrix.h" // for SimpleDMatrix
|
||||
#include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN
|
||||
@ -785,6 +787,104 @@ XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config
|
||||
API_END();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename Page>
|
||||
void GetCutImpl(Context const *ctx, std::shared_ptr<DMatrix> p_m,
|
||||
std::vector<std::uint64_t> *p_indptr, std::vector<float> *p_data) {
|
||||
auto &indptr = *p_indptr;
|
||||
auto &data = *p_data;
|
||||
for (auto const &page : p_m->GetBatches<Page>(ctx, {})) {
|
||||
auto const &cut = page.Cuts();
|
||||
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
indptr.resize(ptrs.size());
|
||||
|
||||
auto const &vals = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
|
||||
bst_feature_t n_features = p_m->Info().num_col_;
|
||||
auto ft = p_m->Info().feature_types.ConstHostSpan();
|
||||
std::size_t n_categories = std::count_if(ft.cbegin(), ft.cend(),
|
||||
[](auto t) { return t == FeatureType::kCategorical; });
|
||||
data.resize(vals.size() + n_features - n_categories); // |vals| + |mins|
|
||||
std::size_t i{0}, n_numeric{0};
|
||||
for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) {
|
||||
CHECK_LT(i, data.size());
|
||||
bool is_numeric = !common::IsCat(ft, fidx);
|
||||
if (is_numeric) {
|
||||
data[i] = mins[fidx];
|
||||
i++;
|
||||
}
|
||||
auto beg = ptrs[fidx];
|
||||
auto end = ptrs[fidx + 1];
|
||||
CHECK_LE(end, data.size());
|
||||
std::copy(vals.cbegin() + beg, vals.cbegin() + end, data.begin() + i);
|
||||
i += (end - beg);
|
||||
// shift by min values.
|
||||
indptr[fidx] = ptrs[fidx] + n_numeric;
|
||||
if (is_numeric) {
|
||||
n_numeric++;
|
||||
}
|
||||
}
|
||||
CHECK_EQ(n_numeric, n_features - n_categories);
|
||||
|
||||
indptr.back() = data.size();
|
||||
CHECK_EQ(indptr.back(), vals.size() + mins.size() - n_categories);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XGB_DLL int XGDMatrixGetQuantileCut(DMatrixHandle const handle, char const *config,
|
||||
char const **out_indptr, char const **out_data) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
xgboost_CHECK_C_ARG_PTR(out_indptr);
|
||||
xgboost_CHECK_C_ARG_PTR(out_data);
|
||||
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
|
||||
if (!p_m->PageExists<GHistIndexMatrix>() && !p_m->PageExists<EllpackPage>()) {
|
||||
LOG(FATAL) << "The quantile cut hasn't been generated yet. Unless this is a `QuantileDMatrix`, "
|
||||
"quantile cut is generated during training.";
|
||||
}
|
||||
// Get return buffer
|
||||
auto &data = p_m->GetThreadLocal().ret_vec_float;
|
||||
auto &indptr = p_m->GetThreadLocal().ret_vec_u64;
|
||||
|
||||
if (p_m->PageExists<GHistIndexMatrix>()) {
|
||||
auto ctx = p_m->Ctx()->IsCPU() ? *p_m->Ctx() : p_m->Ctx()->MakeCPU();
|
||||
GetCutImpl<GHistIndexMatrix>(&ctx, p_m, &indptr, &data);
|
||||
} else {
|
||||
auto ctx = p_m->Ctx()->IsCUDA() ? *p_m->Ctx() : p_m->Ctx()->MakeCUDA(0);
|
||||
GetCutImpl<EllpackPage>(&ctx, p_m, &indptr, &data);
|
||||
}
|
||||
|
||||
// Create a CPU context
|
||||
Context ctx;
|
||||
// Get return buffer
|
||||
auto &ret_vec_str = p_m->GetThreadLocal().ret_vec_str;
|
||||
ret_vec_str.clear();
|
||||
|
||||
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
|
||||
linalg::MakeTensorView(&ctx, common::Span{indptr.data(), indptr.size()}, indptr.size())));
|
||||
ret_vec_str.emplace_back(linalg::ArrayInterfaceStr(
|
||||
linalg::MakeTensorView(&ctx, common::Span{data.data(), data.size()}, data.size())));
|
||||
|
||||
auto &charp_vecs = p_m->GetThreadLocal().ret_vec_charp;
|
||||
charp_vecs.resize(ret_vec_str.size());
|
||||
std::transform(ret_vec_str.cbegin(), ret_vec_str.cend(), charp_vecs.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
|
||||
*out_indptr = charp_vecs[0];
|
||||
*out_data = charp_vecs[1];
|
||||
API_END();
|
||||
}
|
||||
|
||||
// xgboost implementation
|
||||
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
xgboost::bst_ulong len,
|
||||
|
||||
@ -24,6 +24,8 @@ struct XGBAPIThreadLocalEntry {
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<float> ret_vec_float;
|
||||
/*! \brief returning uint vector. */
|
||||
std::vector<std::uint64_t> ret_vec_u64;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returning prediction result. */
|
||||
|
||||
@ -455,7 +455,7 @@ class ArrayInterface {
|
||||
|
||||
explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {}
|
||||
|
||||
explicit ArrayInterface(StringView str) : ArrayInterface<D>{Json::Load(str)} {}
|
||||
explicit ArrayInterface(StringView str) : ArrayInterface{Json::Load(str)} {}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
using T = ArrayInterfaceHandler::Type;
|
||||
|
||||
@ -3,12 +3,20 @@
|
||||
*/
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include "ellpack_page.h"
|
||||
|
||||
#include <xgboost/data.h>
|
||||
|
||||
// dummy implementation of EllpackPage in case CUDA is not used
|
||||
namespace xgboost {
|
||||
|
||||
class EllpackPageImpl {};
|
||||
class EllpackPageImpl {
|
||||
common::HistogramCuts cuts_;
|
||||
|
||||
public:
|
||||
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
|
||||
};
|
||||
|
||||
EllpackPage::EllpackPage() = default;
|
||||
|
||||
@ -32,6 +40,17 @@ size_t EllpackPage::Size() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
return impl_->Cuts();
|
||||
}
|
||||
|
||||
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
return impl_->Cuts();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
|
||||
@ -4,6 +4,10 @@
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
|
||||
#include <algorithm> // for copy
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/hist_util.cuh"
|
||||
@ -11,6 +15,7 @@
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "./ellpack_page.cuh"
|
||||
#include "device_adapter.cuh" // for HasInfInData
|
||||
#include "ellpack_page.h"
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -29,6 +34,16 @@ size_t EllpackPage::Size() const { return impl_->Size(); }
|
||||
|
||||
void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id); }
|
||||
|
||||
[[nodiscard]] common::HistogramCuts& EllpackPage::Cuts() {
|
||||
CHECK(impl_);
|
||||
return impl_->Cuts();
|
||||
}
|
||||
|
||||
[[nodiscard]] common::HistogramCuts const& EllpackPage::Cuts() const {
|
||||
CHECK(impl_);
|
||||
return impl_->Cuts();
|
||||
}
|
||||
|
||||
// Bin each input data entry, store the bin indices in compressed form.
|
||||
__global__ void CompressBinEllpackKernel(
|
||||
common::CompressedBufferWriter wr,
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
#define XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_CUH_
|
||||
#define XGBOOST_DATA_ELLPACK_PAGE_CUH_
|
||||
|
||||
#include <thrust/binary_search.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/categorical.h"
|
||||
#include <thrust/binary_search.h>
|
||||
#include "ellpack_page.h"
|
||||
|
||||
namespace xgboost {
|
||||
/** \brief Struct for accessing and manipulating an ELLPACK matrix on the
|
||||
@ -194,8 +195,8 @@ class EllpackPageImpl {
|
||||
base_rowid = row_id;
|
||||
}
|
||||
|
||||
common::HistogramCuts& Cuts() { return cuts_; }
|
||||
common::HistogramCuts const& Cuts() const { return cuts_; }
|
||||
[[nodiscard]] common::HistogramCuts& Cuts() { return cuts_; }
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cuts_; }
|
||||
|
||||
/*! \return Estimation of memory cost of this page. */
|
||||
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
|
||||
@ -256,4 +257,4 @@ inline size_t GetRowStride(DMatrix* dmat) {
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_CUH_
|
||||
|
||||
59
src/data/ellpack_page.h
Normal file
59
src/data/ellpack_page.h
Normal file
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
#define XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
|
||||
#include <memory> // for unique_ptr
|
||||
|
||||
#include "../common/hist_util.h" // for HistogramCuts
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for DMatrix, BatchParam
|
||||
|
||||
namespace xgboost {
|
||||
class EllpackPageImpl;
|
||||
/**
|
||||
* @brief A page stored in ELLPACK format.
|
||||
*
|
||||
* This class uses the PImpl idiom (https://en.cppreference.com/w/cpp/language/pimpl) to avoid
|
||||
* including CUDA-specific implementation details in the header.
|
||||
*/
|
||||
class EllpackPage {
|
||||
public:
|
||||
/**
|
||||
* @brief Default constructor.
|
||||
*
|
||||
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
|
||||
* set later by the reader.
|
||||
*/
|
||||
EllpackPage();
|
||||
/**
|
||||
* @brief Constructor from an existing DMatrix.
|
||||
*
|
||||
* This is used in the in-memory case. The ELLPACK page is constructed from an existing DMatrix
|
||||
* in CSR format.
|
||||
*/
|
||||
explicit EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param);
|
||||
|
||||
/*! \brief Destructor. */
|
||||
~EllpackPage();
|
||||
|
||||
EllpackPage(EllpackPage&& that);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
[[nodiscard]] size_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(std::size_t row_id);
|
||||
|
||||
[[nodiscard]] const EllpackPageImpl* Impl() const { return impl_.get(); }
|
||||
EllpackPageImpl* Impl() { return impl_.get(); }
|
||||
|
||||
[[nodiscard]] common::HistogramCuts& Cuts();
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<EllpackPageImpl> impl_;
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_H_
|
||||
@ -5,10 +5,10 @@
|
||||
#include <utility>
|
||||
|
||||
#include "ellpack_page.cuh"
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "ellpack_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
void EllpackPageSource::Fetch() {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
if (!this->ReadCache()) {
|
||||
@ -27,5 +27,4 @@ void EllpackPageSource::Fetch() {
|
||||
this->WriteCache();
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -6,17 +6,17 @@
|
||||
#define XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
namespace xgboost::data {
|
||||
class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
||||
bool is_dense_;
|
||||
size_t row_stride_;
|
||||
@ -53,7 +53,6 @@ inline void EllpackPageSource::Fetch() {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
#endif // XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
|
||||
|
||||
@ -245,6 +245,9 @@ class GHistIndexMatrix {
|
||||
std::vector<float> const& values, std::vector<float> const& mins,
|
||||
bst_row_t ridx, bst_feature_t fidx, bool is_cat) const;
|
||||
|
||||
[[nodiscard]] common::HistogramCuts& Cuts() { return cut; }
|
||||
[[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<common::ColumnMatrix> columns_;
|
||||
std::vector<size_t> hit_count_tloc_;
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
#include "../common/threading_utils.h"
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "adapter.h"
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "ellpack_page.h" // for EllpackPage
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -165,7 +165,10 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
|
||||
|
||||
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ctx,
|
||||
const BatchParam ¶m) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (param.Initialized()) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
}
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
|
||||
this->InitializeSparsePage(ctx);
|
||||
if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) {
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
/**
|
||||
* Copyright 2021-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <memory>
|
||||
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "ellpack_page.cuh"
|
||||
@ -11,7 +13,9 @@ namespace xgboost::data {
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
const BatchParam& param) {
|
||||
CHECK(ctx->IsCUDA());
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
if (param.Initialized()) {
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
}
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
size_t row_stride = 0;
|
||||
@ -21,8 +25,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
cache_info_.erase(id);
|
||||
MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
|
||||
std::unique_ptr<common::HistogramCuts> cuts;
|
||||
cuts.reset(
|
||||
new common::HistogramCuts{common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0)});
|
||||
cuts = std::make_unique<common::HistogramCuts>(
|
||||
common::DeviceSketch(ctx->gpu_id, this, param.max_bin, 0));
|
||||
this->InitializeSparsePage(ctx); // reset after use.
|
||||
|
||||
row_stride = GetRowStride(this);
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "../common/io.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/ellpack_page.h"
|
||||
#include "constraints.cuh"
|
||||
#include "driver.h"
|
||||
#include "gpu_hist/evaluate_splits.cuh"
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/version_config.h>
|
||||
|
||||
#include <array> // for array
|
||||
#include <cstddef> // std::size_t
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <string> // std::string
|
||||
@ -15,6 +16,11 @@
|
||||
|
||||
#include "../../../src/c_api/c_api_error.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/data/adapter.h" // for ArrayAdapter
|
||||
#include "../../../src/data/array_interface.h" // for ArrayInterface
|
||||
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix
|
||||
#include "../../../src/data/sparse_page_dmatrix.h" // for SparsePageDMatrix
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(CAPI, XGDMatrixCreateFromMatDT) {
|
||||
@ -137,9 +143,9 @@ TEST(CAPI, ConfigIO) {
|
||||
BoosterHandle handle = learner.get();
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
|
||||
char const* out[1];
|
||||
std::array<char const* , 1> out;
|
||||
bst_ulong len {0};
|
||||
XGBoosterSaveJsonConfig(handle, &len, out);
|
||||
XGBoosterSaveJsonConfig(handle, &len, out.data());
|
||||
|
||||
std::string config_str_0 { out[0] };
|
||||
auto config_0 = Json::Load({config_str_0.c_str(), config_str_0.size()});
|
||||
@ -147,7 +153,7 @@ TEST(CAPI, ConfigIO) {
|
||||
|
||||
bst_ulong len_1 {0};
|
||||
std::string config_str_1 { out[0] };
|
||||
XGBoosterSaveJsonConfig(handle, &len_1, out);
|
||||
XGBoosterSaveJsonConfig(handle, &len_1, out.data());
|
||||
auto config_1 = Json::Load({config_str_1.c_str(), config_str_1.size()});
|
||||
|
||||
ASSERT_EQ(config_0, config_1);
|
||||
@ -266,9 +272,9 @@ TEST(CAPI, DMatrixSetFeatureName) {
|
||||
ASSERT_EQ(std::to_string(i), c_out_features[i]);
|
||||
}
|
||||
|
||||
char const* feat_types [] {"i", "q"};
|
||||
std::array<char const *, 2> feat_types{"i", "q"};
|
||||
static_assert(sizeof(feat_types) / sizeof(feat_types[0]) == kCols);
|
||||
XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types, kCols);
|
||||
XGDMatrixSetStrFeatureInfo(handle, "feature_type", feat_types.data(), kCols);
|
||||
char const **c_out_types;
|
||||
XGDMatrixGetStrFeatureInfo(handle, u8"feature_type", &out_len,
|
||||
&c_out_types);
|
||||
@ -405,4 +411,210 @@ TEST(CAPI, JArgs) {
|
||||
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void MakeLabelForTest(std::shared_ptr<DMatrix> Xy, DMatrixHandle cxy) {
|
||||
auto n_samples = Xy->Info().num_row_;
|
||||
std::vector<float> y(n_samples);
|
||||
for (std::size_t i = 0; i < y.size(); ++i) {
|
||||
y[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
Xy->Info().labels.Reshape(n_samples);
|
||||
Xy->Info().labels.Data()->HostVector() = y;
|
||||
|
||||
auto y_int = GetArrayInterface(Xy->Info().labels.Data(), n_samples, 1);
|
||||
std::string s_y_int;
|
||||
Json::Dump(y_int, &s_y_int);
|
||||
|
||||
XGDMatrixSetInfoFromInterface(cxy, "label", s_y_int.c_str());
|
||||
}
|
||||
|
||||
auto MakeSimpleDMatrixForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) {
|
||||
HostDeviceVector<float> storage;
|
||||
auto arr_int = RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateArrayInterface(&storage);
|
||||
|
||||
data::ArrayAdapter adapter{StringView{arr_int}};
|
||||
std::shared_ptr<DMatrix> Xy{
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads())};
|
||||
|
||||
DMatrixHandle p_fmat;
|
||||
std::string s_dconfig;
|
||||
Json::Dump(dconfig, &s_dconfig);
|
||||
CHECK_EQ(XGDMatrixCreateFromDense(arr_int.c_str(), s_dconfig.c_str(), &p_fmat), 0);
|
||||
|
||||
MakeLabelForTest(Xy, p_fmat);
|
||||
return std::pair{p_fmat, Xy};
|
||||
}
|
||||
|
||||
auto MakeQDMForTest(Context const *ctx, bst_row_t n_samples, bst_feature_t n_features,
|
||||
Json dconfig) {
|
||||
bst_bin_t n_bins{16};
|
||||
dconfig["max_bin"] = Integer{n_bins};
|
||||
|
||||
std::size_t n_batches{4};
|
||||
std::unique_ptr<ArrayIterForTest> iter_0;
|
||||
if (ctx->IsCUDA()) {
|
||||
iter_0 = std::make_unique<CudaArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
|
||||
} else {
|
||||
iter_0 = std::make_unique<NumpyArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
|
||||
}
|
||||
std::string s_dconfig;
|
||||
Json::Dump(dconfig, &s_dconfig);
|
||||
DMatrixHandle p_fmat;
|
||||
CHECK_EQ(XGQuantileDMatrixCreateFromCallback(static_cast<DataIterHandle>(iter_0.get()),
|
||||
iter_0->Proxy(), nullptr, Reset, Next,
|
||||
s_dconfig.c_str(), &p_fmat),
|
||||
0);
|
||||
|
||||
std::unique_ptr<ArrayIterForTest> iter_1;
|
||||
if (ctx->IsCUDA()) {
|
||||
iter_1 = std::make_unique<CudaArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
|
||||
} else {
|
||||
iter_1 = std::make_unique<NumpyArrayIterForTest>(0.0f, n_samples, n_features, n_batches);
|
||||
}
|
||||
auto Xy =
|
||||
std::make_shared<data::IterativeDMatrix>(iter_1.get(), iter_1->Proxy(), nullptr, Reset, Next,
|
||||
std::numeric_limits<float>::quiet_NaN(), 0, n_bins);
|
||||
return std::pair{p_fmat, Xy};
|
||||
}
|
||||
|
||||
auto MakeExtMemForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) {
|
||||
std::size_t n_batches{4};
|
||||
NumpyArrayIterForTest iter_0{0.0f, n_samples, n_features, n_batches};
|
||||
std::string s_dconfig;
|
||||
dconfig["cache_prefix"] = String{"cache"};
|
||||
Json::Dump(dconfig, &s_dconfig);
|
||||
DMatrixHandle p_fmat;
|
||||
CHECK_EQ(XGDMatrixCreateFromCallback(static_cast<DataIterHandle>(&iter_0), iter_0.Proxy(), Reset,
|
||||
Next, s_dconfig.c_str(), &p_fmat),
|
||||
0);
|
||||
|
||||
NumpyArrayIterForTest iter_1{0.0f, n_samples, n_features, n_batches};
|
||||
auto Xy = std::make_shared<data::SparsePageDMatrix>(
|
||||
&iter_1, iter_1.Proxy(), Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, "");
|
||||
MakeLabelForTest(Xy, p_fmat);
|
||||
return std::pair{p_fmat, Xy};
|
||||
}
|
||||
|
||||
template <typename Page>
|
||||
void CheckResult(Context const *ctx, bst_feature_t n_features, std::shared_ptr<DMatrix> Xy,
|
||||
float const *out_data, std::uint64_t const *out_indptr) {
|
||||
for (auto const &page : Xy->GetBatches<Page>(ctx, BatchParam{16, 0.2})) {
|
||||
auto const &cut = page.Cuts();
|
||||
auto const &ptrs = cut.Ptrs();
|
||||
auto const &vals = cut.Values();
|
||||
auto const &mins = cut.MinValues();
|
||||
for (bst_feature_t f = 0; f < Xy->Info().num_col_; ++f) {
|
||||
ASSERT_EQ(ptrs[f] + f, out_indptr[f]);
|
||||
ASSERT_EQ(mins[f], out_data[out_indptr[f]]);
|
||||
auto beg = out_indptr[f];
|
||||
auto end = out_indptr[f + 1];
|
||||
auto val_beg = ptrs[f];
|
||||
for (std::uint64_t i = beg + 1, j = val_beg; i < end; ++i, ++j) {
|
||||
ASSERT_EQ(vals[j], out_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(ptrs[n_features] + n_features, out_indptr[n_features]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestXGDMatrixGetQuantileCut(Context const *ctx) {
|
||||
bst_row_t n_samples{1024};
|
||||
bst_feature_t n_features{16};
|
||||
|
||||
Json dconfig{Object{}};
|
||||
dconfig["ntread"] = Integer{Context{}.Threads()};
|
||||
dconfig["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
|
||||
|
||||
auto check_result = [n_features, &ctx](std::shared_ptr<DMatrix> Xy, StringView s_out_data,
|
||||
StringView s_out_indptr) {
|
||||
auto i_out_data = ArrayInterface<1, false>{s_out_data};
|
||||
ASSERT_EQ(i_out_data.type, ArrayInterfaceHandler::kF4);
|
||||
auto out_data = static_cast<float const *>(i_out_data.data);
|
||||
ASSERT_TRUE(out_data);
|
||||
|
||||
auto i_out_indptr = ArrayInterface<1, false>{s_out_indptr};
|
||||
ASSERT_EQ(i_out_indptr.type, ArrayInterfaceHandler::kU8);
|
||||
auto out_indptr = static_cast<std::uint64_t const *>(i_out_indptr.data);
|
||||
ASSERT_TRUE(out_data);
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
CheckResult<GHistIndexMatrix>(ctx, n_features, Xy, out_data, out_indptr);
|
||||
} else {
|
||||
CheckResult<EllpackPage>(ctx, n_features, Xy, out_data, out_indptr);
|
||||
}
|
||||
};
|
||||
|
||||
Json config{Null{}};
|
||||
std::string s_config;
|
||||
Json::Dump(config, &s_config);
|
||||
char const *out_indptr;
|
||||
char const *out_data;
|
||||
|
||||
{
|
||||
// SimpleDMatrix
|
||||
auto [p_fmat, Xy] = MakeSimpleDMatrixForTest(n_samples, n_features, dconfig);
|
||||
// assert fail, we don't have the quantile yet.
|
||||
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1);
|
||||
|
||||
std::array<DMatrixHandle, 1> mats{p_fmat};
|
||||
BoosterHandle booster;
|
||||
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
|
||||
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
|
||||
if (ctx->IsCUDA()) {
|
||||
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0);
|
||||
}
|
||||
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
|
||||
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
|
||||
|
||||
check_result(Xy, out_data, out_indptr);
|
||||
|
||||
XGDMatrixFree(p_fmat);
|
||||
XGBoosterFree(booster);
|
||||
}
|
||||
|
||||
{
|
||||
// IterativeDMatrix
|
||||
auto [p_fmat, Xy] = MakeQDMForTest(ctx, n_samples, n_features, dconfig);
|
||||
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
|
||||
|
||||
check_result(Xy, out_data, out_indptr);
|
||||
XGDMatrixFree(p_fmat);
|
||||
}
|
||||
|
||||
{
|
||||
// SparsePageDMatrix
|
||||
auto [p_fmat, Xy] = MakeExtMemForTest(n_samples, n_features, dconfig);
|
||||
// assert fail, we don't have the quantile yet.
|
||||
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), -1);
|
||||
|
||||
std::array<DMatrixHandle, 1> mats{p_fmat};
|
||||
BoosterHandle booster;
|
||||
ASSERT_EQ(XGBoosterCreate(mats.data(), 1, &booster), 0);
|
||||
ASSERT_EQ(XGBoosterSetParam(booster, "max_bin", "16"), 0);
|
||||
if (ctx->IsCUDA()) {
|
||||
ASSERT_EQ(XGBoosterSetParam(booster, "tree_method", "gpu_hist"), 0);
|
||||
}
|
||||
ASSERT_EQ(XGBoosterUpdateOneIter(booster, 0, p_fmat), 0);
|
||||
ASSERT_EQ(XGDMatrixGetQuantileCut(p_fmat, s_config.c_str(), &out_indptr, &out_data), 0);
|
||||
|
||||
XGDMatrixFree(p_fmat);
|
||||
XGBoosterFree(booster);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(CAPI, XGDMatrixGetQuantileCut) {
|
||||
Context ctx;
|
||||
TestXGDMatrixGetQuantileCut(&ctx);
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST(CAPI, GPUXGDMatrixGetQuantileCut) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestXGDMatrixGetQuantileCut(&ctx);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include "../../../src/common/categorical.h"
|
||||
#include "../../../src/common/hist_util.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/ellpack_page.h"
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../helpers.h"
|
||||
#include "../histogram_helpers.h"
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/ellpack_page.h"
|
||||
#include "../../../src/data/iterative_dmatrix.h"
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../helpers.h"
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/ellpack_page.h"
|
||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
|
||||
@ -472,6 +472,18 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_
|
||||
return m;
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
|
||||
: ArrayIterForTest{sparsity, rows, cols, batches} {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
|
||||
int CudaArrayIterForTest::Next() {
|
||||
common::AssertGPUSupport();
|
||||
return 0;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
NumpyArrayIterForTest::NumpyArrayIterForTest(float sparsity, size_t rows, size_t cols,
|
||||
size_t batches)
|
||||
: ArrayIterForTest{sparsity, rows, cols, batches} {
|
||||
@ -650,7 +662,7 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs,
|
||||
ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches)
|
||||
: rows_{rows}, cols_{cols}, n_batches_{batches} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
rng_.reset(new RandomDataGenerator{rows_, cols_, sparsity});
|
||||
rng_ = std::make_unique<RandomDataGenerator>(rows_, cols_, sparsity);
|
||||
std::tie(batches_, interface_) = rng_->GenerateArrayInterfaceBatch(&data_, n_batches_);
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/common/common.h"
|
||||
#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl
|
||||
#include "../../../src/data/ellpack_page.h" // for EllpackPage
|
||||
#include "../../../src/data/sparse_page_source.h"
|
||||
#include "../../../src/tree/constraints.cuh"
|
||||
#include "../../../src/tree/param.h" // for TrainParam
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
@ -10,6 +11,16 @@ from test_dmatrix import set_base_margin_info
|
||||
|
||||
from xgboost import testing as tm
|
||||
|
||||
cupy = pytest.importorskip("cupy")
|
||||
|
||||
|
||||
def test_array_interface() -> None:
|
||||
arr = cupy.array([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||
i_arr = arr.__cuda_array_interface__
|
||||
i_arr = json.loads(json.dumps(i_arr))
|
||||
ret = xgb.core.from_array_interface(i_arr)
|
||||
np.testing.assert_equal(cupy.asnumpy(arr), cupy.asnumpy(ret))
|
||||
|
||||
|
||||
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
|
||||
@ -8,7 +8,11 @@ from hypothesis import assume, given, note, settings, strategies
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
|
||||
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
|
||||
from xgboost.testing.updater import (
|
||||
check_get_quantile_cut,
|
||||
check_init_estimation,
|
||||
check_quantile_loss,
|
||||
)
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import test_updaters as test_up
|
||||
@ -264,3 +268,7 @@ class TestGPUUpdaters:
|
||||
},
|
||||
num_boost_round=150,
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_get_quantile_cut(self) -> None:
|
||||
check_get_quantile_cut("gpu_hist")
|
||||
|
||||
@ -14,7 +14,11 @@ from xgboost.testing.params import (
|
||||
hist_multi_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
)
|
||||
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
|
||||
from xgboost.testing.updater import (
|
||||
check_get_quantile_cut,
|
||||
check_init_estimation,
|
||||
check_quantile_loss,
|
||||
)
|
||||
|
||||
|
||||
def train_result(param, dmat, num_rounds):
|
||||
@ -537,3 +541,8 @@ class TestTreeMethod:
|
||||
@pytest.mark.parametrize("weighted", [True, False])
|
||||
def test_quantile_loss(self, weighted: bool) -> None:
|
||||
check_quantile_loss("hist", weighted)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
@pytest.mark.parametrize("tree_method", ["hist"])
|
||||
def test_get_quantile_cut(self, tree_method: str) -> None:
|
||||
check_get_quantile_cut(tree_method)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user