Support optimal partitioning for GPU hist. (#7652)
* Implement `MaxCategory` in quantile. * Implement partition-based split for GPU evaluation. Currently, it's based on the existing evaluation function. * Extract an evaluator from GPU Hist to store the needed states. * Added some CUDA stream/event utilities. * Update document with references. * Fixed a bug in approx evaluator where the number of data points is less than the number of categories.
This commit is contained in:
parent
2369d55e9a
commit
0d0abe1845
@ -61,7 +61,12 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
|
||||
return X, y
|
||||
|
||||
|
||||
params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32}
|
||||
params = {
|
||||
"tree_method": "gpu_hist",
|
||||
"use_label_encoder": False,
|
||||
"n_estimators": 32,
|
||||
"colsample_bylevel": 0.7,
|
||||
}
|
||||
|
||||
|
||||
def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
|
||||
@ -70,13 +75,13 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
|
||||
X, y, random_state=1994, test_size=0.2
|
||||
)
|
||||
# Specify `enable_categorical`.
|
||||
clf = xgb.XGBClassifier(**params, enable_categorical=True)
|
||||
clf.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=[(X_test, y_test), (X_train, y_train)],
|
||||
clf = xgb.XGBClassifier(
|
||||
**params,
|
||||
eval_metric="auc",
|
||||
enable_categorical=True,
|
||||
max_cat_to_onehot=1, # We use optimal partitioning exclusively
|
||||
)
|
||||
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)])
|
||||
clf.save_model(os.path.join(output_dir, "categorical.json"))
|
||||
|
||||
y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples
|
||||
|
||||
@ -3,15 +3,15 @@ Getting started with categorical data
|
||||
=====================================
|
||||
|
||||
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
|
||||
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported
|
||||
experimental support for one-hot encoding based tree split, and in 1.6 `approx` support
|
||||
was added.
|
||||
|
||||
In before, users need to run an encoder themselves before passing the data into XGBoost,
|
||||
which creates a sparse matrix and potentially increase memory usage. This demo showcases
|
||||
the experimental categorical data support, more advanced features are planned.
|
||||
|
||||
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with categorical data
|
||||
which creates a sparse matrix and potentially increase memory usage. This demo
|
||||
showcases the experimental categorical data support, more advanced features are planned.
|
||||
|
||||
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with
|
||||
categorical data.
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
@ -55,8 +55,11 @@ def main() -> None:
|
||||
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
|
||||
# DataFrame with categorical features
|
||||
X, y = make_categorical(100, 10, 4, False)
|
||||
# Specify `enable_categorical` to True.
|
||||
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
|
||||
# Specify `enable_categorical` to True, also we use onehot encoding based split
|
||||
# here for demonstration. For details see the document of `max_cat_to_onehot`.
|
||||
reg = xgb.XGBRegressor(
|
||||
tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5
|
||||
)
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
|
||||
# Pass in already encoded data
|
||||
|
||||
@ -245,8 +245,8 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
|
||||
|
||||
- Use single precision to build histograms instead of double precision.
|
||||
|
||||
Additional parameters for ``approx`` tree method
|
||||
================================================
|
||||
Additional parameters for ``approx`` and ``gpu_hist`` tree method
|
||||
=================================================================
|
||||
|
||||
* ``max_cat_to_onehot``
|
||||
|
||||
@ -257,7 +257,8 @@ Additional parameters for ``approx`` tree method
|
||||
- A threshold for deciding whether XGBoost should use one-hot encoding based split for
|
||||
categorical data. When number of categories is lesser than the threshold then one-hot
|
||||
encoding is chosen, otherwise the categories will be partitioned into children nodes.
|
||||
Only relevant for regression and binary classification with `approx` tree method.
|
||||
Only relevant for regression and binary classification. Also, `approx` or `gpu_hist`
|
||||
tree method is required.
|
||||
|
||||
Additional parameters for Dart Booster (``booster=dart``)
|
||||
=========================================================
|
||||
|
||||
@ -2,6 +2,10 @@
|
||||
Categorical Data
|
||||
################
|
||||
|
||||
.. note::
|
||||
|
||||
As of XGBoost 1.6, the feature is highly experimental and has limited features
|
||||
|
||||
Starting from version 1.5, XGBoost has experimental support for categorical data available
|
||||
for public testing. At the moment, the support is implemented as one-hot encoding based
|
||||
categorical tree splits. For numerical data, the split condition is defined as
|
||||
@ -107,6 +111,28 @@ For numerical data, the feature type can be ``"q"`` or ``"float"``, while for ca
|
||||
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
|
||||
:class:`dask.Array <dask.Array>` can also be used as categorical data.
|
||||
|
||||
********************
|
||||
Optimal Partitioning
|
||||
********************
|
||||
|
||||
.. versionadded:: 1.6
|
||||
|
||||
Optimal partitioning is a technique for partitioning the categorical predictors for each
|
||||
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
|
||||
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
|
||||
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
|
||||
<#references>`__ brought it to the context of gradient boosting trees and now is also
|
||||
adopted in XGBoost as an optional feature for handling categorical splits. More
|
||||
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
|
||||
partition a set of discrete values into groups based on the distances between a measure of
|
||||
these values, one only needs to look at sorted partitions instead of enumerating all
|
||||
possible permutations. In the context of decision trees, the discrete values are
|
||||
categories, and the measure is the output leaf value. Intuitively, we want to group the
|
||||
categories that output similar leaf values. During split finding, we first sort the
|
||||
gradient histogram to prepare the contiguous partitions then enumerate the splits
|
||||
according to these sorted values. One of the related parameters for XGBoost is
|
||||
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
|
||||
used for each feature, see :doc:`/parameter` for details.
|
||||
|
||||
*************
|
||||
Miscellaneous
|
||||
@ -120,10 +146,20 @@ actual number of unique categories. During training this is validated but for p
|
||||
it's treated as the same as missing value for performance reasons. Lastly, missing values
|
||||
are treated as the same as numerical features (using the learned split direction).
|
||||
|
||||
|
||||
**********
|
||||
Next Steps
|
||||
References
|
||||
**********
|
||||
|
||||
As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU
|
||||
training is not yet supported. Please see `this issue
|
||||
<https://github.com/dmlc/xgboost/issues/6503>`_ for progress.
|
||||
[1] Walter D. Fisher. "`On Grouping for Maximum Homogeneity`_." Journal of the American Statistical Association. Vol. 53, No. 284 (Dec., 1958), pp. 789-798.
|
||||
|
||||
[2] Trevor Hastie, Robert Tibshirani, Jerome Friedman. "`The Elements of Statistical Learning`_". Springer Series in Statistics Springer New York Inc. (2001).
|
||||
|
||||
[3] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, Tie-Yan Liu. "`LightGBM\: A Highly Efficient Gradient Boosting Decision Tree`_." Advances in Neural Information Processing Systems 30 (NIPS 2017), pp. 3149-3157.
|
||||
|
||||
|
||||
.. _On Grouping for Maximum Homogeneity: https://www.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479
|
||||
|
||||
.. _The Elements of Statistical Learning: https://link.springer.com/book/10.1007/978-0-387-84858-7
|
||||
|
||||
.. _LightGBM\: A Highly Efficient Gradient Boosting Decision Tree: https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
* Copyright 2021-2022 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TASK_H_
|
||||
#define XGBOOST_TASK_H_
|
||||
@ -34,6 +34,10 @@ struct ObjInfo {
|
||||
|
||||
explicit ObjInfo(Task t) : task{t} {}
|
||||
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {}
|
||||
|
||||
constexpr bool UseOneHot() const {
|
||||
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TASK_H_
|
||||
|
||||
@ -581,10 +581,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Experimental support of specializing for categorical features. Do not set to
|
||||
True unless you are interested in development. Currently it's only available
|
||||
for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also,
|
||||
JSON serialization format is required.
|
||||
Experimental support of specializing for categorical features. Do not set
|
||||
to True unless you are interested in development. Currently it's only
|
||||
available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON
|
||||
serialization format is required. (XGBoost 1.6 for approx)
|
||||
|
||||
"""
|
||||
if group is not None and qid is not None:
|
||||
|
||||
@ -207,7 +207,9 @@ __model_doc = f'''
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
Experimental support for categorical data. Do not set to true unless you are
|
||||
interested in development. Only valid when `gpu_hist` and dataframe are used.
|
||||
interested in development. Only valid when `gpu_hist` or `approx` is used along
|
||||
with dataframe as input. Also, JSON/UBJSON serialization format is
|
||||
required. (XGBoost 1.6 for approx)
|
||||
|
||||
max_cat_to_onehot : Optional[int]
|
||||
|
||||
@ -216,10 +218,11 @@ __model_doc = f'''
|
||||
.. note:: This parameter is experimental
|
||||
|
||||
A threshold for deciding whether XGBoost should use one-hot encoding based split
|
||||
for categorical data. When number of categories is lesser than the threshold then
|
||||
one-hot encoding is chosen, otherwise the categories will be partitioned into
|
||||
children nodes. Only relevant for regression and binary classification and
|
||||
`approx` tree method.
|
||||
for categorical data. When number of categories is lesser than the threshold
|
||||
then one-hot encoding is chosen, otherwise the categories will be partitioned
|
||||
into children nodes. Only relevant for regression and binary
|
||||
classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See
|
||||
:doc:`Categorical Data </tutorials/categorical>` for details.
|
||||
|
||||
eval_metric : Optional[Union[str, List[str], Callable]]
|
||||
|
||||
|
||||
@ -16,6 +16,10 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
using CatBitField = LBitField32;
|
||||
using KCatBitField = CLBitField32;
|
||||
|
||||
// Cast the categorical type.
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
|
||||
@ -57,6 +61,11 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
|
||||
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
|
||||
return dft_left;
|
||||
}
|
||||
|
||||
auto pos = KCatBitField::ToBitPos(cat);
|
||||
if (pos.int_pos >= cats.size()) {
|
||||
return true;
|
||||
}
|
||||
return !s_cats.Check(AsCat(cat));
|
||||
}
|
||||
|
||||
@ -73,18 +82,14 @@ inline void InvalidCategory() {
|
||||
/*!
|
||||
* \brief Whether should we use onehot encoding for categorical data.
|
||||
*/
|
||||
inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
|
||||
bool use_one_hot = n_cats < max_cat_to_onehot ||
|
||||
(task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary);
|
||||
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
|
||||
bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot();
|
||||
return use_one_hot;
|
||||
}
|
||||
|
||||
struct IsCatOp {
|
||||
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
|
||||
};
|
||||
|
||||
using CatBitField = LBitField32;
|
||||
using KCatBitField = CLBitField32;
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -952,22 +952,22 @@ thrust::device_ptr<T const> tcend(xgboost::HostDeviceVector<T> const& vector) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
|
||||
return thrust::device_ptr<T>(span.data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
return thrust::device_ptr<T>(span.data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
|
||||
return tbegin(span) + span.size();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
return tbegin(span) + span.size();
|
||||
}
|
||||
|
||||
@ -982,12 +982,12 @@ XGBOOST_DEVICE auto trend(xgboost::common::Span<T> &span) { // NOLINT
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
return thrust::device_ptr<T const>(span.data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
XGBOOST_DEVICE thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT
|
||||
return tcbegin(span) + span.size();
|
||||
}
|
||||
|
||||
@ -1536,4 +1536,69 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
|
||||
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
|
||||
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
class CUDAStreamView;
|
||||
|
||||
class CUDAEvent {
|
||||
cudaEvent_t event_{nullptr};
|
||||
|
||||
public:
|
||||
CUDAEvent() { dh::safe_cuda(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); }
|
||||
~CUDAEvent() {
|
||||
if (event_) {
|
||||
dh::safe_cuda(cudaEventDestroy(event_));
|
||||
}
|
||||
}
|
||||
|
||||
CUDAEvent(CUDAEvent const &that) = delete;
|
||||
CUDAEvent &operator=(CUDAEvent const &that) = delete;
|
||||
|
||||
inline void Record(CUDAStreamView stream); // NOLINT
|
||||
|
||||
operator cudaEvent_t() const { return event_; } // NOLINT
|
||||
};
|
||||
|
||||
class CUDAStreamView {
|
||||
cudaStream_t stream_{nullptr};
|
||||
|
||||
public:
|
||||
explicit CUDAStreamView(cudaStream_t s) : stream_{s} {}
|
||||
void Wait(CUDAEvent const &e) {
|
||||
#if defined(__CUDACC_VER_MAJOR__)
|
||||
#if __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0
|
||||
// CUDA == 11.0
|
||||
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, 0));
|
||||
#else
|
||||
// CUDA > 11.0
|
||||
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
|
||||
#endif // __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0:
|
||||
#else // clang
|
||||
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
|
||||
#endif // defined(__CUDACC_VER_MAJOR__)
|
||||
}
|
||||
operator cudaStream_t() const { // NOLINT
|
||||
return stream_;
|
||||
}
|
||||
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
|
||||
};
|
||||
|
||||
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
|
||||
}
|
||||
|
||||
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; }
|
||||
|
||||
class CUDAStream {
|
||||
cudaStream_t stream_;
|
||||
|
||||
public:
|
||||
CUDAStream() {
|
||||
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
~CUDAStream() {
|
||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||
};
|
||||
} // namespace dh
|
||||
|
||||
@ -33,66 +33,84 @@ namespace common {
|
||||
*/
|
||||
using GHistIndexRow = Span<uint32_t const>;
|
||||
|
||||
// A CSC matrix representing histogram cuts, used in CPU quantile hist.
|
||||
// A CSC matrix representing histogram cuts.
|
||||
// The cut values represent upper bounds of bins containing approximately equal numbers of elements
|
||||
class HistogramCuts {
|
||||
bool has_categorical_{false};
|
||||
float max_cat_{-1.0f};
|
||||
|
||||
protected:
|
||||
using BinIdx = uint32_t;
|
||||
|
||||
void Swap(HistogramCuts&& that) noexcept(true) {
|
||||
std::swap(cut_values_, that.cut_values_);
|
||||
std::swap(cut_ptrs_, that.cut_ptrs_);
|
||||
std::swap(min_vals_, that.min_vals_);
|
||||
|
||||
std::swap(has_categorical_, that.has_categorical_);
|
||||
std::swap(max_cat_, that.max_cat_);
|
||||
}
|
||||
|
||||
void Copy(HistogramCuts const& that) {
|
||||
cut_values_.Resize(that.cut_values_.Size());
|
||||
cut_ptrs_.Resize(that.cut_ptrs_.Size());
|
||||
min_vals_.Resize(that.min_vals_.Size());
|
||||
cut_values_.Copy(that.cut_values_);
|
||||
cut_ptrs_.Copy(that.cut_ptrs_);
|
||||
min_vals_.Copy(that.min_vals_);
|
||||
has_categorical_ = that.has_categorical_;
|
||||
max_cat_ = that.max_cat_;
|
||||
}
|
||||
|
||||
public:
|
||||
HostDeviceVector<bst_float> cut_values_; // NOLINT
|
||||
HostDeviceVector<float> cut_values_; // NOLINT
|
||||
HostDeviceVector<uint32_t> cut_ptrs_; // NOLINT
|
||||
// storing minimum value in a sketch set.
|
||||
HostDeviceVector<float> min_vals_; // NOLINT
|
||||
|
||||
HistogramCuts();
|
||||
HistogramCuts(HistogramCuts const& that) {
|
||||
cut_values_.Resize(that.cut_values_.Size());
|
||||
cut_ptrs_.Resize(that.cut_ptrs_.Size());
|
||||
min_vals_.Resize(that.min_vals_.Size());
|
||||
cut_values_.Copy(that.cut_values_);
|
||||
cut_ptrs_.Copy(that.cut_ptrs_);
|
||||
min_vals_.Copy(that.min_vals_);
|
||||
}
|
||||
HistogramCuts(HistogramCuts const& that) { this->Copy(that); }
|
||||
|
||||
HistogramCuts(HistogramCuts&& that) noexcept(true) {
|
||||
*this = std::forward<HistogramCuts&&>(that);
|
||||
this->Swap(std::forward<HistogramCuts>(that));
|
||||
}
|
||||
|
||||
HistogramCuts& operator=(HistogramCuts const& that) {
|
||||
cut_values_.Resize(that.cut_values_.Size());
|
||||
cut_ptrs_.Resize(that.cut_ptrs_.Size());
|
||||
min_vals_.Resize(that.min_vals_.Size());
|
||||
cut_values_.Copy(that.cut_values_);
|
||||
cut_ptrs_.Copy(that.cut_ptrs_);
|
||||
min_vals_.Copy(that.min_vals_);
|
||||
this->Copy(that);
|
||||
return *this;
|
||||
}
|
||||
|
||||
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
|
||||
cut_ptrs_ = std::move(that.cut_ptrs_);
|
||||
cut_values_ = std::move(that.cut_values_);
|
||||
min_vals_ = std::move(that.min_vals_);
|
||||
this->Swap(std::forward<HistogramCuts>(that));
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint32_t FeatureBins(uint32_t feature) const {
|
||||
return cut_ptrs_.ConstHostVector().at(feature + 1) -
|
||||
cut_ptrs_.ConstHostVector()[feature];
|
||||
uint32_t FeatureBins(bst_feature_t feature) const {
|
||||
return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature];
|
||||
}
|
||||
|
||||
// Getters. Cuts should be of no use after building histogram indices, but currently
|
||||
// they are deeply linked with quantile_hist, gpu sketcher and gpu_hist, so we preserve
|
||||
// these for now.
|
||||
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
|
||||
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
|
||||
std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); }
|
||||
|
||||
bool HasCategorical() const { return has_categorical_; }
|
||||
float MaxCategory() const { return max_cat_; }
|
||||
/**
|
||||
* \brief Set meta info about categorical features.
|
||||
*
|
||||
* \param has_cat Do we have categorical feature in the data?
|
||||
* \param max_cat The maximum categorical value in all features.
|
||||
*/
|
||||
void SetCategorical(bool has_cat, float max_cat) {
|
||||
has_categorical_ = has_cat;
|
||||
max_cat_ = max_cat;
|
||||
}
|
||||
|
||||
size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }
|
||||
|
||||
// Return the index of a cut point that is strictly greater than the input
|
||||
// value, or the last available index if none exists
|
||||
BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs,
|
||||
BinIdx SearchBin(float value, bst_feature_t column_id, std::vector<uint32_t> const& ptrs,
|
||||
std::vector<float> const& values) const {
|
||||
auto end = ptrs[column_id + 1];
|
||||
auto beg = ptrs[column_id];
|
||||
@ -102,7 +120,7 @@ class HistogramCuts {
|
||||
return idx;
|
||||
}
|
||||
|
||||
BinIdx SearchBin(float value, uint32_t column_id) const {
|
||||
BinIdx SearchBin(float value, bst_feature_t column_id) const {
|
||||
return this->SearchBin(value, column_id, Ptrs(), Values());
|
||||
}
|
||||
|
||||
|
||||
@ -272,7 +272,7 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
||||
|
||||
// move all categories into a flatten vector to prepare for allreduce
|
||||
size_t total = feature_ptr.back();
|
||||
std::vector<bst_cat_t> flatten(total, 0);
|
||||
std::vector<float> flatten(total, 0);
|
||||
auto cursor{flatten.begin()};
|
||||
for (auto const &feat : categories) {
|
||||
cursor = std::copy(feat.cbegin(), feat.cend(), cursor);
|
||||
@ -287,15 +287,15 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
|
||||
auto gtotal = global_worker_ptr.back();
|
||||
|
||||
// categories in all workers with all features.
|
||||
std::vector<bst_cat_t> global_categories(gtotal, 0);
|
||||
std::vector<float> global_categories(gtotal, 0);
|
||||
auto rank_begin = global_worker_ptr[rank];
|
||||
auto rank_size = global_worker_ptr[rank + 1] - rank_begin;
|
||||
CHECK_EQ(rank_size, total);
|
||||
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
|
||||
// gather values from all workers.
|
||||
rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size());
|
||||
QuantileAllreduce<bst_cat_t> allreduce_result{global_categories, global_worker_ptr,
|
||||
global_feat_ptrs, categories.size()};
|
||||
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
|
||||
categories.size()};
|
||||
ParallelFor(categories.size(), n_threads, [&](auto fidx) {
|
||||
if (!IsCat(feature_types, fidx)) {
|
||||
return;
|
||||
@ -531,6 +531,22 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
||||
InvalidCategory();
|
||||
}
|
||||
}
|
||||
auto const &ptrs = cuts->Ptrs();
|
||||
auto const &vals = cuts->Values();
|
||||
|
||||
float max_cat{-std::numeric_limits<float>::infinity()};
|
||||
for (size_t i = 1; i < ptrs.size(); ++i) {
|
||||
if (IsCat(feature_types_, i - 1)) {
|
||||
auto beg = ptrs[i - 1];
|
||||
auto end = ptrs[i];
|
||||
auto feat = Span<float const>{vals}.subspan(beg, end - beg);
|
||||
auto max_elem = *std::max_element(feat.cbegin(), feat.cend());
|
||||
if (max_elem > max_cat) {
|
||||
max_cat = max_elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
cuts->SetCategorical(true, max_cat);
|
||||
}
|
||||
|
||||
monitor_.Stop(__func__);
|
||||
|
||||
@ -1,22 +1,23 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*/
|
||||
#include <thrust/unique.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/binary_search.h>
|
||||
#include <thrust/transform_scan.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/transform_scan.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
#include <limits> // std::numeric_limits
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/span.h"
|
||||
#include "quantile.h"
|
||||
#include "quantile.cuh"
|
||||
#include "hist_util.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "categorical.h"
|
||||
#include "common.h"
|
||||
#include "device_helpers.cuh"
|
||||
#include "hist_util.h"
|
||||
#include "quantile.cuh"
|
||||
#include "quantile.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -586,7 +587,7 @@ struct InvalidCatOp {
|
||||
Span<uint32_t const> ptrs;
|
||||
Span<FeatureType const> ft;
|
||||
|
||||
XGBOOST_DEVICE bool operator()(size_t i) {
|
||||
XGBOOST_DEVICE bool operator()(size_t i) const {
|
||||
auto fidx = dh::SegmentId(ptrs, i);
|
||||
return IsCat(ft, fidx) && InvalidCat(values[i]);
|
||||
}
|
||||
@ -683,18 +684,36 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
out_column[idx] = in_column[idx+1].value;
|
||||
});
|
||||
|
||||
float max_cat{-1.0f};
|
||||
if (has_categorical_) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan();
|
||||
auto it = thrust::make_counting_iterator(0ul);
|
||||
auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft};
|
||||
auto it = dh::MakeTransformIterator<thrust::pair<bool, float>>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto fidx = dh::SegmentId(d_out_columns_ptr, i);
|
||||
if (IsCat(d_ft, fidx)) {
|
||||
auto invalid = invalid_op(i);
|
||||
auto v = out_cut_values[i];
|
||||
return thrust::make_pair(invalid, v);
|
||||
}
|
||||
return thrust::make_pair(false, std::numeric_limits<float>::min());
|
||||
});
|
||||
|
||||
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size());
|
||||
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||
InvalidCatOp{out_cut_values, ptrs, d_ft});
|
||||
bool invalid{false};
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::tie(invalid, max_cat) =
|
||||
thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
|
||||
thrust::make_pair(false, std::numeric_limits<float>::min()),
|
||||
[=] XGBOOST_DEVICE(thrust::pair<bool, bst_cat_t> const &l,
|
||||
thrust::pair<bool, bst_cat_t> const &r) {
|
||||
return thrust::make_pair(l.first || r.first, std::max(l.second, r.second));
|
||||
});
|
||||
if (invalid) {
|
||||
InvalidCategory();
|
||||
}
|
||||
}
|
||||
|
||||
p_cuts->SetCategorical(this->has_categorical_, max_cat);
|
||||
|
||||
timer_.Stop(__func__);
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 by XGBoost Contributors
|
||||
* Copyright 2020-2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <algorithm> // std::max
|
||||
#include <limits>
|
||||
#include "evaluate_splits.cuh"
|
||||
|
||||
#include "../../common/categorical.h"
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "expand_entry.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -23,7 +28,7 @@ XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
|
||||
float missing_right_gain = evaluator.CalcSplitGain(
|
||||
param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
if (missing_left_gain > missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
@ -69,83 +74,13 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
|
||||
return shared_sum;
|
||||
}
|
||||
|
||||
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
|
||||
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
||||
SumCallbackOp<GradientSumT> *,
|
||||
GradientPairPrecise const &missing,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||
TempStorageT *) {
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
auto rest = inputs.parent_sum - GradientPairPrecise(bin) - missing;
|
||||
return GradientSumT{rest};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct UpdateOneHot {
|
||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||
bst_feature_t fidx, GradientPairPrecise const &missing,
|
||||
GradientSumT const &bin,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||
DeviceSplitCandidate *best_split) {
|
||||
int split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = inputs.feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, true,
|
||||
inputs.param);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT, typename TempStorageT, typename ScanT>
|
||||
struct NumericBin {
|
||||
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
|
||||
SumCallbackOp<GradientSumT> *prefix_callback,
|
||||
GradientPairPrecise const &missing,
|
||||
EvaluateSplitInputs<GradientSumT> inputs,
|
||||
TempStorageT *temp_storage) {
|
||||
GradientSumT bin = thread_active
|
||||
? inputs.gradient_histogram[scan_begin + threadIdx.x]
|
||||
: GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback);
|
||||
return bin;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
struct UpdateNumeric {
|
||||
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
|
||||
bst_feature_t fidx, GradientPairPrecise const &missing,
|
||||
GradientSumT const &bin,
|
||||
EvaluateSplitInputs<GradientSumT> const &inputs,
|
||||
DeviceSplitCandidate *best_split) {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, false,
|
||||
inputs.param);
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Find the thread with best gain. */
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
|
||||
typename MaxReduceT, typename TempStorageT, typename GradientSumT,
|
||||
typename BinFn, typename UpdateFn>
|
||||
template <int BLOCK_THREADS, typename ReduceT, typename ScanT, typename MaxReduceT,
|
||||
typename TempStorageT, typename GradientSumT, SplitType type>
|
||||
__device__ void EvaluateFeature(
|
||||
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<bst_feature_t> sorted_idx, size_t offset,
|
||||
DeviceSplitCandidate *best_split, // shared memory storing best split
|
||||
TempStorageT *temp_storage // temp memory for cub operations
|
||||
) {
|
||||
@ -154,23 +89,46 @@ __device__ void EvaluateFeature(
|
||||
uint32_t gidx_end =
|
||||
inputs.feature_segments[fidx + 1]; // end bin for i^th feature
|
||||
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
|
||||
auto bin_fn = BinFn();
|
||||
auto update_fn = UpdateFn();
|
||||
|
||||
// Sum histogram bins for current feature
|
||||
GradientSumT const feature_sum =
|
||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
|
||||
feature_hist, temp_storage);
|
||||
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(feature_hist, temp_storage);
|
||||
|
||||
GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum};
|
||||
float const null_gain = -std::numeric_limits<bst_float>::infinity();
|
||||
|
||||
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
|
||||
scan_begin += BLOCK_THREADS) {
|
||||
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
|
||||
auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage);
|
||||
|
||||
auto calc_bin_value = [&]() {
|
||||
GradientSumT bin;
|
||||
switch (type) {
|
||||
case kOneHot: {
|
||||
auto rest =
|
||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
||||
break;
|
||||
}
|
||||
case kNum: {
|
||||
bin =
|
||||
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
|
||||
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
auto rest = thread_active
|
||||
? inputs.gradient_histogram[sorted_idx[scan_begin + threadIdx.x] - offset]
|
||||
: GradientSumT();
|
||||
// No min value for cat feature, use inclusive scan.
|
||||
ScanT(temp_storage->scan).InclusiveScan(rest, rest, cub::Sum(), prefix_op);
|
||||
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
|
||||
break;
|
||||
}
|
||||
}
|
||||
return bin;
|
||||
};
|
||||
auto bin = calc_bin_value();
|
||||
// Whether the gradient of missing values is put to the left side.
|
||||
bool missing_left = true;
|
||||
float gain = null_gain;
|
||||
@ -193,10 +151,48 @@ __device__ void EvaluateFeature(
|
||||
|
||||
cub::CTA_SYNC();
|
||||
|
||||
// Best thread updates split
|
||||
// Best thread updates the split
|
||||
if (threadIdx.x == block_max.key) {
|
||||
update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs,
|
||||
best_split);
|
||||
switch (type) {
|
||||
case kNum: {
|
||||
// Use pointer from cut to indicate begin and end of bins for each feature.
|
||||
uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin
|
||||
int split_gidx = (scan_begin + threadIdx.x) - 1;
|
||||
float fvalue;
|
||||
if (split_gidx < static_cast<int>(gidx_begin)) {
|
||||
fvalue = inputs.min_fvalue[fidx];
|
||||
} else {
|
||||
fvalue = inputs.feature_values[split_gidx];
|
||||
}
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
false, inputs.param);
|
||||
break;
|
||||
}
|
||||
case kOneHot: {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = inputs.feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
|
||||
true, inputs.param);
|
||||
break;
|
||||
}
|
||||
case kPart: {
|
||||
int32_t split_gidx = (scan_begin + threadIdx.x);
|
||||
float fvalue = inputs.feature_values[split_gidx];
|
||||
GradientPairPrecise left =
|
||||
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
|
||||
GradientPairPrecise right = inputs.parent_sum - left;
|
||||
auto best_thresh = block_max.key; // index of best threshold inside a feature.
|
||||
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
|
||||
right, true, inputs.param);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cub::CTA_SYNC();
|
||||
}
|
||||
@ -206,6 +202,8 @@ template <int BLOCK_THREADS, typename GradientSumT>
|
||||
__global__ void EvaluateSplitsKernel(
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
ObjInfo task,
|
||||
common::Span<bst_feature_t> sorted_idx,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_candidates) {
|
||||
// KeyValuePair here used as threadIdx.x -> gain_value
|
||||
@ -240,22 +238,26 @@ __global__ void EvaluateSplitsKernel(
|
||||
// One block for each feature. Features are sampled, so fidx != blockIdx.x
|
||||
int fidx = inputs.feature_set[is_left ? blockIdx.x
|
||||
: blockIdx.x - left.feature_set.size()];
|
||||
|
||||
if (common::IsCat(inputs.feature_types, fidx)) {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
|
||||
TempStorage, GradientSumT,
|
||||
OneHotBin<GradientSumT, TempStorage>,
|
||||
UpdateOneHot<GradientSumT>>(fidx, inputs, evaluator, &best_split,
|
||||
&temp_storage);
|
||||
auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx];
|
||||
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kOneHot>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
|
||||
} else {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
|
||||
TempStorage, GradientSumT,
|
||||
NumericBin<GradientSumT, TempStorage, BlockScanT>,
|
||||
UpdateNumeric<GradientSumT>>(fidx, inputs, evaluator, &best_split,
|
||||
auto node_sorted_idx = is_left ? sorted_idx.first(inputs.feature_values.size())
|
||||
: sorted_idx.last(inputs.feature_values.size());
|
||||
size_t offset = is_left ? 0 : inputs.feature_values.size();
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kPart>(fidx, inputs, evaluator, node_sorted_idx, offset, &best_split,
|
||||
&temp_storage);
|
||||
}
|
||||
} else {
|
||||
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
|
||||
kNum>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
|
||||
}
|
||||
|
||||
cub::CTA_SYNC();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// Record best loss for each feature
|
||||
out_candidates[blockIdx.x] = best_split;
|
||||
@ -267,24 +269,65 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
|
||||
return b.loss_chg > a.loss_chg ? b : a;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Set the bits for categorical splits based on the split threshold.
|
||||
*/
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
|
||||
__device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
common::Span<bst_feature_t const> d_sorted_idx, bst_feature_t fidx,
|
||||
bool is_left, common::Span<common::CatBitField::value_type> out,
|
||||
DeviceSplitCandidate *p_out_split) {
|
||||
auto &out_split = *p_out_split;
|
||||
out_split.split_cats = common::CatBitField{out};
|
||||
auto node_sorted_idx =
|
||||
is_left ? d_sorted_idx.subspan(0, input.feature_values.size())
|
||||
: d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size());
|
||||
size_t node_offset = is_left ? 0 : input.feature_values.size();
|
||||
auto best_thresh = out_split.PopBestThresh();
|
||||
auto f_sorted_idx =
|
||||
node_sorted_idx.subspan(input.feature_segments[fidx], input.FeatureBins(fidx));
|
||||
if (out_split.dir != kLeftDir) {
|
||||
// forward, missing on right
|
||||
auto beg = dh::tcbegin(f_sorted_idx);
|
||||
// Don't put all the categories into one side
|
||||
auto boundary = std::min(static_cast<size_t>((best_thresh + 1)), (f_sorted_idx.size() - 1));
|
||||
boundary = std::max(boundary, static_cast<size_t>(1ul));
|
||||
auto end = beg + boundary;
|
||||
thrust::for_each(thrust::seq, beg, end, [&](auto c) {
|
||||
auto cat = input.feature_values[c - node_offset];
|
||||
assert(!out_split.split_cats.Check(cat) && "already set");
|
||||
out_split.SetCat(cat);
|
||||
});
|
||||
} else {
|
||||
assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0");
|
||||
thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx),
|
||||
dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) {
|
||||
auto cat = input.feature_values[c - node_offset];
|
||||
out_split.SetCat(cat);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
|
||||
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right) {
|
||||
size_t combined_num_features =
|
||||
left.feature_set.size() + right.feature_set.size();
|
||||
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits(
|
||||
combined_num_features);
|
||||
common::Span<DeviceSplitCandidate> out_splits) {
|
||||
if (!split_cats_.empty()) {
|
||||
this->SortHistogram(left, right, evaluator);
|
||||
}
|
||||
|
||||
size_t combined_num_features = left.feature_set.size() + right.feature_set.size();
|
||||
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits(combined_num_features);
|
||||
|
||||
// One block for each feature
|
||||
uint32_t constexpr kBlockThreads = 256;
|
||||
dh::LaunchKernel {uint32_t(combined_num_features), kBlockThreads, 0}(
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, evaluator,
|
||||
dh::ToSpan(feature_best_splits));
|
||||
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
|
||||
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, task, this->SortedIdx(left),
|
||||
evaluator, dh::ToSpan(feature_best_splits));
|
||||
|
||||
// Reduce to get best candidate for left and right child over all features
|
||||
auto reduce_offset =
|
||||
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu),
|
||||
auto reduce_offset = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) -> size_t {
|
||||
if (idx == 0) {
|
||||
return 0;
|
||||
@ -299,39 +342,102 @@ void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
|
||||
});
|
||||
size_t temp_storage_bytes = 0;
|
||||
auto num_segments = out_splits.size();
|
||||
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes,
|
||||
feature_best_splits.data(), out_splits.data(),
|
||||
num_segments, reduce_offset, reduce_offset + 1);
|
||||
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, feature_best_splits.data(),
|
||||
out_splits.data(), num_segments, reduce_offset,
|
||||
reduce_offset + 1);
|
||||
dh::TemporaryArray<int8_t> temp(temp_storage_bytes);
|
||||
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes,
|
||||
feature_best_splits.data(), out_splits.data(),
|
||||
num_segments, reduce_offset, reduce_offset + 1);
|
||||
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, feature_best_splits.data(),
|
||||
out_splits.data(), num_segments, reduce_offset,
|
||||
reduce_offset + 1);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientSumT> input) {
|
||||
EvaluateSplits(out_split, evaluator, input, {});
|
||||
void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT> const &input,
|
||||
common::Span<CatST> cats_out) {
|
||||
if (has_sort_) {
|
||||
dh::CUDAEvent event;
|
||||
event.Record(dh::DefaultStream());
|
||||
auto h_cats = this->HostCatStorage(input.nidx);
|
||||
copy_stream_.View().Wait(event);
|
||||
dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(),
|
||||
cudaMemcpyDeviceToHost, copy_stream_.View()));
|
||||
}
|
||||
}
|
||||
|
||||
template void EvaluateSplits<GradientPair>(
|
||||
common::Span<DeviceSplitCandidate> out_splits,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientPair> left,
|
||||
EvaluateSplitInputs<GradientPair> right);
|
||||
template void EvaluateSplits<GradientPairPrecise>(
|
||||
common::Span<DeviceSplitCandidate> out_splits,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientPairPrecise> left,
|
||||
EvaluateSplitInputs<GradientPairPrecise> right);
|
||||
template void EvaluateSingleSplit<GradientPair>(
|
||||
common::Span<DeviceSplitCandidate> out_split,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientPair> input);
|
||||
template void EvaluateSingleSplit<GradientPairPrecise>(
|
||||
common::Span<DeviceSplitCandidate> out_split,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientPairPrecise> input);
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<GPUExpandEntry> out_entries) {
|
||||
auto evaluator = this->tree_evaluator_.template GetEvaluator<GPUTrainingParam>();
|
||||
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(2);
|
||||
auto out_splits = dh::ToSpan(splits_out_storage);
|
||||
this->EvaluateSplits(left, right, task, evaluator, out_splits);
|
||||
|
||||
auto d_sorted_idx = this->SortedIdx(left);
|
||||
auto d_entries = out_entries;
|
||||
auto cats_out = this->DeviceCatStorage(left.nidx);
|
||||
// turn candidate into entry, along with hanlding sort based split.
|
||||
dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) {
|
||||
auto const &input = i == 0 ? left : right;
|
||||
auto &split = out_splits[i];
|
||||
auto fidx = out_splits[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
|
||||
bool is_left = i == 0;
|
||||
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
|
||||
}
|
||||
|
||||
float base_weight =
|
||||
evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum + split.right_sum});
|
||||
float left_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum});
|
||||
float right_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.right_sum});
|
||||
|
||||
d_entries[i] = GPUExpandEntry{input.nidx, candidate.depth + 1, out_splits[i],
|
||||
base_weight, left_weight, right_weight};
|
||||
});
|
||||
|
||||
this->CopyToHost(left, cats_out);
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
|
||||
EvaluateSplitInputs<GradientSumT> input, float weight, ObjInfo task) {
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
||||
auto out_split = dh::ToSpan(splits_out);
|
||||
auto evaluator = tree_evaluator_.GetEvaluator<GPUTrainingParam>();
|
||||
this->EvaluateSplits(input, {}, task, evaluator, out_split);
|
||||
|
||||
auto cats_out = this->DeviceCatStorage(input.nidx);
|
||||
auto d_sorted_idx = this->SortedIdx(input);
|
||||
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(1);
|
||||
auto d_entries = entries.data().get();
|
||||
dh::LaunchN(1, [=] __device__(size_t i) {
|
||||
auto &split = out_split[i];
|
||||
auto fidx = out_split[i].findex;
|
||||
|
||||
if (split.is_cat &&
|
||||
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
|
||||
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
|
||||
}
|
||||
|
||||
float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum});
|
||||
float right_weight = evaluator.CalcWeight(0, input.param, GradStats{split.right_sum});
|
||||
d_entries[0] = GPUExpandEntry(0, 0, split, weight, left_weight, right_weight);
|
||||
});
|
||||
this->CopyToHost(input, cats_out);
|
||||
|
||||
GPUExpandEntry root_entry;
|
||||
dh::safe_cuda(cudaMemcpyAsync(&root_entry, entries.data().get(),
|
||||
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
template class GPUHistEvaluator<GradientPair>;
|
||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -3,15 +3,20 @@
|
||||
*/
|
||||
#ifndef EVALUATE_SPLITS_CUH_
|
||||
#define EVALUATE_SPLITS_CUH_
|
||||
#include <thrust/system/cuda/experimental/pinned_allocator.h>
|
||||
#include <xgboost/span.h>
|
||||
#include "../../data/ellpack_page.cuh"
|
||||
|
||||
#include "../../common/categorical.h"
|
||||
#include "../split_evaluator.h"
|
||||
#include "../constraints.cuh"
|
||||
#include "../updater_gpu_common.cuh"
|
||||
#include "expand_entry.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace common {
|
||||
class HistogramCuts;
|
||||
}
|
||||
|
||||
namespace tree {
|
||||
template <typename GradientSumT>
|
||||
struct EvaluateSplitInputs {
|
||||
int nidx;
|
||||
@ -23,16 +28,131 @@ struct EvaluateSplitInputs {
|
||||
common::Span<const float> feature_values;
|
||||
common::Span<const float> min_fvalue;
|
||||
common::Span<const GradientSumT> gradient_histogram;
|
||||
|
||||
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
|
||||
__device__ auto FeatureBins(bst_feature_t fidx) const {
|
||||
return feature_segments[fidx + 1] - feature_segments[fidx];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits,
|
||||
class GPUHistEvaluator {
|
||||
using CatST = common::CatBitField::value_type; // categorical storage type
|
||||
// use pinned memory to stage the categories, used for sort based splits.
|
||||
using Alloc = thrust::system::cuda::experimental::pinned_allocator<CatST>;
|
||||
|
||||
private:
|
||||
TreeEvaluator tree_evaluator_;
|
||||
// storage for categories for each node, used for sort based splits.
|
||||
dh::device_vector<CatST> split_cats_;
|
||||
// host storage for categories for each node, used for sort based splits.
|
||||
std::vector<CatST, Alloc> h_split_cats_;
|
||||
// stream for copying categories from device back to host for expanding the decision tree.
|
||||
dh::CUDAStream copy_stream_;
|
||||
// storage for sorted index of feature histogram, used for sort based splits.
|
||||
dh::device_vector<bst_feature_t> cat_sorted_idx_;
|
||||
TrainParam param_;
|
||||
// whether the input data requires sort based split, which is more complicated so we try
|
||||
// to avoid it if possible.
|
||||
bool has_sort_{false};
|
||||
|
||||
// Copy the categories from device to host asynchronously.
|
||||
void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out);
|
||||
|
||||
/**
|
||||
* \brief Get host category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto HostCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get device category storage of nidx for internal calculation.
|
||||
*/
|
||||
auto DeviceCatStorage(bst_node_t nidx) {
|
||||
auto cat_bits = split_cats_.size() / param_.MaxNodes();
|
||||
if (nidx == RegTree::kRoot) {
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits);
|
||||
return cats_out;
|
||||
}
|
||||
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2);
|
||||
return cats_out;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get sorted index storage based on the left node of inputs .
|
||||
*/
|
||||
auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) {
|
||||
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
|
||||
return dh::ToSpan(cat_sorted_idx_).first(left.feature_values.size());
|
||||
}
|
||||
return dh::ToSpan(cat_sorted_idx_);
|
||||
}
|
||||
|
||||
public:
|
||||
GPUHistEvaluator(TrainParam const ¶m, bst_feature_t n_features, int32_t device)
|
||||
: tree_evaluator_{param, n_features, device}, param_{param} {}
|
||||
/**
|
||||
* \brief Reset the evaluator, should be called before any use.
|
||||
*/
|
||||
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, ObjInfo task,
|
||||
bst_feature_t n_features, TrainParam const ¶m, int32_t device);
|
||||
|
||||
/**
|
||||
* \brief Get host category storage for nidx. Different from the internal version, this
|
||||
* returns strictly 1 node.
|
||||
*/
|
||||
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
|
||||
copy_stream_.View().Sync();
|
||||
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
|
||||
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
|
||||
return cats_out;
|
||||
}
|
||||
/**
|
||||
* \brief Add a split to the internal tree evaluator.
|
||||
*/
|
||||
void ApplyTreeSplit(GPUExpandEntry const &candidate, RegTree *p_tree) {
|
||||
auto &tree = *p_tree;
|
||||
// Set up child constraints
|
||||
auto left_child = tree[candidate.nid].LeftChild();
|
||||
auto right_child = tree[candidate.nid].RightChild();
|
||||
tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
|
||||
tree[candidate.nid].SplitIndex(), candidate.left_weight,
|
||||
candidate.right_weight);
|
||||
}
|
||||
|
||||
auto GetEvaluator() { return tree_evaluator_.GetEvaluator<GPUTrainingParam>(); }
|
||||
/**
|
||||
* \brief Sort the histogram based on output to obtain contiguous partitions.
|
||||
*/
|
||||
common::Span<bst_feature_t const> SortHistogram(
|
||||
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator);
|
||||
|
||||
// impl of evaluate splits, contains CUDA kernels so it's public
|
||||
void EvaluateSplits(EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
common::Span<DeviceSplitCandidate> out_splits);
|
||||
/**
|
||||
* \brief Evaluate splits for left and right nodes.
|
||||
*/
|
||||
void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
|
||||
EvaluateSplitInputs<GradientSumT> left,
|
||||
EvaluateSplitInputs<GradientSumT> right);
|
||||
template <typename GradientSumT>
|
||||
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
|
||||
EvaluateSplitInputs<GradientSumT> input);
|
||||
EvaluateSplitInputs<GradientSumT> right,
|
||||
common::Span<GPUExpandEntry> out_splits);
|
||||
/**
|
||||
* \brief Evaluate splits for root node.
|
||||
*/
|
||||
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight,
|
||||
ObjInfo task);
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
100
src/tree/gpu_hist/evaluator.cu
Normal file
100
src/tree/gpu_hist/evaluator.cu
Normal file
@ -0,0 +1,100 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*
|
||||
* \brief Some components of GPU Hist evaluator, this file only exist to reduce nvcc
|
||||
* compilation time.
|
||||
*/
|
||||
#include <thrust/logical.h> // thrust::any_of
|
||||
#include <thrust/sort.h> // thrust::stable_sort
|
||||
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#include "../../common/hist_util.h" // common::HistogramCuts
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
template <typename GradientSumT>
|
||||
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
|
||||
common::Span<FeatureType const> ft, ObjInfo task,
|
||||
bst_feature_t n_features, TrainParam const ¶m,
|
||||
int32_t device) {
|
||||
param_ = param;
|
||||
tree_evaluator_ = TreeEvaluator{param, n_features, device};
|
||||
if (cuts.HasCategorical() && !task.UseOneHot()) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
|
||||
auto beg = thrust::make_counting_iterator<size_t>(1ul);
|
||||
auto end = thrust::make_counting_iterator<size_t>(ptrs.size());
|
||||
auto to_onehot = param.max_cat_to_onehot;
|
||||
// This condition avoids sort-based split function calls if the users want
|
||||
// onehot-encoding-based splits.
|
||||
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
|
||||
has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
|
||||
auto idx = i - 1;
|
||||
if (common::IsCat(ft, idx)) {
|
||||
auto n_bins = ptrs[i] - ptrs[idx];
|
||||
bool use_sort = !common::UseOneHot(n_bins, to_onehot, task);
|
||||
return use_sort;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (has_sort_) {
|
||||
auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
|
||||
CHECK_NE(bit_storage_size, 0);
|
||||
// We need to allocate for all nodes since the updater can grow the tree layer by
|
||||
// layer, all nodes in the same layer must be preserved until that layer is
|
||||
// finished. We can allocate one layer at a time, but the best case is reducing the
|
||||
// size of the bitset by about a half, at the cost of invoking CUDA malloc many more
|
||||
// times than necessary.
|
||||
split_cats_.resize(param.MaxNodes() * bit_storage_size);
|
||||
h_split_cats_.resize(split_cats_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
|
||||
|
||||
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GradientSumT>
|
||||
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
|
||||
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
|
||||
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
auto sorted_idx = this->SortedIdx(left);
|
||||
dh::Iota(sorted_idx);
|
||||
// sort 2 nodes and all the features at the same time, disregarding colmun sampling.
|
||||
thrust::stable_sort(
|
||||
thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx),
|
||||
[evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) {
|
||||
auto l_is_left = l < left.feature_values.size();
|
||||
auto r_is_left = r < left.feature_values.size();
|
||||
if (l_is_left != r_is_left) {
|
||||
return l_is_left; // not the same node
|
||||
}
|
||||
|
||||
auto const &input = l_is_left ? left : right;
|
||||
l -= (l_is_left ? 0 : input.feature_values.size());
|
||||
r -= (r_is_left ? 0 : input.feature_values.size());
|
||||
|
||||
auto lfidx = dh::SegmentId(input.feature_segments, l);
|
||||
auto rfidx = dh::SegmentId(input.feature_segments, r);
|
||||
if (lfidx != rfidx) {
|
||||
return lfidx < rfidx; // not the same feature
|
||||
}
|
||||
if (common::IsCat(input.feature_types, lfidx)) {
|
||||
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]);
|
||||
auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]);
|
||||
return lw < rw;
|
||||
}
|
||||
return l < r;
|
||||
});
|
||||
return dh::ToSpan(cat_sorted_idx_);
|
||||
}
|
||||
|
||||
template class GPUHistEvaluator<GradientPair>;
|
||||
template class GPUHistEvaluator<GradientPairPrecise>;
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@ -4,8 +4,9 @@
|
||||
#ifndef EXPAND_ENTRY_CUH_
|
||||
#define EXPAND_ENTRY_CUH_
|
||||
#include <xgboost/span.h>
|
||||
|
||||
#include "../param.h"
|
||||
#include "evaluate_splits.cuh"
|
||||
#include "../updater_gpu_common.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
@ -53,7 +53,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 };
|
||||
|
||||
// Enumerate/Scan the split values of specific feature
|
||||
// Returns the sum of gradients corresponding to the data points that contains
|
||||
@ -137,7 +136,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
|
||||
GradStats{right_sum}) -
|
||||
parent.root_gain);
|
||||
split_pt = cut_val[i];
|
||||
split_pt = cut_val[i]; // not used for partition based
|
||||
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
|
||||
left_sum, right_sum);
|
||||
} else {
|
||||
@ -180,10 +179,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
|
||||
if (d_step == 1) {
|
||||
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
|
||||
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||
[&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
|
||||
} else {
|
||||
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
|
||||
[&cat_bits](size_t c) { cat_bits.Set(c); });
|
||||
[&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
|
||||
}
|
||||
}
|
||||
p_best->Update(best);
|
||||
@ -231,6 +230,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
}
|
||||
}
|
||||
auto evaluator = tree_evaluator_.GetEvaluator();
|
||||
auto const& cut_ptrs = cut.Ptrs();
|
||||
|
||||
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
|
||||
auto tidx = omp_get_thread_num();
|
||||
@ -246,27 +246,23 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
continue;
|
||||
}
|
||||
if (is_cat) {
|
||||
auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx];
|
||||
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
||||
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
|
||||
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
} else {
|
||||
auto const &cut_ptr = cut.Ptrs();
|
||||
std::vector<size_t> sorted_idx(n_bins);
|
||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||
auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins);
|
||||
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
|
||||
// Sort the histogram to get contiguous partitions.
|
||||
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
|
||||
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
|
||||
evaluator.CalcWeightCat(param_, feat_hist[r]);
|
||||
static_assert(std::is_same<decltype(ret), bool>::value, "");
|
||||
return ret;
|
||||
});
|
||||
auto grad_stats =
|
||||
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
|
||||
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto grad_stats =
|
||||
EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
||||
@ -313,6 +309,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
cat_bits.Set(cat);
|
||||
} else {
|
||||
split_cats = candidate.split.cat_bits;
|
||||
common::CatBitField cat_bits{split_cats};
|
||||
}
|
||||
|
||||
tree.ExpandCategorical(
|
||||
|
||||
@ -110,6 +110,9 @@ class TreeEvaluator {
|
||||
|
||||
template <typename GradientSumT>
|
||||
XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const {
|
||||
// FIXME(jiamingy): This is a temporary solution until we have categorical feature
|
||||
// specific regularization parameters. During sorting we should try to avoid any
|
||||
// regularization.
|
||||
return ::xgboost::tree::CalcWeight(param, stats);
|
||||
}
|
||||
|
||||
@ -180,6 +183,15 @@ class TreeEvaluator {
|
||||
.Eval(&lower_bounds_, &upper_bounds_, &monotone_);
|
||||
}
|
||||
};
|
||||
|
||||
enum SplitType {
|
||||
// numerical split
|
||||
kNum = 0,
|
||||
// onehot encoding based categorical split
|
||||
kOneHot = 1,
|
||||
// partition-based categorical split
|
||||
kPart = 2
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/random.h"
|
||||
#include "param.h"
|
||||
@ -27,6 +28,7 @@ struct GPUTrainingParam {
|
||||
// default=0 means no constraint on weight delta
|
||||
float max_delta_step;
|
||||
float learning_rate;
|
||||
uint32_t max_cat_to_onehot;
|
||||
|
||||
GPUTrainingParam() = default;
|
||||
|
||||
@ -35,14 +37,10 @@ struct GPUTrainingParam {
|
||||
reg_lambda(param.reg_lambda),
|
||||
reg_alpha(param.reg_alpha),
|
||||
max_delta_step(param.max_delta_step),
|
||||
learning_rate{param.learning_rate} {}
|
||||
learning_rate{param.learning_rate},
|
||||
max_cat_to_onehot{param.max_cat_to_onehot} {}
|
||||
};
|
||||
|
||||
using NodeIdT = int32_t;
|
||||
|
||||
/** used to assign default id to a Node */
|
||||
static const bst_node_t kUnusedNode = -1;
|
||||
|
||||
/**
|
||||
* @enum DefaultDirection node.cuh
|
||||
* @brief Default direction to be followed in case of missing values
|
||||
@ -59,6 +57,8 @@ struct DeviceSplitCandidate {
|
||||
DefaultDirection dir {kLeftDir};
|
||||
int findex {-1};
|
||||
float fvalue {0};
|
||||
|
||||
common::CatBitField split_cats;
|
||||
bool is_cat { false };
|
||||
|
||||
GradientPairPrecise left_sum;
|
||||
@ -75,6 +75,28 @@ struct DeviceSplitCandidate {
|
||||
*this = other;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* \brief The largest encoded category in the split bitset
|
||||
*/
|
||||
bst_cat_t MaxCat() const {
|
||||
// Reuse the fvalue for categorical values.
|
||||
return static_cast<bst_cat_t>(fvalue);
|
||||
}
|
||||
/**
|
||||
* \brief Return the best threshold for cat split, reset the value after return.
|
||||
*/
|
||||
XGBOOST_DEVICE size_t PopBestThresh() {
|
||||
// fvalue is also being used for storing the threshold for categorical split
|
||||
auto best_thresh = static_cast<size_t>(this->fvalue);
|
||||
this->fvalue = 0;
|
||||
return best_thresh;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE void SetCat(T c) {
|
||||
this->split_cats.Set(common::AsCat(c));
|
||||
fvalue = std::max(this->fvalue, static_cast<float>(c));
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in,
|
||||
float fvalue_in, int findex_in,
|
||||
@ -108,18 +130,6 @@ struct DeviceSplitCandidate {
|
||||
}
|
||||
};
|
||||
|
||||
struct DeviceSplitCandidateReduceOp {
|
||||
GPUTrainingParam param;
|
||||
explicit DeviceSplitCandidateReduceOp(GPUTrainingParam param) : param(std::move(param)) {}
|
||||
XGBOOST_DEVICE DeviceSplitCandidate operator()(
|
||||
const DeviceSplitCandidate& a, const DeviceSplitCandidate& b) const {
|
||||
DeviceSplitCandidate best;
|
||||
best.Update(a, param);
|
||||
best.Update(b, param);
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SumCallbackOp {
|
||||
// Running prefix
|
||||
|
||||
@ -159,6 +159,10 @@ class DeviceHistogram {
|
||||
// Manage memory for a single GPU
|
||||
template <typename GradientSumT>
|
||||
struct GPUHistMakerDevice {
|
||||
private:
|
||||
GPUHistEvaluator<GradientSumT> evaluator_;
|
||||
|
||||
public:
|
||||
int device_id;
|
||||
EllpackPageImpl const* page;
|
||||
common::Span<FeatureType const> feature_types;
|
||||
@ -182,7 +186,6 @@ struct GPUHistMakerDevice {
|
||||
dh::PinnedMemory pinned;
|
||||
|
||||
common::Monitor monitor;
|
||||
TreeEvaluator tree_evaluator;
|
||||
common::ColumnSampler column_sampler;
|
||||
FeatureInteractionConstraintDevice interaction_constraints;
|
||||
|
||||
@ -192,24 +195,20 @@ struct GPUHistMakerDevice {
|
||||
// Storing split categories for last node.
|
||||
dh::caching_device_vector<uint32_t> node_categories;
|
||||
|
||||
GPUHistMakerDevice(int _device_id,
|
||||
EllpackPageImpl const* _page,
|
||||
common::Span<FeatureType const> _feature_types,
|
||||
bst_uint _n_rows,
|
||||
TrainParam _param,
|
||||
uint32_t column_sampler_seed,
|
||||
uint32_t n_features,
|
||||
GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page,
|
||||
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
|
||||
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
|
||||
BatchParam _batch_param)
|
||||
: device_id(_device_id),
|
||||
: evaluator_{_param, n_features, _device_id},
|
||||
device_id(_device_id),
|
||||
page(_page),
|
||||
feature_types{_feature_types},
|
||||
param(std::move(_param)),
|
||||
tree_evaluator(param, n_features, _device_id),
|
||||
column_sampler(column_sampler_seed),
|
||||
interaction_constraints(param, n_features),
|
||||
batch_param(std::move(_batch_param)) {
|
||||
sampler.reset(new GradientBasedSampler(
|
||||
page, _n_rows, batch_param, param.subsample, param.sampling_method));
|
||||
sampler.reset(new GradientBasedSampler(page, _n_rows, batch_param, param.subsample,
|
||||
param.sampling_method));
|
||||
if (!param.monotone_constraints.empty()) {
|
||||
// Copy assigning an empty vector causes an exception in MSVC debug builds
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
@ -219,9 +218,8 @@ struct GPUHistMakerDevice {
|
||||
// Init histogram
|
||||
hist.Init(device_id, page->Cuts().TotalBins());
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
|
||||
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense,
|
||||
dh::MaxSharedMemoryOptin(device_id),
|
||||
sizeof(GradientSumT)));
|
||||
feature_groups.reset(new FeatureGroups(
|
||||
page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT)));
|
||||
}
|
||||
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
@ -231,13 +229,17 @@ struct GPUHistMakerDevice {
|
||||
// Reset values for each update iteration
|
||||
// Note that the column sampler must be passed by value because it is not
|
||||
// thread safe
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns,
|
||||
ObjInfo task) {
|
||||
auto const& info = dmat->Info();
|
||||
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
|
||||
param.colsample_bynode, param.colsample_bylevel,
|
||||
param.colsample_bytree);
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id);
|
||||
|
||||
this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param,
|
||||
device_id);
|
||||
|
||||
this->interaction_constraints.Reset();
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
|
||||
|
||||
@ -258,10 +260,8 @@ struct GPUHistMakerDevice {
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
|
||||
DeviceSplitCandidate EvaluateRootSplit(GradientPairPrecise root_sum) {
|
||||
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) {
|
||||
int nidx = RegTree::kRoot;
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto sampled_features = column_sampler.GetFeatureSet(0);
|
||||
sampled_features->SetDevice(device_id);
|
||||
@ -277,32 +277,23 @@ struct GPUHistMakerDevice {
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
hist.GetNodeHistogram(nidx)};
|
||||
auto gain_calc = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs);
|
||||
std::vector<DeviceSplitCandidate> result(1);
|
||||
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
|
||||
sizeof(DeviceSplitCandidate) * splits_out.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return result.front();
|
||||
auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task);
|
||||
return split;
|
||||
}
|
||||
|
||||
void EvaluateLeftRightSplits(
|
||||
GPUExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
|
||||
void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx,
|
||||
int right_nidx, const RegTree& tree,
|
||||
common::Span<GPUExpandEntry> pinned_candidates_out) {
|
||||
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto left_sampled_features =
|
||||
column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
|
||||
auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
|
||||
left_sampled_features->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> left_feature_set =
|
||||
interaction_constraints.Query(left_sampled_features->DeviceSpan(),
|
||||
left_nidx);
|
||||
auto right_sampled_features =
|
||||
column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
|
||||
interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
|
||||
auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
|
||||
right_sampled_features->SetDevice(device_id);
|
||||
common::Span<bst_feature_t> right_feature_set =
|
||||
interaction_constraints.Query(right_sampled_features->DeviceSpan(),
|
||||
left_nidx);
|
||||
interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx);
|
||||
auto matrix = page->GetDeviceAccessor(device_id);
|
||||
|
||||
EvaluateSplitInputs<GradientSumT> left{left_nidx,
|
||||
@ -323,28 +314,10 @@ struct GPUHistMakerDevice {
|
||||
matrix.gidx_fvalue_map,
|
||||
matrix.min_fvalue,
|
||||
hist.GetNodeHistogram(right_nidx)};
|
||||
auto d_splits_out = dh::ToSpan(splits_out);
|
||||
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
|
||||
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(2);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
auto d_entries = entries.data().get();
|
||||
dh::LaunchN(2, [=] __device__(size_t idx) {
|
||||
auto split = d_splits_out[idx];
|
||||
auto nidx = idx == 0 ? left_nidx : right_nidx;
|
||||
|
||||
float base_weight = evaluator.CalcWeight(
|
||||
nidx, gpu_param, GradStats{split.left_sum + split.right_sum});
|
||||
float left_weight =
|
||||
evaluator.CalcWeight(nidx, gpu_param, GradStats{split.left_sum});
|
||||
float right_weight = evaluator.CalcWeight(
|
||||
nidx, gpu_param, GradStats{split.right_sum});
|
||||
|
||||
d_entries[idx] =
|
||||
GPUExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx],
|
||||
base_weight, left_weight, right_weight};
|
||||
});
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
pinned_candidates_out.data(), entries.data().get(),
|
||||
this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries));
|
||||
dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(),
|
||||
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
@ -369,12 +342,10 @@ struct GPUHistMakerDevice {
|
||||
});
|
||||
}
|
||||
|
||||
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||
int nidx_subtraction) {
|
||||
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
|
||||
// Make sure histograms are already allocated
|
||||
hist.AllocateHistogram(nidx_subtraction);
|
||||
return hist.HistogramExists(nidx_histogram) &&
|
||||
hist.HistogramExists(nidx_parent);
|
||||
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
|
||||
}
|
||||
|
||||
void UpdatePosition(int nidx, RegTree* p_tree) {
|
||||
@ -503,13 +474,12 @@ struct GPUHistMakerDevice {
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = row_partitioner->GetPosition();
|
||||
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
auto tree_evaluator = evaluator_.GetEvaluator();
|
||||
|
||||
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(
|
||||
int local_idx) mutable {
|
||||
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = evaluator.CalcWeight(
|
||||
pos, param_d, GradStats{d_node_sum_gradients[pos]});
|
||||
bst_float weight =
|
||||
tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]});
|
||||
static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
|
||||
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
|
||||
});
|
||||
@ -562,7 +532,6 @@ struct GPUHistMakerDevice {
|
||||
|
||||
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
auto evaluator = tree_evaluator.GetEvaluator();
|
||||
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
|
||||
auto base_weight = candidate.base_weight;
|
||||
auto left_weight = candidate.left_weight * param.learning_rate;
|
||||
@ -572,48 +541,50 @@ struct GPUHistMakerDevice {
|
||||
if (is_cat) {
|
||||
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
|
||||
<< "Categorical feature value too large.";
|
||||
std::vector<uint32_t> split_cats;
|
||||
if (candidate.split.split_cats.Bits().empty()) {
|
||||
if (common::InvalidCat(candidate.split.fvalue)) {
|
||||
common::InvalidCategory();
|
||||
}
|
||||
auto cat = common::AsCat(candidate.split.fvalue);
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
|
||||
LBitField32 cats_bits(split_cats);
|
||||
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0);
|
||||
common::CatBitField cats_bits(split_cats);
|
||||
cats_bits.Set(cat);
|
||||
dh::CopyToD(split_cats, &node_categories);
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight,
|
||||
right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(),
|
||||
candidate.split.right_sum.GetHess());
|
||||
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
|
||||
auto max_cat = candidate.split.MaxCat();
|
||||
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
|
||||
CHECK_LE(split_cats.size(), h_cats.size());
|
||||
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
|
||||
|
||||
node_categories.resize(candidate.split.split_cats.Bits().size());
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
node_categories.data().get(), candidate.split.split_cats.Data(),
|
||||
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// Set up child constraints
|
||||
auto left_child = tree[candidate.nid].LeftChild();
|
||||
auto right_child = tree[candidate.nid].RightChild();
|
||||
tree.ExpandCategorical(
|
||||
candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
} else {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
||||
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.GetHess(),
|
||||
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
|
||||
}
|
||||
evaluator_.ApplyTreeSplit(candidate, p_tree);
|
||||
|
||||
tree_evaluator.AddSplit(candidate.nid, left_child, right_child,
|
||||
tree[candidate.nid].SplitIndex(), candidate.left_weight,
|
||||
candidate.right_weight);
|
||||
node_sum_gradients[tree[candidate.nid].LeftChild()] =
|
||||
candidate.split.left_sum;
|
||||
node_sum_gradients[tree[candidate.nid].RightChild()] =
|
||||
candidate.split.right_sum;
|
||||
node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
|
||||
node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
|
||||
|
||||
interaction_constraints.Split(
|
||||
candidate.nid, tree[candidate.nid].SplitIndex(),
|
||||
interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
|
||||
tree[candidate.nid].LeftChild(),
|
||||
tree[candidate.nid].RightChild());
|
||||
}
|
||||
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) {
|
||||
constexpr bst_node_t kRootNIdx = 0;
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
|
||||
@ -634,39 +605,21 @@ struct GPUHistMakerDevice {
|
||||
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
|
||||
|
||||
// Generate first split
|
||||
auto split = this->EvaluateRootSplit(root_sum);
|
||||
dh::TemporaryArray<GPUExpandEntry> entries(1);
|
||||
auto d_entries = entries.data().get();
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
GPUTrainingParam gpu_param(param);
|
||||
auto depth = p_tree->GetDepth(kRootNIdx);
|
||||
dh::LaunchN(1, [=] __device__(size_t idx) {
|
||||
float left_weight = evaluator.CalcWeight(kRootNIdx, gpu_param,
|
||||
GradStats{split.left_sum});
|
||||
float right_weight = evaluator.CalcWeight(
|
||||
kRootNIdx, gpu_param, GradStats{split.right_sum});
|
||||
d_entries[0] =
|
||||
GPUExpandEntry(kRootNIdx, depth, split,
|
||||
weight, left_weight, right_weight);
|
||||
});
|
||||
GPUExpandEntry root_entry;
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
&root_entry, entries.data().get(),
|
||||
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
|
||||
auto root_entry = this->EvaluateRootSplit(root_sum, weight, task);
|
||||
return root_entry;
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
|
||||
RegTree* p_tree, dh::AllReducer* reducer) {
|
||||
auto& tree = *p_tree;
|
||||
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
|
||||
|
||||
monitor.Start("Reset");
|
||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task);
|
||||
monitor.Stop("Reset");
|
||||
|
||||
monitor.Start("InitRoot");
|
||||
driver.Push({ this->InitRoot(p_tree, reducer) });
|
||||
driver.Push({ this->InitRoot(p_tree, task, reducer) });
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
auto num_leaves = 1;
|
||||
@ -703,8 +656,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
|
||||
right_child_nidx, *p_tree,
|
||||
this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree,
|
||||
new_candidates.subspan(i * 2, 2));
|
||||
monitor.Stop("EvaluateSplits");
|
||||
} else {
|
||||
@ -819,14 +771,13 @@ class GPUHistMakerSpecialised {
|
||||
CHECK(*local_tree == reference_tree);
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree) {
|
||||
monitor_.Start("InitData");
|
||||
this->InitData(p_fmat);
|
||||
monitor_.Stop("InitData");
|
||||
|
||||
gpair->SetDevice(device_);
|
||||
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_);
|
||||
maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix *data,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
* Copyright 2019-2022 by XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <gtest/gtest.h>
|
||||
@ -235,6 +235,7 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins,
|
||||
|
||||
ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
|
||||
auto cuts = sketch(dmat.get(), num_bins);
|
||||
ASSERT_EQ(cuts.MaxCategory(), num_categories - 1);
|
||||
std::sort(x.begin(), x.end());
|
||||
auto n_uniques = std::unique(x.begin(), x.end()) - x.begin();
|
||||
ASSERT_NE(n_uniques, x.size());
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2020-2022 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
|
||||
#include "../../helpers.h"
|
||||
#include "../../histogram_helpers.h"
|
||||
#include "../test_evaluate_splits.h" // TestPartitionBasedSplit
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -16,7 +20,6 @@ auto ZeroParam() {
|
||||
} // anonymous namespace
|
||||
|
||||
void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
@ -50,11 +53,13 @@ void TestEvaluateSingleSplit(bool is_categorical) {
|
||||
dh::ToSpan(feature_values),
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input);
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
GPUHistEvaluator<GradientPair> evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
dh::device_vector<common::CatBitField::value_type> out_cats;
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
|
||||
@ -72,7 +77,6 @@ TEST(GpuHist, EvaluateCategoricalSplit) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPairPrecise parent_sum(1.0, 1.5);
|
||||
TrainParam tparam = ZeroParam();
|
||||
GPUTrainingParam param{tparam};
|
||||
@ -96,11 +100,10 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
|
||||
TreeEvaluator tree_evaluator(tparam, feature_set.size(), 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input);
|
||||
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_set.size(), 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
EXPECT_EQ(result.dir, kRightDir);
|
||||
@ -109,27 +112,18 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
|
||||
}
|
||||
|
||||
TEST(GpuHist, EvaluateSingleSplitEmpty) {
|
||||
DeviceSplitCandidate nonzeroed;
|
||||
nonzeroed.findex = 1;
|
||||
nonzeroed.loss_chg = 1.0;
|
||||
|
||||
thrust::device_vector<DeviceSplitCandidate> out_split(1);
|
||||
out_split[0] = nonzeroed;
|
||||
|
||||
TrainParam tparam = ZeroParam();
|
||||
TreeEvaluator tree_evaluator(tparam, 1, 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(out_split), evaluator,
|
||||
EvaluateSplitInputs<GradientPair>{});
|
||||
|
||||
DeviceSplitCandidate result = out_split[0];
|
||||
GPUHistEvaluator<GradientPair> evaluator(tparam, 1, 0);
|
||||
DeviceSplitCandidate result = evaluator
|
||||
.EvaluateSingleSplit(EvaluateSplitInputs<GradientPair>{}, 0,
|
||||
ObjInfo{ObjInfo::kRegression})
|
||||
.split;
|
||||
EXPECT_EQ(result.findex, -1);
|
||||
EXPECT_LT(result.loss_chg, 0.0f);
|
||||
}
|
||||
|
||||
// Feature 0 has a better split, but the algorithm must select feature 1
|
||||
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
@ -157,11 +151,10 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
|
||||
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input);
|
||||
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 1);
|
||||
EXPECT_EQ(result.fvalue, 11.0);
|
||||
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
|
||||
@ -170,7 +163,6 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
|
||||
|
||||
// Features 0 and 1 have identical gain, the algorithm must select 0
|
||||
TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
|
||||
GradientPairPrecise parent_sum(0.0, 1.0);
|
||||
TrainParam tparam = ZeroParam();
|
||||
tparam.UpdateAllowUnknown(Args{});
|
||||
@ -198,11 +190,10 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram)};
|
||||
|
||||
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input);
|
||||
GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
|
||||
DeviceSplitCandidate result =
|
||||
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
|
||||
DeviceSplitCandidate result = out_splits[0];
|
||||
EXPECT_EQ(result.findex, 0);
|
||||
EXPECT_EQ(result.fvalue, 1.0);
|
||||
}
|
||||
@ -250,9 +241,10 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
dh::ToSpan(feature_min_values),
|
||||
dh::ToSpan(feature_histogram_right)};
|
||||
|
||||
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0);
|
||||
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
|
||||
EvaluateSplits(dh::ToSpan(out_splits), evaluator, input_left, input_right);
|
||||
GPUHistEvaluator<GradientPair> evaluator{
|
||||
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
|
||||
evaluator.EvaluateSplits(input_left, input_right, ObjInfo{ObjInfo::kRegression},
|
||||
evaluator.GetEvaluator(), dh::ToSpan(out_splits));
|
||||
|
||||
DeviceSplitCandidate result_left = out_splits[0];
|
||||
EXPECT_EQ(result_left.findex, 1);
|
||||
@ -262,5 +254,36 @@ TEST(GpuHist, EvaluateSplits) {
|
||||
EXPECT_EQ(result_right.findex, 0);
|
||||
EXPECT_EQ(result_right.fvalue, 1.0);
|
||||
}
|
||||
|
||||
TEST_F(TestPartitionBasedSplit, GpuHist) {
|
||||
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
|
||||
GPUHistEvaluator<GradientPairPrecise> evaluator{param_,
|
||||
static_cast<bst_feature_t>(info_.num_col_), 0};
|
||||
|
||||
cuts_.cut_ptrs_.SetDevice(0);
|
||||
cuts_.cut_values_.SetDevice(0);
|
||||
cuts_.min_vals_.SetDevice(0);
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0);
|
||||
|
||||
dh::device_vector<GradientPairPrecise> d_hist(hist_[0].size());
|
||||
auto node_hist = hist_[0];
|
||||
dh::safe_cuda(cudaMemcpy(d_hist.data().get(), node_hist.data(), node_hist.size_bytes(),
|
||||
cudaMemcpyHostToDevice));
|
||||
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
|
||||
|
||||
EvaluateSplitInputs<GradientPairPrecise> input{0,
|
||||
total_gpair_,
|
||||
GPUTrainingParam{param_},
|
||||
dh::ToSpan(feature_set),
|
||||
dh::ToSpan(ft),
|
||||
cuts_.cut_ptrs_.ConstDeviceSpan(),
|
||||
cuts_.cut_values_.ConstDeviceSpan(),
|
||||
cuts_.min_vals_.ConstDeviceSpan(),
|
||||
dh::ToSpan(d_hist)};
|
||||
auto split = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/base.h>
|
||||
|
||||
#include "../../../../src/common/hist_util.h"
|
||||
#include "../../../../src/tree/hist/evaluate_splits.h"
|
||||
#include "../../../../src/tree/updater_quantile_hist.h"
|
||||
#include "../../../../src/common/hist_util.h"
|
||||
#include "../test_evaluate_splits.h"
|
||||
#include "../../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -108,80 +110,17 @@ TEST(HistEvaluator, Apply) {
|
||||
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
|
||||
}
|
||||
|
||||
TEST(HistEvaluator, CategoricalPartition) {
|
||||
int static constexpr kRows = 128, kCols = 1;
|
||||
using GradientSumT = double;
|
||||
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
|
||||
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||
|
||||
size_t n_cats{8};
|
||||
|
||||
auto dmat =
|
||||
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
|
||||
|
||||
int32_t n_threads = 16;
|
||||
TEST_F(TestPartitionBasedSplit, CPUHist) {
|
||||
// check the evaluator is returning the optimal split
|
||||
std::vector<FeatureType> ft{FeatureType::kCategorical};
|
||||
auto sampler = std::make_shared<common::ColumnSampler>();
|
||||
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{
|
||||
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}};
|
||||
|
||||
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
|
||||
common::HistCollection<GradientSumT> hist;
|
||||
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
entries.front().nid = 0;
|
||||
entries.front().depth = 0;
|
||||
|
||||
hist.Init(gmat.cut.TotalBins());
|
||||
hist.AddHistRow(0);
|
||||
hist.AllocateAllData();
|
||||
auto node_hist = hist[0];
|
||||
ASSERT_EQ(node_hist.size(), n_cats);
|
||||
ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back());
|
||||
|
||||
GradientPairPrecise total_gpair;
|
||||
for (size_t i = 0; i < node_hist.size(); ++i) {
|
||||
node_hist[i] = {static_cast<double>(node_hist.size() - i), 1.0};
|
||||
total_gpair += node_hist[i];
|
||||
}
|
||||
SimpleLCG lcg;
|
||||
std::shuffle(node_hist.begin(), node_hist.end(), lcg);
|
||||
|
||||
HistEvaluator<double, CPUExpandEntry> evaluator{param_, info_, common::OmpGetNumThreads(0),
|
||||
sampler, ObjInfo{ObjInfo::kRegression}};
|
||||
evaluator.InitRoot(GradStats{total_gpair_});
|
||||
RegTree tree;
|
||||
evaluator.InitRoot(GradStats{total_gpair});
|
||||
evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries);
|
||||
ASSERT_TRUE(entries.front().split.is_cat);
|
||||
|
||||
auto run_eval = [&](auto fn) {
|
||||
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
|
||||
GradStats left, right;
|
||||
for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) {
|
||||
auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
|
||||
evaluator.Stats().front().root_gain;
|
||||
fn(loss_chg);
|
||||
left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess());
|
||||
right.SetSubstract(GradStats{total_gpair}, left);
|
||||
}
|
||||
}
|
||||
};
|
||||
// Assert that's the best split
|
||||
auto best_loss_chg = entries.front().split.loss_chg;
|
||||
run_eval([&](auto loss_chg) {
|
||||
// Approximated test that gain returned by optimal partition is greater than
|
||||
// numerical split.
|
||||
ASSERT_GT(best_loss_chg, loss_chg);
|
||||
});
|
||||
// node_hist is captured in lambda.
|
||||
std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) {
|
||||
return evaluator.Evaluator().CalcWeightCat(param, l) <
|
||||
evaluator.Evaluator().CalcWeightCat(param, r);
|
||||
});
|
||||
|
||||
double reimpl = 0;
|
||||
run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); });
|
||||
CHECK_EQ(reimpl, best_loss_chg);
|
||||
}
|
||||
std::vector<CPUExpandEntry> entries(1);
|
||||
evaluator.EvaluateSplits(hist_, cuts_, {ft}, tree, &entries);
|
||||
ASSERT_NEAR(entries[0].split.loss_chg, best_score_, 1e-16);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
96
tests/cpp/tree/test_evaluate_splits.h
Normal file
96
tests/cpp/tree/test_evaluate_splits.h
Normal file
@ -0,0 +1,96 @@
|
||||
/*!
|
||||
* Copyright 2022 by XGBoost Contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <algorithm> // next_permutation
|
||||
#include <numeric> // iota
|
||||
|
||||
#include "../../../src/tree/hist/evaluate_splits.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
/**
|
||||
* \brief Enumerate all possible partitions for categorical split.
|
||||
*/
|
||||
class TestPartitionBasedSplit : public ::testing::Test {
|
||||
protected:
|
||||
size_t n_bins_ = 6;
|
||||
std::vector<size_t> sorted_idx_;
|
||||
TrainParam param_;
|
||||
MetaInfo info_;
|
||||
float best_score_{-std::numeric_limits<float>::infinity()};
|
||||
common::HistogramCuts cuts_;
|
||||
common::HistCollection<double> hist_;
|
||||
GradientPairPrecise total_gpair_;
|
||||
|
||||
void SetUp() override {
|
||||
param_.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
|
||||
sorted_idx_.resize(n_bins_);
|
||||
std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0);
|
||||
|
||||
info_.num_col_ = 1;
|
||||
|
||||
cuts_.cut_ptrs_.Resize(2);
|
||||
cuts_.SetCategorical(true, n_bins_);
|
||||
auto &h_cuts = cuts_.cut_ptrs_.HostVector();
|
||||
h_cuts[0] = 0;
|
||||
h_cuts[1] = n_bins_;
|
||||
auto &h_vals = cuts_.cut_values_.HostVector();
|
||||
h_vals.resize(n_bins_);
|
||||
std::iota(h_vals.begin(), h_vals.end(), 0.0);
|
||||
|
||||
hist_.Init(cuts_.TotalBins());
|
||||
hist_.AddHistRow(0);
|
||||
hist_.AllocateAllData();
|
||||
auto node_hist = hist_[0];
|
||||
|
||||
SimpleLCG lcg;
|
||||
SimpleRealUniformDistribution<double> grad_dist{-4.0, 4.0};
|
||||
SimpleRealUniformDistribution<double> hess_dist{0.0, 4.0};
|
||||
|
||||
for (auto &e : node_hist) {
|
||||
e = GradientPairPrecise{grad_dist(&lcg), hess_dist(&lcg)};
|
||||
total_gpair_ += e;
|
||||
}
|
||||
|
||||
auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow<double> hist,
|
||||
GradientPairPrecise parent_sum) {
|
||||
int32_t best_thresh = -1;
|
||||
float best_score{-std::numeric_limits<float>::infinity()};
|
||||
TreeEvaluator evaluator{param_, static_cast<bst_feature_t>(n_feat), -1};
|
||||
auto tree_evaluator = evaluator.GetEvaluator<TrainParam>();
|
||||
GradientPairPrecise left_sum;
|
||||
auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_});
|
||||
for (size_t i = 0; i < hist.size() - 1; ++i) {
|
||||
left_sum += hist[i];
|
||||
auto right_sum = parent_sum - left_sum;
|
||||
auto gain =
|
||||
tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) -
|
||||
parent_gain;
|
||||
if (gain > best_score) {
|
||||
best_score = gain;
|
||||
best_thresh = i;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(best_thresh, best_score);
|
||||
};
|
||||
|
||||
// enumerate all possible partitions to find the optimal split
|
||||
do {
|
||||
int32_t thresh;
|
||||
float score;
|
||||
std::vector<GradientPairPrecise> sorted_hist(node_hist.size());
|
||||
for (size_t i = 0; i < sorted_hist.size(); ++i) {
|
||||
sorted_hist[i] = node_hist[sorted_idx_[i]];
|
||||
}
|
||||
std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_);
|
||||
if (score > best_score_) {
|
||||
best_score_ = score;
|
||||
}
|
||||
} while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@ -262,7 +262,8 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
info.num_row_ = kNRows;
|
||||
info.num_col_ = kNCols;
|
||||
|
||||
DeviceSplitCandidate res = maker.EvaluateRootSplit({6.4f, 12.8f});
|
||||
DeviceSplitCandidate res =
|
||||
maker.EvaluateRootSplit({6.4f, 12.8f}, 0, ObjInfo{ObjInfo::kRegression}).split;
|
||||
|
||||
ASSERT_EQ(res.findex, 7);
|
||||
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
|
||||
@ -300,11 +301,11 @@ void TestHistogramIndexImpl() {
|
||||
const auto &maker = hist_maker.maker;
|
||||
auto grad = GenerateRandomGradients(kNRows);
|
||||
grad.SetDevice(0);
|
||||
maker->Reset(&grad, hist_maker_dmat.get(), kNCols);
|
||||
maker->Reset(&grad, hist_maker_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.HostVector());
|
||||
|
||||
const auto &maker_ext = hist_maker_ext.maker;
|
||||
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols);
|
||||
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
|
||||
std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector());
|
||||
|
||||
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());
|
||||
|
||||
@ -211,6 +211,34 @@ class TestTreeMethod:
|
||||
)
|
||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||
|
||||
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||
parameters["max_cat_to_onehot"] = 1
|
||||
parameters["reg_lambda"] = 0
|
||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=rounds,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_grouping,
|
||||
)
|
||||
rmse_oh = by_builtin_results["Train"]["rmse"]
|
||||
rmse_group = by_grouping["Train"]["rmse"]
|
||||
# always better or equal to onehot when there's no regularization.
|
||||
for a, b in zip(rmse_oh, rmse_group):
|
||||
assert a >= b
|
||||
|
||||
parameters["reg_lambda"] = 1.0
|
||||
by_grouping = {}
|
||||
xgb.train(
|
||||
parameters,
|
||||
m,
|
||||
num_boost_round=32,
|
||||
evals=[(m, "Train")],
|
||||
evals_result=by_grouping,
|
||||
)
|
||||
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
||||
|
||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||
@settings(deadline=None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user