Support optimal partitioning for GPU hist. (#7652)

* Implement `MaxCategory` in quantile.
* Implement partition-based split for GPU evaluation.  Currently, it's based on the existing evaluation function.
* Extract an evaluator from GPU Hist to store the needed states.
* Added some CUDA stream/event utilities.
* Update document with references.
* Fixed a bug in approx evaluator where the number of data points is less than the number of categories.
This commit is contained in:
Jiaming Yuan 2022-02-15 03:03:12 +08:00 committed by GitHub
parent 2369d55e9a
commit 0d0abe1845
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1088 additions and 528 deletions

View File

@ -61,7 +61,12 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
return X, y return X, y
params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32} params = {
"tree_method": "gpu_hist",
"use_label_encoder": False,
"n_estimators": 32,
"colsample_bylevel": 0.7,
}
def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
@ -70,13 +75,13 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
X, y, random_state=1994, test_size=0.2 X, y, random_state=1994, test_size=0.2
) )
# Specify `enable_categorical`. # Specify `enable_categorical`.
clf = xgb.XGBClassifier(**params, enable_categorical=True) clf = xgb.XGBClassifier(
clf.fit( **params,
X_train,
y_train,
eval_set=[(X_test, y_test), (X_train, y_train)],
eval_metric="auc", eval_metric="auc",
enable_categorical=True,
max_cat_to_onehot=1, # We use optimal partitioning exclusively
) )
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)])
clf.save_model(os.path.join(output_dir, "categorical.json")) clf.save_model(os.path.join(output_dir, "categorical.json"))
y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples

View File

@ -3,15 +3,15 @@ Getting started with categorical data
===================================== =====================================
Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has
experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported experimental support for one-hot encoding based tree split, and in 1.6 `approx` support
was added. was added.
In before, users need to run an encoder themselves before passing the data into XGBoost, In before, users need to run an encoder themselves before passing the data into XGBoost,
which creates a sparse matrix and potentially increase memory usage. This demo showcases which creates a sparse matrix and potentially increase memory usage. This demo
the experimental categorical data support, more advanced features are planned. showcases the experimental categorical data support, more advanced features are planned.
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with categorical data
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with
categorical data.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
@ -55,8 +55,11 @@ def main() -> None:
# For scikit-learn interface, the input data must be pandas DataFrame or cudf # For scikit-learn interface, the input data must be pandas DataFrame or cudf
# DataFrame with categorical features # DataFrame with categorical features
X, y = make_categorical(100, 10, 4, False) X, y = make_categorical(100, 10, 4, False)
# Specify `enable_categorical` to True. # Specify `enable_categorical` to True, also we use onehot encoding based split
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) # here for demonstration. For details see the document of `max_cat_to_onehot`.
reg = xgb.XGBRegressor(
tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5
)
reg.fit(X, y, eval_set=[(X, y)]) reg.fit(X, y, eval_set=[(X, y)])
# Pass in already encoded data # Pass in already encoded data

View File

@ -245,8 +245,8 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method
- Use single precision to build histograms instead of double precision. - Use single precision to build histograms instead of double precision.
Additional parameters for ``approx`` tree method Additional parameters for ``approx`` and ``gpu_hist`` tree method
================================================ =================================================================
* ``max_cat_to_onehot`` * ``max_cat_to_onehot``
@ -257,7 +257,8 @@ Additional parameters for ``approx`` tree method
- A threshold for deciding whether XGBoost should use one-hot encoding based split for - A threshold for deciding whether XGBoost should use one-hot encoding based split for
categorical data. When number of categories is lesser than the threshold then one-hot categorical data. When number of categories is lesser than the threshold then one-hot
encoding is chosen, otherwise the categories will be partitioned into children nodes. encoding is chosen, otherwise the categories will be partitioned into children nodes.
Only relevant for regression and binary classification with `approx` tree method. Only relevant for regression and binary classification. Also, `approx` or `gpu_hist`
tree method is required.
Additional parameters for Dart Booster (``booster=dart``) Additional parameters for Dart Booster (``booster=dart``)
========================================================= =========================================================

View File

@ -2,6 +2,10 @@
Categorical Data Categorical Data
################ ################
.. note::
As of XGBoost 1.6, the feature is highly experimental and has limited features
Starting from version 1.5, XGBoost has experimental support for categorical data available Starting from version 1.5, XGBoost has experimental support for categorical data available
for public testing. At the moment, the support is implemented as one-hot encoding based for public testing. At the moment, the support is implemented as one-hot encoding based
categorical tree splits. For numerical data, the split condition is defined as categorical tree splits. For numerical data, the split condition is defined as
@ -107,6 +111,28 @@ For numerical data, the feature type can be ``"q"`` or ``"float"``, while for ca
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
:class:`dask.Array <dask.Array>` can also be used as categorical data. :class:`dask.Array <dask.Array>` can also be used as categorical data.
********************
Optimal Partitioning
********************
.. versionadded:: 1.6
Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details.
************* *************
Miscellaneous Miscellaneous
@ -120,10 +146,20 @@ actual number of unique categories. During training this is validated but for p
it's treated as the same as missing value for performance reasons. Lastly, missing values it's treated as the same as missing value for performance reasons. Lastly, missing values
are treated as the same as numerical features (using the learned split direction). are treated as the same as numerical features (using the learned split direction).
********** **********
Next Steps References
********** **********
As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU [1] Walter D. Fisher. "`On Grouping for Maximum Homogeneity`_." Journal of the American Statistical Association. Vol. 53, No. 284 (Dec., 1958), pp. 789-798.
training is not yet supported. Please see `this issue
<https://github.com/dmlc/xgboost/issues/6503>`_ for progress. [2] Trevor Hastie, Robert Tibshirani, Jerome Friedman. "`The Elements of Statistical Learning`_". Springer Series in Statistics Springer New York Inc. (2001).
[3] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, Tie-Yan Liu. "`LightGBM\: A Highly Efficient Gradient Boosting Decision Tree`_." Advances in Neural Information Processing Systems 30 (NIPS 2017), pp. 3149-3157.
.. _On Grouping for Maximum Homogeneity: https://www.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479
.. _The Elements of Statistical Learning: https://link.springer.com/book/10.1007/978-0-387-84858-7
.. _LightGBM\: A Highly Efficient Gradient Boosting Decision Tree: https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2021 by XGBoost Contributors * Copyright 2021-2022 by XGBoost Contributors
*/ */
#ifndef XGBOOST_TASK_H_ #ifndef XGBOOST_TASK_H_
#define XGBOOST_TASK_H_ #define XGBOOST_TASK_H_
@ -34,6 +34,10 @@ struct ObjInfo {
explicit ObjInfo(Task t) : task{t} {} explicit ObjInfo(Task t) : task{t} {}
ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {} ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {}
constexpr bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
}
}; };
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_TASK_H_ #endif // XGBOOST_TASK_H_

View File

@ -581,10 +581,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
Experimental support of specializing for categorical features. Do not set to Experimental support of specializing for categorical features. Do not set
True unless you are interested in development. Currently it's only available to True unless you are interested in development. Currently it's only
for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also, available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON
JSON serialization format is required. serialization format is required. (XGBoost 1.6 for approx)
""" """
if group is not None and qid is not None: if group is not None and qid is not None:

View File

@ -207,7 +207,9 @@ __model_doc = f'''
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
Experimental support for categorical data. Do not set to true unless you are Experimental support for categorical data. Do not set to true unless you are
interested in development. Only valid when `gpu_hist` and dataframe are used. interested in development. Only valid when `gpu_hist` or `approx` is used along
with dataframe as input. Also, JSON/UBJSON serialization format is
required. (XGBoost 1.6 for approx)
max_cat_to_onehot : Optional[int] max_cat_to_onehot : Optional[int]
@ -216,10 +218,11 @@ __model_doc = f'''
.. note:: This parameter is experimental .. note:: This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold then for categorical data. When number of categories is lesser than the threshold
one-hot encoding is chosen, otherwise the categories will be partitioned into then one-hot encoding is chosen, otherwise the categories will be partitioned
children nodes. Only relevant for regression and binary classification and into children nodes. Only relevant for regression and binary
`approx` tree method. classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See
:doc:`Categorical Data </tutorials/categorical>` for details.
eval_metric : Optional[Union[str, List[str], Callable]] eval_metric : Optional[Union[str, List[str], Callable]]

View File

@ -16,6 +16,10 @@
namespace xgboost { namespace xgboost {
namespace common { namespace common {
using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
// Cast the categorical type. // Cast the categorical type.
template <typename T> template <typename T>
XGBOOST_DEVICE bst_cat_t AsCat(T const& v) { XGBOOST_DEVICE bst_cat_t AsCat(T const& v) {
@ -57,6 +61,11 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, float cat
if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) { if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) {
return dft_left; return dft_left;
} }
auto pos = KCatBitField::ToBitPos(cat);
if (pos.int_pos >= cats.size()) {
return true;
}
return !s_cats.Check(AsCat(cat)); return !s_cats.Check(AsCat(cat));
} }
@ -73,18 +82,14 @@ inline void InvalidCategory() {
/*! /*!
* \brief Whether should we use onehot encoding for categorical data. * \brief Whether should we use onehot encoding for categorical data.
*/ */
inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
bool use_one_hot = n_cats < max_cat_to_onehot || bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot();
(task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary);
return use_one_hot; return use_one_hot;
} }
struct IsCatOp { struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
}; };
using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -952,22 +952,22 @@ thrust::device_ptr<T const> tcend(xgboost::HostDeviceVector<T> const& vector) {
} }
template <typename T> template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T>& span) { // NOLINT
return thrust::device_ptr<T>(span.data()); return thrust::device_ptr<T>(span.data());
} }
template <typename T> template <typename T>
thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T> tbegin(xgboost::common::Span<T> const& span) { // NOLINT
return thrust::device_ptr<T>(span.data()); return thrust::device_ptr<T>(span.data());
} }
template <typename T> template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
return tbegin(span) + span.size(); return tbegin(span) + span.size();
} }
template <typename T> template <typename T>
thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T> tend(xgboost::common::Span<T> const& span) { // NOLINT
return tbegin(span) + span.size(); return tbegin(span) + span.size();
} }
@ -982,12 +982,12 @@ XGBOOST_DEVICE auto trend(xgboost::common::Span<T> &span) { // NOLINT
} }
template <typename T> template <typename T>
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
return thrust::device_ptr<T const>(span.data()); return thrust::device_ptr<T const>(span.data());
} }
template <typename T> template <typename T>
thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT XGBOOST_DEVICE thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NOLINT
return tcbegin(span) + span.size(); return tcbegin(span) + span.size();
} }
@ -1536,4 +1536,69 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(),
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice));
} }
class CUDAStreamView;
class CUDAEvent {
cudaEvent_t event_{nullptr};
public:
CUDAEvent() { dh::safe_cuda(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); }
~CUDAEvent() {
if (event_) {
dh::safe_cuda(cudaEventDestroy(event_));
}
}
CUDAEvent(CUDAEvent const &that) = delete;
CUDAEvent &operator=(CUDAEvent const &that) = delete;
inline void Record(CUDAStreamView stream); // NOLINT
operator cudaEvent_t() const { return event_; } // NOLINT
};
class CUDAStreamView {
cudaStream_t stream_{nullptr};
public:
explicit CUDAStreamView(cudaStream_t s) : stream_{s} {}
void Wait(CUDAEvent const &e) {
#if defined(__CUDACC_VER_MAJOR__)
#if __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0
// CUDA == 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, 0));
#else
// CUDA > 11.0
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
#endif // __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0:
#else // clang
dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault));
#endif // defined(__CUDACC_VER_MAJOR__)
}
operator cudaStream_t() const { // NOLINT
return stream_;
}
void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); }
};
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
}
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; }
class CUDAStream {
cudaStream_t stream_;
public:
CUDAStream() {
dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
}
~CUDAStream() {
dh::safe_cuda(cudaStreamDestroy(stream_));
}
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
};
} // namespace dh } // namespace dh

View File

@ -33,66 +33,84 @@ namespace common {
*/ */
using GHistIndexRow = Span<uint32_t const>; using GHistIndexRow = Span<uint32_t const>;
// A CSC matrix representing histogram cuts, used in CPU quantile hist. // A CSC matrix representing histogram cuts.
// The cut values represent upper bounds of bins containing approximately equal numbers of elements // The cut values represent upper bounds of bins containing approximately equal numbers of elements
class HistogramCuts { class HistogramCuts {
bool has_categorical_{false};
float max_cat_{-1.0f};
protected: protected:
using BinIdx = uint32_t; using BinIdx = uint32_t;
void Swap(HistogramCuts&& that) noexcept(true) {
std::swap(cut_values_, that.cut_values_);
std::swap(cut_ptrs_, that.cut_ptrs_);
std::swap(min_vals_, that.min_vals_);
std::swap(has_categorical_, that.has_categorical_);
std::swap(max_cat_, that.max_cat_);
}
void Copy(HistogramCuts const& that) {
cut_values_.Resize(that.cut_values_.Size());
cut_ptrs_.Resize(that.cut_ptrs_.Size());
min_vals_.Resize(that.min_vals_.Size());
cut_values_.Copy(that.cut_values_);
cut_ptrs_.Copy(that.cut_ptrs_);
min_vals_.Copy(that.min_vals_);
has_categorical_ = that.has_categorical_;
max_cat_ = that.max_cat_;
}
public: public:
HostDeviceVector<bst_float> cut_values_; // NOLINT HostDeviceVector<float> cut_values_; // NOLINT
HostDeviceVector<uint32_t> cut_ptrs_; // NOLINT HostDeviceVector<uint32_t> cut_ptrs_; // NOLINT
// storing minimum value in a sketch set. // storing minimum value in a sketch set.
HostDeviceVector<float> min_vals_; // NOLINT HostDeviceVector<float> min_vals_; // NOLINT
HistogramCuts(); HistogramCuts();
HistogramCuts(HistogramCuts const& that) { HistogramCuts(HistogramCuts const& that) { this->Copy(that); }
cut_values_.Resize(that.cut_values_.Size());
cut_ptrs_.Resize(that.cut_ptrs_.Size());
min_vals_.Resize(that.min_vals_.Size());
cut_values_.Copy(that.cut_values_);
cut_ptrs_.Copy(that.cut_ptrs_);
min_vals_.Copy(that.min_vals_);
}
HistogramCuts(HistogramCuts&& that) noexcept(true) { HistogramCuts(HistogramCuts&& that) noexcept(true) {
*this = std::forward<HistogramCuts&&>(that); this->Swap(std::forward<HistogramCuts>(that));
} }
HistogramCuts& operator=(HistogramCuts const& that) { HistogramCuts& operator=(HistogramCuts const& that) {
cut_values_.Resize(that.cut_values_.Size()); this->Copy(that);
cut_ptrs_.Resize(that.cut_ptrs_.Size());
min_vals_.Resize(that.min_vals_.Size());
cut_values_.Copy(that.cut_values_);
cut_ptrs_.Copy(that.cut_ptrs_);
min_vals_.Copy(that.min_vals_);
return *this; return *this;
} }
HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) { HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) {
cut_ptrs_ = std::move(that.cut_ptrs_); this->Swap(std::forward<HistogramCuts>(that));
cut_values_ = std::move(that.cut_values_);
min_vals_ = std::move(that.min_vals_);
return *this; return *this;
} }
uint32_t FeatureBins(uint32_t feature) const { uint32_t FeatureBins(bst_feature_t feature) const {
return cut_ptrs_.ConstHostVector().at(feature + 1) - return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature];
cut_ptrs_.ConstHostVector()[feature];
} }
// Getters. Cuts should be of no use after building histogram indices, but currently
// they are deeply linked with quantile_hist, gpu sketcher and gpu_hist, so we preserve
// these for now.
std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); } std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); } std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); } std::vector<float> const& MinValues() const { return min_vals_.ConstHostVector(); }
bool HasCategorical() const { return has_categorical_; }
float MaxCategory() const { return max_cat_; }
/**
* \brief Set meta info about categorical features.
*
* \param has_cat Do we have categorical feature in the data?
* \param max_cat The maximum categorical value in all features.
*/
void SetCategorical(bool has_cat, float max_cat) {
has_categorical_ = has_cat;
max_cat_ = max_cat;
}
size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); } size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); }
// Return the index of a cut point that is strictly greater than the input // Return the index of a cut point that is strictly greater than the input
// value, or the last available index if none exists // value, or the last available index if none exists
BinIdx SearchBin(float value, uint32_t column_id, std::vector<uint32_t> const& ptrs, BinIdx SearchBin(float value, bst_feature_t column_id, std::vector<uint32_t> const& ptrs,
std::vector<float> const& values) const { std::vector<float> const& values) const {
auto end = ptrs[column_id + 1]; auto end = ptrs[column_id + 1];
auto beg = ptrs[column_id]; auto beg = ptrs[column_id];
@ -102,7 +120,7 @@ class HistogramCuts {
return idx; return idx;
} }
BinIdx SearchBin(float value, uint32_t column_id) const { BinIdx SearchBin(float value, bst_feature_t column_id) const {
return this->SearchBin(value, column_id, Ptrs(), Values()); return this->SearchBin(value, column_id, Ptrs(), Values());
} }

View File

@ -272,7 +272,7 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
// move all categories into a flatten vector to prepare for allreduce // move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back(); size_t total = feature_ptr.back();
std::vector<bst_cat_t> flatten(total, 0); std::vector<float> flatten(total, 0);
auto cursor{flatten.begin()}; auto cursor{flatten.begin()};
for (auto const &feat : categories) { for (auto const &feat : categories) {
cursor = std::copy(feat.cbegin(), feat.cend(), cursor); cursor = std::copy(feat.cbegin(), feat.cend(), cursor);
@ -287,15 +287,15 @@ void AllreduceCategories(Span<FeatureType const> feature_types, int32_t n_thread
auto gtotal = global_worker_ptr.back(); auto gtotal = global_worker_ptr.back();
// categories in all workers with all features. // categories in all workers with all features.
std::vector<bst_cat_t> global_categories(gtotal, 0); std::vector<float> global_categories(gtotal, 0);
auto rank_begin = global_worker_ptr[rank]; auto rank_begin = global_worker_ptr[rank];
auto rank_size = global_worker_ptr[rank + 1] - rank_begin; auto rank_size = global_worker_ptr[rank + 1] - rank_begin;
CHECK_EQ(rank_size, total); CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers. // gather values from all workers.
rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size()); rabit::Allreduce<rabit::op::Sum>(global_categories.data(), global_categories.size());
QuantileAllreduce<bst_cat_t> allreduce_result{global_categories, global_worker_ptr, QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
global_feat_ptrs, categories.size()}; categories.size()};
ParallelFor(categories.size(), n_threads, [&](auto fidx) { ParallelFor(categories.size(), n_threads, [&](auto fidx) {
if (!IsCat(feature_types, fidx)) { if (!IsCat(feature_types, fidx)) {
return; return;
@ -531,6 +531,22 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
InvalidCategory(); InvalidCategory();
} }
} }
auto const &ptrs = cuts->Ptrs();
auto const &vals = cuts->Values();
float max_cat{-std::numeric_limits<float>::infinity()};
for (size_t i = 1; i < ptrs.size(); ++i) {
if (IsCat(feature_types_, i - 1)) {
auto beg = ptrs[i - 1];
auto end = ptrs[i];
auto feat = Span<float const>{vals}.subspan(beg, end - beg);
auto max_elem = *std::max_element(feat.cbegin(), feat.cend());
if (max_elem > max_cat) {
max_cat = max_elem;
}
}
}
cuts->SetCategorical(true, max_cat);
} }
monitor_.Stop(__func__); monitor_.Stop(__func__);

View File

@ -1,22 +1,23 @@
/*! /*!
* Copyright 2020 by XGBoost Contributors * Copyright 2020 by XGBoost Contributors
*/ */
#include <thrust/unique.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <thrust/transform_scan.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/transform_scan.h>
#include <thrust/unique.h>
#include <limits> // std::numeric_limits
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "xgboost/span.h"
#include "quantile.h"
#include "quantile.cuh"
#include "hist_util.h"
#include "device_helpers.cuh"
#include "categorical.h" #include "categorical.h"
#include "common.h" #include "common.h"
#include "device_helpers.cuh"
#include "hist_util.h"
#include "quantile.cuh"
#include "quantile.h"
#include "xgboost/span.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -586,7 +587,7 @@ struct InvalidCatOp {
Span<uint32_t const> ptrs; Span<uint32_t const> ptrs;
Span<FeatureType const> ft; Span<FeatureType const> ft;
XGBOOST_DEVICE bool operator()(size_t i) { XGBOOST_DEVICE bool operator()(size_t i) const {
auto fidx = dh::SegmentId(ptrs, i); auto fidx = dh::SegmentId(ptrs, i);
return IsCat(ft, fidx) && InvalidCat(values[i]); return IsCat(ft, fidx) && InvalidCat(values[i]);
} }
@ -683,18 +684,36 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
out_column[idx] = in_column[idx+1].value; out_column[idx] = in_column[idx+1].value;
}); });
float max_cat{-1.0f};
if (has_categorical_) { if (has_categorical_) {
dh::XGBCachingDeviceAllocator<char> alloc; auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft};
auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan(); auto it = dh::MakeTransformIterator<thrust::pair<bool, float>>(
auto it = thrust::make_counting_iterator(0ul); thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(d_out_columns_ptr, i);
if (IsCat(d_ft, fidx)) {
auto invalid = invalid_op(i);
auto v = out_cut_values[i];
return thrust::make_pair(invalid, v);
}
return thrust::make_pair(false, std::numeric_limits<float>::min());
});
CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size()); bool invalid{false};
auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(), dh::XGBCachingDeviceAllocator<char> alloc;
InvalidCatOp{out_cut_values, ptrs, d_ft}); thrust::tie(invalid, max_cat) =
thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(),
thrust::make_pair(false, std::numeric_limits<float>::min()),
[=] XGBOOST_DEVICE(thrust::pair<bool, bst_cat_t> const &l,
thrust::pair<bool, bst_cat_t> const &r) {
return thrust::make_pair(l.first || r.first, std::max(l.second, r.second));
});
if (invalid) { if (invalid) {
InvalidCategory(); InvalidCategory();
} }
} }
p_cuts->SetCategorical(this->has_categorical_, max_cat);
timer_.Stop(__func__); timer_.Stop(__func__);
} }
} // namespace common } // namespace common

View File

@ -1,9 +1,14 @@
/*! /*!
* Copyright 2020-2021 by XGBoost Contributors * Copyright 2020-2022 by XGBoost Contributors
*/ */
#include <algorithm> // std::max
#include <limits> #include <limits>
#include "evaluate_splits.cuh"
#include "../../common/categorical.h" #include "../../common/categorical.h"
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh"
#include "expand_entry.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -23,7 +28,7 @@ XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
float missing_right_gain = evaluator.CalcSplitGain( float missing_right_gain = evaluator.CalcSplitGain(
param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan)); param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan));
if (missing_left_gain >= missing_right_gain) { if (missing_left_gain > missing_right_gain) {
missing_left_out = true; missing_left_out = true;
return missing_left_gain - parent_gain; return missing_left_gain - parent_gain;
} else { } else {
@ -69,108 +74,61 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
return shared_sum; return shared_sum;
} }
template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT> *,
GradientPairPrecise const &missing,
EvaluateSplitInputs<GradientSumT> const &inputs,
TempStorageT *) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
auto rest = inputs.parent_sum - GradientPairPrecise(bin) - missing;
return GradientSumT{rest};
}
};
template <typename GradientSumT>
struct UpdateOneHot {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientPairPrecise const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
int split_gidx = (scan_begin + threadIdx.x);
float fvalue = inputs.feature_values[split_gidx];
GradientPairPrecise left =
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
GradientPairPrecise right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, true,
inputs.param);
}
};
template <typename GradientSumT, typename TempStorageT, typename ScanT>
struct NumericBin {
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT> *prefix_callback,
GradientPairPrecise const &missing,
EvaluateSplitInputs<GradientSumT> inputs,
TempStorageT *temp_storage) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback);
return bin;
}
};
template <typename GradientSumT>
struct UpdateNumeric {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientPairPrecise const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
// Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientPairPrecise left =
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
GradientPairPrecise right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, false,
inputs.param);
}
};
/*! \brief Find the thread with best gain. */ /*! \brief Find the thread with best gain. */
template <int BLOCK_THREADS, typename ReduceT, typename ScanT, template <int BLOCK_THREADS, typename ReduceT, typename ScanT, typename MaxReduceT,
typename MaxReduceT, typename TempStorageT, typename GradientSumT, typename TempStorageT, typename GradientSumT, SplitType type>
typename BinFn, typename UpdateFn>
__device__ void EvaluateFeature( __device__ void EvaluateFeature(
int fidx, EvaluateSplitInputs<GradientSumT> inputs, int fidx, EvaluateSplitInputs<GradientSumT> inputs,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
DeviceSplitCandidate* best_split, // shared memory storing best split common::Span<bst_feature_t> sorted_idx, size_t offset,
TempStorageT* temp_storage // temp memory for cub operations DeviceSplitCandidate *best_split, // shared memory storing best split
TempStorageT *temp_storage // temp memory for cub operations
) { ) {
// Use pointer from cut to indicate begin and end of bins for each feature. // Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin
uint32_t gidx_end = uint32_t gidx_end =
inputs.feature_segments[fidx + 1]; // end bin for i^th feature inputs.feature_segments[fidx + 1]; // end bin for i^th feature
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin); auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
auto bin_fn = BinFn();
auto update_fn = UpdateFn();
// Sum histogram bins for current feature // Sum histogram bins for current feature
GradientSumT const feature_sum = GradientSumT const feature_sum =
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>( ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(feature_hist, temp_storage);
feature_hist, temp_storage);
GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum}; GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum};
float const null_gain = -std::numeric_limits<bst_float>::infinity(); float const null_gain = -std::numeric_limits<bst_float>::infinity();
SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>(); SumCallbackOp<GradientSumT> prefix_op = SumCallbackOp<GradientSumT>();
for (int scan_begin = gidx_begin; scan_begin < gidx_end; for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) {
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end; bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage);
auto calc_bin_value = [&]() {
GradientSumT bin;
switch (type) {
case kOneHot: {
auto rest =
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
break;
}
case kNum: {
bin =
thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
break;
}
case kPart: {
auto rest = thread_active
? inputs.gradient_histogram[sorted_idx[scan_begin + threadIdx.x] - offset]
: GradientSumT();
// No min value for cat feature, use inclusive scan.
ScanT(temp_storage->scan).InclusiveScan(rest, rest, cub::Sum(), prefix_op);
bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT
break;
}
}
return bin;
};
auto bin = calc_bin_value();
// Whether the gradient of missing values is put to the left side. // Whether the gradient of missing values is put to the left side.
bool missing_left = true; bool missing_left = true;
float gain = null_gain; float gain = null_gain;
@ -193,10 +151,48 @@ __device__ void EvaluateFeature(
cub::CTA_SYNC(); cub::CTA_SYNC();
// Best thread updates split // Best thread updates the split
if (threadIdx.x == block_max.key) { if (threadIdx.x == block_max.key) {
update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs, switch (type) {
best_split); case kNum: {
// Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientPairPrecise left =
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
GradientPairPrecise right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
false, inputs.param);
break;
}
case kOneHot: {
int32_t split_gidx = (scan_begin + threadIdx.x);
float fvalue = inputs.feature_values[split_gidx];
GradientPairPrecise left =
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
GradientPairPrecise right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
true, inputs.param);
break;
}
case kPart: {
int32_t split_gidx = (scan_begin + threadIdx.x);
float fvalue = inputs.feature_values[split_gidx];
GradientPairPrecise left =
missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin};
GradientPairPrecise right = inputs.parent_sum - left;
auto best_thresh = block_max.key; // index of best threshold inside a feature.
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
right, true, inputs.param);
break;
}
}
} }
cub::CTA_SYNC(); cub::CTA_SYNC();
} }
@ -206,6 +202,8 @@ template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitsKernel( __global__ void EvaluateSplitsKernel(
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right, EvaluateSplitInputs<GradientSumT> right,
ObjInfo task,
common::Span<bst_feature_t> sorted_idx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_candidates) { common::Span<DeviceSplitCandidate> out_candidates) {
// KeyValuePair here used as threadIdx.x -> gain_value // KeyValuePair here used as threadIdx.x -> gain_value
@ -240,22 +238,26 @@ __global__ void EvaluateSplitsKernel(
// One block for each feature. Features are sampled, so fidx != blockIdx.x // One block for each feature. Features are sampled, so fidx != blockIdx.x
int fidx = inputs.feature_set[is_left ? blockIdx.x int fidx = inputs.feature_set[is_left ? blockIdx.x
: blockIdx.x - left.feature_set.size()]; : blockIdx.x - left.feature_set.size()];
if (common::IsCat(inputs.feature_types, fidx)) { if (common::IsCat(inputs.feature_types, fidx)) {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx];
TempStorage, GradientSumT, if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) {
OneHotBin<GradientSumT, TempStorage>, EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
UpdateOneHot<GradientSumT>>(fidx, inputs, evaluator, &best_split, kOneHot>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
&temp_storage); } else {
auto node_sorted_idx = is_left ? sorted_idx.first(inputs.feature_values.size())
: sorted_idx.last(inputs.feature_values.size());
size_t offset = is_left ? 0 : inputs.feature_values.size();
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
kPart>(fidx, inputs, evaluator, node_sorted_idx, offset, &best_split,
&temp_storage);
}
} else { } else {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
TempStorage, GradientSumT, kNum>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
NumericBin<GradientSumT, TempStorage, BlockScanT>,
UpdateNumeric<GradientSumT>>(fidx, inputs, evaluator, &best_split,
&temp_storage);
} }
cub::CTA_SYNC(); cub::CTA_SYNC();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
// Record best loss for each feature // Record best loss for each feature
out_candidates[blockIdx.x] = best_split; out_candidates[blockIdx.x] = best_split;
@ -267,71 +269,175 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a,
return b.loss_chg > a.loss_chg ? b : a; return b.loss_chg > a.loss_chg ? b : a;
} }
/**
* \brief Set the bits for categorical splits based on the split threshold.
*/
template <typename GradientSumT> template <typename GradientSumT>
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits, __device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, common::Span<bst_feature_t const> d_sorted_idx, bst_feature_t fidx,
EvaluateSplitInputs<GradientSumT> left, bool is_left, common::Span<common::CatBitField::value_type> out,
EvaluateSplitInputs<GradientSumT> right) { DeviceSplitCandidate *p_out_split) {
size_t combined_num_features = auto &out_split = *p_out_split;
left.feature_set.size() + right.feature_set.size(); out_split.split_cats = common::CatBitField{out};
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits( auto node_sorted_idx =
combined_num_features); is_left ? d_sorted_idx.subspan(0, input.feature_values.size())
: d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size());
size_t node_offset = is_left ? 0 : input.feature_values.size();
auto best_thresh = out_split.PopBestThresh();
auto f_sorted_idx =
node_sorted_idx.subspan(input.feature_segments[fidx], input.FeatureBins(fidx));
if (out_split.dir != kLeftDir) {
// forward, missing on right
auto beg = dh::tcbegin(f_sorted_idx);
// Don't put all the categories into one side
auto boundary = std::min(static_cast<size_t>((best_thresh + 1)), (f_sorted_idx.size() - 1));
boundary = std::max(boundary, static_cast<size_t>(1ul));
auto end = beg + boundary;
thrust::for_each(thrust::seq, beg, end, [&](auto c) {
auto cat = input.feature_values[c - node_offset];
assert(!out_split.split_cats.Check(cat) && "already set");
out_split.SetCat(cat);
});
} else {
assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0");
thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx),
dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) {
auto cat = input.feature_values[c - node_offset];
out_split.SetCat(cat);
});
}
}
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits) {
if (!split_cats_.empty()) {
this->SortHistogram(left, right, evaluator);
}
size_t combined_num_features = left.feature_set.size() + right.feature_set.size();
dh::TemporaryArray<DeviceSplitCandidate> feature_best_splits(combined_num_features);
// One block for each feature // One block for each feature
uint32_t constexpr kBlockThreads = 256; uint32_t constexpr kBlockThreads = 256;
dh::LaunchKernel {uint32_t(combined_num_features), kBlockThreads, 0}( dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, evaluator, EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, task, this->SortedIdx(left),
dh::ToSpan(feature_best_splits)); evaluator, dh::ToSpan(feature_best_splits));
// Reduce to get best candidate for left and right child over all features // Reduce to get best candidate for left and right child over all features
auto reduce_offset = auto reduce_offset = dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu),
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) -> size_t {
[=] __device__(size_t idx) -> size_t { if (idx == 0) {
if (idx == 0) { return 0;
return 0; }
} if (idx == 1) {
if (idx == 1) { return left.feature_set.size();
return left.feature_set.size(); }
} if (idx == 2) {
if (idx == 2) { return combined_num_features;
return combined_num_features; }
} return 0;
return 0; });
});
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
auto num_segments = out_splits.size(); auto num_segments = out_splits.size();
cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, feature_best_splits.data(),
feature_best_splits.data(), out_splits.data(), out_splits.data(), num_segments, reduce_offset,
num_segments, reduce_offset, reduce_offset + 1); reduce_offset + 1);
dh::TemporaryArray<int8_t> temp(temp_storage_bytes); dh::TemporaryArray<int8_t> temp(temp_storage_bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, feature_best_splits.data(),
feature_best_splits.data(), out_splits.data(), out_splits.data(), num_segments, reduce_offset,
num_segments, reduce_offset, reduce_offset + 1); reduce_offset + 1);
} }
template <typename GradientSumT> template <typename GradientSumT>
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split, void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT> const &input,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, common::Span<CatST> cats_out) {
EvaluateSplitInputs<GradientSumT> input) { if (has_sort_) {
EvaluateSplits(out_split, evaluator, input, {}); dh::CUDAEvent event;
event.Record(dh::DefaultStream());
auto h_cats = this->HostCatStorage(input.nidx);
copy_stream_.View().Wait(event);
dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(),
cudaMemcpyDeviceToHost, copy_stream_.View()));
}
} }
template void EvaluateSplits<GradientPair>( template <typename GradientSumT>
common::Span<DeviceSplitCandidate> out_splits, void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientPair> left, EvaluateSplitInputs<GradientSumT> right,
EvaluateSplitInputs<GradientPair> right); common::Span<GPUExpandEntry> out_entries) {
template void EvaluateSplits<GradientPairPrecise>( auto evaluator = this->tree_evaluator_.template GetEvaluator<GPUTrainingParam>();
common::Span<DeviceSplitCandidate> out_splits,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(2);
EvaluateSplitInputs<GradientPairPrecise> left, auto out_splits = dh::ToSpan(splits_out_storage);
EvaluateSplitInputs<GradientPairPrecise> right); this->EvaluateSplits(left, right, task, evaluator, out_splits);
template void EvaluateSingleSplit<GradientPair>(
common::Span<DeviceSplitCandidate> out_split, auto d_sorted_idx = this->SortedIdx(left);
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, auto d_entries = out_entries;
EvaluateSplitInputs<GradientPair> input); auto cats_out = this->DeviceCatStorage(left.nidx);
template void EvaluateSingleSplit<GradientPairPrecise>( // turn candidate into entry, along with hanlding sort based split.
common::Span<DeviceSplitCandidate> out_split, dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) {
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, auto const &input = i == 0 ? left : right;
EvaluateSplitInputs<GradientPairPrecise> input); auto &split = out_splits[i];
auto fidx = out_splits[i].findex;
if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
bool is_left = i == 0;
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
}
float base_weight =
evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum + split.right_sum});
float left_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.right_sum});
d_entries[i] = GPUExpandEntry{input.nidx, candidate.depth + 1, out_splits[i],
base_weight, left_weight, right_weight};
});
this->CopyToHost(left, cats_out);
}
template <typename GradientSumT>
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
EvaluateSplitInputs<GradientSumT> input, float weight, ObjInfo task) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
auto out_split = dh::ToSpan(splits_out);
auto evaluator = tree_evaluator_.GetEvaluator<GPUTrainingParam>();
this->EvaluateSplits(input, {}, task, evaluator, out_split);
auto cats_out = this->DeviceCatStorage(input.nidx);
auto d_sorted_idx = this->SortedIdx(input);
dh::TemporaryArray<GPUExpandEntry> entries(1);
auto d_entries = entries.data().get();
dh::LaunchN(1, [=] __device__(size_t i) {
auto &split = out_split[i];
auto fidx = out_split[i].findex;
if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
}
float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(0, input.param, GradStats{split.right_sum});
d_entries[0] = GPUExpandEntry(0, 0, split, weight, left_weight, right_weight);
});
this->CopyToHost(input, cats_out);
GPUExpandEntry root_entry;
dh::safe_cuda(cudaMemcpyAsync(&root_entry, entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
return root_entry;
}
template class GPUHistEvaluator<GradientPair>;
template class GPUHistEvaluator<GradientPairPrecise>;
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -3,15 +3,20 @@
*/ */
#ifndef EVALUATE_SPLITS_CUH_ #ifndef EVALUATE_SPLITS_CUH_
#define EVALUATE_SPLITS_CUH_ #define EVALUATE_SPLITS_CUH_
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#include <xgboost/span.h> #include <xgboost/span.h>
#include "../../data/ellpack_page.cuh"
#include "../../common/categorical.h"
#include "../split_evaluator.h" #include "../split_evaluator.h"
#include "../constraints.cuh"
#include "../updater_gpu_common.cuh" #include "../updater_gpu_common.cuh"
#include "expand_entry.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace common {
class HistogramCuts;
}
namespace tree {
template <typename GradientSumT> template <typename GradientSumT>
struct EvaluateSplitInputs { struct EvaluateSplitInputs {
int nidx; int nidx;
@ -23,16 +28,131 @@ struct EvaluateSplitInputs {
common::Span<const float> feature_values; common::Span<const float> feature_values;
common::Span<const float> min_fvalue; common::Span<const float> min_fvalue;
common::Span<const GradientSumT> gradient_histogram; common::Span<const GradientSumT> gradient_histogram;
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
__device__ auto FeatureBins(bst_feature_t fidx) const {
return feature_segments[fidx + 1] - feature_segments[fidx];
}
}; };
template <typename GradientSumT> template <typename GradientSumT>
void EvaluateSplits(common::Span<DeviceSplitCandidate> out_splits, class GPUHistEvaluator {
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, using CatST = common::CatBitField::value_type; // categorical storage type
EvaluateSplitInputs<GradientSumT> left, // use pinned memory to stage the categories, used for sort based splits.
EvaluateSplitInputs<GradientSumT> right); using Alloc = thrust::system::cuda::experimental::pinned_allocator<CatST>;
template <typename GradientSumT>
void EvaluateSingleSplit(common::Span<DeviceSplitCandidate> out_split, private:
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator, TreeEvaluator tree_evaluator_;
EvaluateSplitInputs<GradientSumT> input); // storage for categories for each node, used for sort based splits.
dh::device_vector<CatST> split_cats_;
// host storage for categories for each node, used for sort based splits.
std::vector<CatST, Alloc> h_split_cats_;
// stream for copying categories from device back to host for expanding the decision tree.
dh::CUDAStream copy_stream_;
// storage for sorted index of feature histogram, used for sort based splits.
dh::device_vector<bst_feature_t> cat_sorted_idx_;
TrainParam param_;
// whether the input data requires sort based split, which is more complicated so we try
// to avoid it if possible.
bool has_sort_{false};
// Copy the categories from device to host asynchronously.
void CopyToHost(EvaluateSplitInputs<GradientSumT> const &input, common::Span<CatST> cats_out);
/**
* \brief Get host category storage of nidx for internal calculation.
*/
auto HostCatStorage(bst_node_t nidx) {
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
if (nidx == RegTree::kRoot) {
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
return cats_out;
}
auto cats_out = common::Span<CatST>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2);
return cats_out;
}
/**
* \brief Get device category storage of nidx for internal calculation.
*/
auto DeviceCatStorage(bst_node_t nidx) {
auto cat_bits = split_cats_.size() / param_.MaxNodes();
if (nidx == RegTree::kRoot) {
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits);
return cats_out;
}
auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2);
return cats_out;
}
/**
* \brief Get sorted index storage based on the left node of inputs .
*/
auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) {
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
return dh::ToSpan(cat_sorted_idx_).first(left.feature_values.size());
}
return dh::ToSpan(cat_sorted_idx_);
}
public:
GPUHistEvaluator(TrainParam const &param, bst_feature_t n_features, int32_t device)
: tree_evaluator_{param, n_features, device}, param_{param} {}
/**
* \brief Reset the evaluator, should be called before any use.
*/
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, ObjInfo task,
bst_feature_t n_features, TrainParam const &param, int32_t device);
/**
* \brief Get host category storage for nidx. Different from the internal version, this
* returns strictly 1 node.
*/
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
copy_stream_.View().Sync();
auto cat_bits = h_split_cats_.size() / param_.MaxNodes();
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * cat_bits, cat_bits);
return cats_out;
}
/**
* \brief Add a split to the internal tree evaluator.
*/
void ApplyTreeSplit(GPUExpandEntry const &candidate, RegTree *p_tree) {
auto &tree = *p_tree;
// Set up child constraints
auto left_child = tree[candidate.nid].LeftChild();
auto right_child = tree[candidate.nid].RightChild();
tree_evaluator_.AddSplit(candidate.nid, left_child, right_child,
tree[candidate.nid].SplitIndex(), candidate.left_weight,
candidate.right_weight);
}
auto GetEvaluator() { return tree_evaluator_.GetEvaluator<GPUTrainingParam>(); }
/**
* \brief Sort the histogram based on output to obtain contiguous partitions.
*/
common::Span<bst_feature_t const> SortHistogram(
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator);
// impl of evaluate splits, contains CUDA kernels so it's public
void EvaluateSplits(EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits);
/**
* \brief Evaluate splits for left and right nodes.
*/
void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<GPUExpandEntry> out_splits);
/**
* \brief Evaluate splits for root node.
*/
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight,
ObjInfo task);
};
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,100 @@
/*!
* Copyright 2022 by XGBoost Contributors
*
* \brief Some components of GPU Hist evaluator, this file only exist to reduce nvcc
* compilation time.
*/
#include <thrust/logical.h> // thrust::any_of
#include <thrust/sort.h> // thrust::stable_sort
#include "../../common/device_helpers.cuh"
#include "../../common/hist_util.h" // common::HistogramCuts
#include "evaluate_splits.cuh"
#include "xgboost/data.h"
namespace xgboost {
namespace tree {
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
common::Span<FeatureType const> ft, ObjInfo task,
bst_feature_t n_features, TrainParam const &param,
int32_t device) {
param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device};
if (cuts.HasCategorical() && !task.UseOneHot()) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
auto beg = thrust::make_counting_iterator<size_t>(1ul);
auto end = thrust::make_counting_iterator<size_t>(ptrs.size());
auto to_onehot = param.max_cat_to_onehot;
// This condition avoids sort-based split function calls if the users want
// onehot-encoding-based splits.
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
auto idx = i - 1;
if (common::IsCat(ft, idx)) {
auto n_bins = ptrs[i] - ptrs[idx];
bool use_sort = !common::UseOneHot(n_bins, to_onehot, task);
return use_sort;
}
return false;
});
if (has_sort_) {
auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
CHECK_NE(bit_storage_size, 0);
// We need to allocate for all nodes since the updater can grow the tree layer by
// layer, all nodes in the same layer must be preserved until that layer is
// finished. We can allocate one layer at a time, but the best case is reducing the
// size of the bitset by about a half, at the cost of invoking CUDA malloc many more
// times than necessary.
split_cats_.resize(param.MaxNodes() * bit_storage_size);
h_split_cats_.resize(split_cats_.size());
dh::safe_cuda(
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
}
}
}
template <typename GradientSumT>
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBDeviceAllocator<char> alloc;
auto sorted_idx = this->SortedIdx(left);
dh::Iota(sorted_idx);
// sort 2 nodes and all the features at the same time, disregarding colmun sampling.
thrust::stable_sort(
thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx),
[evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) {
auto l_is_left = l < left.feature_values.size();
auto r_is_left = r < left.feature_values.size();
if (l_is_left != r_is_left) {
return l_is_left; // not the same node
}
auto const &input = l_is_left ? left : right;
l -= (l_is_left ? 0 : input.feature_values.size());
r -= (r_is_left ? 0 : input.feature_values.size());
auto lfidx = dh::SegmentId(input.feature_segments, l);
auto rfidx = dh::SegmentId(input.feature_segments, r);
if (lfidx != rfidx) {
return lfidx < rfidx; // not the same feature
}
if (common::IsCat(input.feature_types, lfidx)) {
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]);
auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]);
return lw < rw;
}
return l < r;
});
return dh::ToSpan(cat_sorted_idx_);
}
template class GPUHistEvaluator<GradientPair>;
template class GPUHistEvaluator<GradientPairPrecise>;
} // namespace tree
} // namespace xgboost

View File

@ -4,8 +4,9 @@
#ifndef EXPAND_ENTRY_CUH_ #ifndef EXPAND_ENTRY_CUH_
#define EXPAND_ENTRY_CUH_ #define EXPAND_ENTRY_CUH_
#include <xgboost/span.h> #include <xgboost/span.h>
#include "../param.h" #include "../param.h"
#include "evaluate_splits.cuh" #include "../updater_gpu_common.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {

View File

@ -53,7 +53,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
return true; return true;
} }
} }
enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 };
// Enumerate/Scan the split values of specific feature // Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains // Returns the sum of gradients corresponding to the data points that contains
@ -137,7 +136,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) - GradStats{right_sum}) -
parent.root_gain); parent.root_gain);
split_pt = cut_val[i]; split_pt = cut_val[i]; // not used for partition based
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
left_sum, right_sum); left_sum, right_sum);
} else { } else {
@ -180,10 +179,10 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
if (d_step == 1) { if (d_step == 1) {
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1), std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
[&cat_bits](size_t c) { cat_bits.Set(c); }); [&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
} else { } else {
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh), std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
[&cat_bits](size_t c) { cat_bits.Set(c); }); [&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
} }
} }
p_best->Update(best); p_best->Update(best);
@ -231,6 +230,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
} }
} }
auto evaluator = tree_evaluator_.GetEvaluator(); auto evaluator = tree_evaluator_.GetEvaluator();
auto const& cut_ptrs = cut.Ptrs();
common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) { common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num(); auto tidx = omp_get_thread_num();
@ -246,26 +246,22 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
continue; continue;
} }
if (is_cat) { if (is_cat) {
auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx]; auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
} else { } else {
auto const &cut_ptr = cut.Ptrs();
std::vector<size_t> sorted_idx(n_bins); std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0); std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins); auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) { std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) < auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
evaluator.CalcWeightCat(param_, feat_hist[r]); evaluator.CalcWeightCat(param_, feat_hist[r]);
static_assert(std::is_same<decltype(ret), bool>::value, "");
return ret; return ret;
}); });
auto grad_stats = EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} }
} else { } else {
auto grad_stats = auto grad_stats =
@ -313,6 +309,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
cat_bits.Set(cat); cat_bits.Set(cat);
} else { } else {
split_cats = candidate.split.cat_bits; split_cats = candidate.split.cat_bits;
common::CatBitField cat_bits{split_cats};
} }
tree.ExpandCategorical( tree.ExpandCategorical(

View File

@ -110,6 +110,9 @@ class TreeEvaluator {
template <typename GradientSumT> template <typename GradientSumT>
XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const { XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const {
// FIXME(jiamingy): This is a temporary solution until we have categorical feature
// specific regularization parameters. During sorting we should try to avoid any
// regularization.
return ::xgboost::tree::CalcWeight(param, stats); return ::xgboost::tree::CalcWeight(param, stats);
} }
@ -180,6 +183,15 @@ class TreeEvaluator {
.Eval(&lower_bounds_, &upper_bounds_, &monotone_); .Eval(&lower_bounds_, &upper_bounds_, &monotone_);
} }
}; };
enum SplitType {
// numerical split
kNum = 0,
// onehot encoding based categorical split
kOneHot = 1,
// partition-based categorical split
kPart = 2
};
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -8,6 +8,7 @@
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
#include "../common/categorical.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/random.h" #include "../common/random.h"
#include "param.h" #include "param.h"
@ -27,6 +28,7 @@ struct GPUTrainingParam {
// default=0 means no constraint on weight delta // default=0 means no constraint on weight delta
float max_delta_step; float max_delta_step;
float learning_rate; float learning_rate;
uint32_t max_cat_to_onehot;
GPUTrainingParam() = default; GPUTrainingParam() = default;
@ -35,14 +37,10 @@ struct GPUTrainingParam {
reg_lambda(param.reg_lambda), reg_lambda(param.reg_lambda),
reg_alpha(param.reg_alpha), reg_alpha(param.reg_alpha),
max_delta_step(param.max_delta_step), max_delta_step(param.max_delta_step),
learning_rate{param.learning_rate} {} learning_rate{param.learning_rate},
max_cat_to_onehot{param.max_cat_to_onehot} {}
}; };
using NodeIdT = int32_t;
/** used to assign default id to a Node */
static const bst_node_t kUnusedNode = -1;
/** /**
* @enum DefaultDirection node.cuh * @enum DefaultDirection node.cuh
* @brief Default direction to be followed in case of missing values * @brief Default direction to be followed in case of missing values
@ -59,6 +57,8 @@ struct DeviceSplitCandidate {
DefaultDirection dir {kLeftDir}; DefaultDirection dir {kLeftDir};
int findex {-1}; int findex {-1};
float fvalue {0}; float fvalue {0};
common::CatBitField split_cats;
bool is_cat { false }; bool is_cat { false };
GradientPairPrecise left_sum; GradientPairPrecise left_sum;
@ -75,6 +75,28 @@ struct DeviceSplitCandidate {
*this = other; *this = other;
} }
} }
/**
* \brief The largest encoded category in the split bitset
*/
bst_cat_t MaxCat() const {
// Reuse the fvalue for categorical values.
return static_cast<bst_cat_t>(fvalue);
}
/**
* \brief Return the best threshold for cat split, reset the value after return.
*/
XGBOOST_DEVICE size_t PopBestThresh() {
// fvalue is also being used for storing the threshold for categorical split
auto best_thresh = static_cast<size_t>(this->fvalue);
this->fvalue = 0;
return best_thresh;
}
template <typename T>
XGBOOST_DEVICE void SetCat(T c) {
this->split_cats.Set(common::AsCat(c));
fvalue = std::max(this->fvalue, static_cast<float>(c));
}
XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in,
float fvalue_in, int findex_in, float fvalue_in, int findex_in,
@ -108,18 +130,6 @@ struct DeviceSplitCandidate {
} }
}; };
struct DeviceSplitCandidateReduceOp {
GPUTrainingParam param;
explicit DeviceSplitCandidateReduceOp(GPUTrainingParam param) : param(std::move(param)) {}
XGBOOST_DEVICE DeviceSplitCandidate operator()(
const DeviceSplitCandidate& a, const DeviceSplitCandidate& b) const {
DeviceSplitCandidate best;
best.Update(a, param);
best.Update(b, param);
return best;
}
};
template <typename T> template <typename T>
struct SumCallbackOp { struct SumCallbackOp {
// Running prefix // Running prefix

View File

@ -159,6 +159,10 @@ class DeviceHistogram {
// Manage memory for a single GPU // Manage memory for a single GPU
template <typename GradientSumT> template <typename GradientSumT>
struct GPUHistMakerDevice { struct GPUHistMakerDevice {
private:
GPUHistEvaluator<GradientSumT> evaluator_;
public:
int device_id; int device_id;
EllpackPageImpl const* page; EllpackPageImpl const* page;
common::Span<FeatureType const> feature_types; common::Span<FeatureType const> feature_types;
@ -182,7 +186,6 @@ struct GPUHistMakerDevice {
dh::PinnedMemory pinned; dh::PinnedMemory pinned;
common::Monitor monitor; common::Monitor monitor;
TreeEvaluator tree_evaluator;
common::ColumnSampler column_sampler; common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints; FeatureInteractionConstraintDevice interaction_constraints;
@ -192,24 +195,20 @@ struct GPUHistMakerDevice {
// Storing split categories for last node. // Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories; dh::caching_device_vector<uint32_t> node_categories;
GPUHistMakerDevice(int _device_id, GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page,
EllpackPageImpl const* _page, common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
common::Span<FeatureType const> _feature_types, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
bst_uint _n_rows,
TrainParam _param,
uint32_t column_sampler_seed,
uint32_t n_features,
BatchParam _batch_param) BatchParam _batch_param)
: device_id(_device_id), : evaluator_{_param, n_features, _device_id},
device_id(_device_id),
page(_page), page(_page),
feature_types{_feature_types}, feature_types{_feature_types},
param(std::move(_param)), param(std::move(_param)),
tree_evaluator(param, n_features, _device_id),
column_sampler(column_sampler_seed), column_sampler(column_sampler_seed),
interaction_constraints(param, n_features), interaction_constraints(param, n_features),
batch_param(std::move(_batch_param)) { batch_param(std::move(_batch_param)) {
sampler.reset(new GradientBasedSampler( sampler.reset(new GradientBasedSampler(page, _n_rows, batch_param, param.subsample,
page, _n_rows, batch_param, param.subsample, param.sampling_method)); param.sampling_method));
if (!param.monotone_constraints.empty()) { if (!param.monotone_constraints.empty()) {
// Copy assigning an empty vector causes an exception in MSVC debug builds // Copy assigning an empty vector causes an exception in MSVC debug builds
monotone_constraints = param.monotone_constraints; monotone_constraints = param.monotone_constraints;
@ -219,9 +218,8 @@ struct GPUHistMakerDevice {
// Init histogram // Init histogram
hist.Init(device_id, page->Cuts().TotalBins()); hist.Init(device_id, page->Cuts().TotalBins());
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id));
feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, feature_groups.reset(new FeatureGroups(
dh::MaxSharedMemoryOptin(device_id), page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT)));
sizeof(GradientSumT)));
} }
~GPUHistMakerDevice() { // NOLINT ~GPUHistMakerDevice() { // NOLINT
@ -231,13 +229,17 @@ struct GPUHistMakerDevice {
// Reset values for each update iteration // Reset values for each update iteration
// Note that the column sampler must be passed by value because it is not // Note that the column sampler must be passed by value because it is not
// thread safe // thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) { void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns,
ObjInfo task) {
auto const& info = dmat->Info(); auto const& info = dmat->Info();
this->column_sampler.Init(num_columns, info.feature_weights.HostVector(), this->column_sampler.Init(num_columns, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel, param.colsample_bynode, param.colsample_bylevel,
param.colsample_bytree); param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id)); dh::safe_cuda(cudaSetDevice(device_id));
tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id);
this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param,
device_id);
this->interaction_constraints.Reset(); this->interaction_constraints.Reset();
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{});
@ -258,10 +260,8 @@ struct GPUHistMakerDevice {
hist.Reset(); hist.Reset();
} }
GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) {
DeviceSplitCandidate EvaluateRootSplit(GradientPairPrecise root_sum) {
int nidx = RegTree::kRoot; int nidx = RegTree::kRoot;
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto sampled_features = column_sampler.GetFeatureSet(0); auto sampled_features = column_sampler.GetFeatureSet(0);
sampled_features->SetDevice(device_id); sampled_features->SetDevice(device_id);
@ -277,32 +277,23 @@ struct GPUHistMakerDevice {
matrix.gidx_fvalue_map, matrix.gidx_fvalue_map,
matrix.min_fvalue, matrix.min_fvalue,
hist.GetNodeHistogram(nidx)}; hist.GetNodeHistogram(nidx)};
auto gain_calc = tree_evaluator.GetEvaluator<GPUTrainingParam>(); auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task);
EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs); return split;
std::vector<DeviceSplitCandidate> result(1);
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
sizeof(DeviceSplitCandidate) * splits_out.size(),
cudaMemcpyDeviceToHost));
return result.front();
} }
void EvaluateLeftRightSplits( void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx,
GPUExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree, int right_nidx, const RegTree& tree,
common::Span<GPUExpandEntry> pinned_candidates_out) { common::Span<GPUExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2); dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param); GPUTrainingParam gpu_param(param);
auto left_sampled_features = auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
column_sampler.GetFeatureSet(tree.GetDepth(left_nidx));
left_sampled_features->SetDevice(device_id); left_sampled_features->SetDevice(device_id);
common::Span<bst_feature_t> left_feature_set = common::Span<bst_feature_t> left_feature_set =
interaction_constraints.Query(left_sampled_features->DeviceSpan(), interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx);
left_nidx); auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
auto right_sampled_features =
column_sampler.GetFeatureSet(tree.GetDepth(right_nidx));
right_sampled_features->SetDevice(device_id); right_sampled_features->SetDevice(device_id);
common::Span<bst_feature_t> right_feature_set = common::Span<bst_feature_t> right_feature_set =
interaction_constraints.Query(right_sampled_features->DeviceSpan(), interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx);
left_nidx);
auto matrix = page->GetDeviceAccessor(device_id); auto matrix = page->GetDeviceAccessor(device_id);
EvaluateSplitInputs<GradientSumT> left{left_nidx, EvaluateSplitInputs<GradientSumT> left{left_nidx,
@ -323,29 +314,11 @@ struct GPUHistMakerDevice {
matrix.gidx_fvalue_map, matrix.gidx_fvalue_map,
matrix.min_fvalue, matrix.min_fvalue,
hist.GetNodeHistogram(right_nidx)}; hist.GetNodeHistogram(right_nidx)};
auto d_splits_out = dh::ToSpan(splits_out);
EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator<GPUTrainingParam>(), left, right);
dh::TemporaryArray<GPUExpandEntry> entries(2); dh::TemporaryArray<GPUExpandEntry> entries(2);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries));
auto d_entries = entries.data().get(); dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(),
dh::LaunchN(2, [=] __device__(size_t idx) { sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
auto split = d_splits_out[idx];
auto nidx = idx == 0 ? left_nidx : right_nidx;
float base_weight = evaluator.CalcWeight(
nidx, gpu_param, GradStats{split.left_sum + split.right_sum});
float left_weight =
evaluator.CalcWeight(nidx, gpu_param, GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(
nidx, gpu_param, GradStats{split.right_sum});
d_entries[idx] =
GPUExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx],
base_weight, left_weight, right_weight};
});
dh::safe_cuda(cudaMemcpyAsync(
pinned_candidates_out.data(), entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
} }
void BuildHist(int nidx) { void BuildHist(int nidx) {
@ -369,12 +342,10 @@ struct GPUHistMakerDevice {
}); });
} }
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
int nidx_subtraction) {
// Make sure histograms are already allocated // Make sure histograms are already allocated
hist.AllocateHistogram(nidx_subtraction); hist.AllocateHistogram(nidx_subtraction);
return hist.HistogramExists(nidx_histogram) && return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
hist.HistogramExists(nidx_parent);
} }
void UpdatePosition(int nidx, RegTree* p_tree) { void UpdatePosition(int nidx, RegTree* p_tree) {
@ -503,13 +474,12 @@ struct GPUHistMakerDevice {
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
auto d_position = row_partitioner->GetPosition(); auto d_position = row_partitioner->GetPosition();
auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); auto tree_evaluator = evaluator_.GetEvaluator();
dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__( dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable {
int local_idx) mutable {
int pos = d_position[local_idx]; int pos = d_position[local_idx];
bst_float weight = evaluator.CalcWeight( bst_float weight =
pos, param_d, GradStats{d_node_sum_gradients[pos]}); tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]});
static_assert(!std::is_const<decltype(out_preds_d)>::value, ""); static_assert(!std::is_const<decltype(out_preds_d)>::value, "");
out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate; out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate;
}); });
@ -562,7 +532,6 @@ struct GPUHistMakerDevice {
void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree; RegTree& tree = *p_tree;
auto evaluator = tree_evaluator.GetEvaluator();
auto parent_sum = candidate.split.left_sum + candidate.split.right_sum; auto parent_sum = candidate.split.left_sum + candidate.split.right_sum;
auto base_weight = candidate.base_weight; auto base_weight = candidate.base_weight;
auto left_weight = candidate.left_weight * param.learning_rate; auto left_weight = candidate.left_weight * param.learning_rate;
@ -572,48 +541,50 @@ struct GPUHistMakerDevice {
if (is_cat) { if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max()) CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large."; << "Categorical feature value too large.";
if (common::InvalidCat(candidate.split.fvalue)) { std::vector<uint32_t> split_cats;
common::InvalidCategory(); if (candidate.split.split_cats.Bits().empty()) {
if (common::InvalidCat(candidate.split.fvalue)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.fvalue);
split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0);
common::CatBitField cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
} else {
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat();
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());
node_categories.resize(candidate.split.split_cats.Bits().size());
dh::safe_cuda(cudaMemcpyAsync(
node_categories.data().get(), candidate.split.split_cats.Data(),
candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice));
} }
auto cat = common::AsCat(candidate.split.fvalue);
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cats_bits(split_cats);
cats_bits.Set(cat);
dh::CopyToD(split_cats, &node_categories);
tree.ExpandCategorical( tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats, candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir,
candidate.split.dir == kLeftDir, base_weight, left_weight, base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
right_weight, candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
} else { } else {
tree.ExpandNode(candidate.nid, candidate.split.findex, tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.fvalue, candidate.split.dir == kLeftDir, candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
candidate.split.right_sum.GetHess());
} }
evaluator_.ApplyTreeSplit(candidate, p_tree);
// Set up child constraints node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
auto left_child = tree[candidate.nid].LeftChild(); node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
auto right_child = tree[candidate.nid].RightChild();
tree_evaluator.AddSplit(candidate.nid, left_child, right_child, interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(),
tree[candidate.nid].SplitIndex(), candidate.left_weight, tree[candidate.nid].LeftChild(),
candidate.right_weight);
node_sum_gradients[tree[candidate.nid].LeftChild()] =
candidate.split.left_sum;
node_sum_gradients[tree[candidate.nid].RightChild()] =
candidate.split.right_sum;
interaction_constraints.Split(
candidate.nid, tree[candidate.nid].SplitIndex(),
tree[candidate.nid].LeftChild(),
tree[candidate.nid].RightChild()); tree[candidate.nid].RightChild());
} }
GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0; constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>( auto gpair_it = dh::MakeTransformIterator<GradientPairPrecise>(
@ -634,39 +605,21 @@ struct GPUHistMakerDevice {
(*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight); (*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight);
// Generate first split // Generate first split
auto split = this->EvaluateRootSplit(root_sum); auto root_entry = this->EvaluateRootSplit(root_sum, weight, task);
dh::TemporaryArray<GPUExpandEntry> entries(1);
auto d_entries = entries.data().get();
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
GPUTrainingParam gpu_param(param);
auto depth = p_tree->GetDepth(kRootNIdx);
dh::LaunchN(1, [=] __device__(size_t idx) {
float left_weight = evaluator.CalcWeight(kRootNIdx, gpu_param,
GradStats{split.left_sum});
float right_weight = evaluator.CalcWeight(
kRootNIdx, gpu_param, GradStats{split.right_sum});
d_entries[0] =
GPUExpandEntry(kRootNIdx, depth, split,
weight, left_weight, right_weight);
});
GPUExpandEntry root_entry;
dh::safe_cuda(cudaMemcpyAsync(
&root_entry, entries.data().get(),
sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
return root_entry; return root_entry;
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo task,
RegTree* p_tree, dh::AllReducer* reducer) { RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree; auto& tree = *p_tree;
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy)); Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
monitor.Start("Reset"); monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task);
monitor.Stop("Reset"); monitor.Stop("Reset");
monitor.Start("InitRoot"); monitor.Start("InitRoot");
driver.Push({ this->InitRoot(p_tree, reducer) }); driver.Push({ this->InitRoot(p_tree, task, reducer) });
monitor.Stop("InitRoot"); monitor.Stop("InitRoot");
auto num_leaves = 1; auto num_leaves = 1;
@ -703,8 +656,7 @@ struct GPUHistMakerDevice {
monitor.Stop("BuildHist"); monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits"); monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx, this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree,
right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2)); new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits"); monitor.Stop("EvaluateSplits");
} else { } else {
@ -819,14 +771,13 @@ class GPUHistMakerSpecialised {
CHECK(*local_tree == reference_tree); CHECK(*local_tree == reference_tree);
} }
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, RegTree* p_tree) {
RegTree* p_tree) {
monitor_.Start("InitData"); monitor_.Start("InitData");
this->InitData(p_fmat); this->InitData(p_fmat);
monitor_.Stop("InitData"); monitor_.Stop("InitData");
gpair->SetDevice(device_); gpair->SetDevice(device_);
maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_);
} }
bool UpdatePredictionCache(const DMatrix *data, bool UpdatePredictionCache(const DMatrix *data,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019-2021 by XGBoost Contributors * Copyright 2019-2022 by XGBoost Contributors
*/ */
#pragma once #pragma once
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -235,6 +235,7 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins,
ASSERT_EQ(dmat->Info().feature_types.Size(), 1); ASSERT_EQ(dmat->Info().feature_types.Size(), 1);
auto cuts = sketch(dmat.get(), num_bins); auto cuts = sketch(dmat.get(), num_bins);
ASSERT_EQ(cuts.MaxCategory(), num_categories - 1);
std::sort(x.begin(), x.end()); std::sort(x.begin(), x.end());
auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); auto n_uniques = std::unique(x.begin(), x.end()) - x.begin();
ASSERT_NE(n_uniques, x.size()); ASSERT_NE(n_uniques, x.size());

View File

@ -1,7 +1,11 @@
/*!
* Copyright 2020-2022 by XGBoost contributors
*/
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../../../../src/tree/gpu_hist/evaluate_splits.cuh" #include "../../../../src/tree/gpu_hist/evaluate_splits.cuh"
#include "../../helpers.h" #include "../../helpers.h"
#include "../../histogram_helpers.h" #include "../../histogram_helpers.h"
#include "../test_evaluate_splits.h" // TestPartitionBasedSplit
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -16,7 +20,6 @@ auto ZeroParam() {
} // anonymous namespace } // anonymous namespace
void TestEvaluateSingleSplit(bool is_categorical) { void TestEvaluateSingleSplit(bool is_categorical) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPairPrecise parent_sum(0.0, 1.0); GradientPairPrecise parent_sum(0.0, 1.0);
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
GPUTrainingParam param{tparam}; GPUTrainingParam param{tparam};
@ -50,11 +53,13 @@ void TestEvaluateSingleSplit(bool is_categorical) {
dh::ToSpan(feature_values), dh::ToSpan(feature_values),
dh::ToSpan(feature_min_values), dh::ToSpan(feature_min_values),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>();
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input);
DeviceSplitCandidate result = out_splits[0]; GPUHistEvaluator<GradientPair> evaluator{
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
dh::device_vector<common::CatBitField::value_type> out_cats;
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0); EXPECT_EQ(result.fvalue, 11.0);
EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(),
@ -72,7 +77,6 @@ TEST(GpuHist, EvaluateCategoricalSplit) {
} }
TEST(GpuHist, EvaluateSingleSplitMissing) { TEST(GpuHist, EvaluateSingleSplitMissing) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPairPrecise parent_sum(1.0, 1.5); GradientPairPrecise parent_sum(1.0, 1.5);
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
GPUTrainingParam param{tparam}; GPUTrainingParam param{tparam};
@ -96,11 +100,10 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
dh::ToSpan(feature_min_values), dh::ToSpan(feature_min_values),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
TreeEvaluator tree_evaluator(tparam, feature_set.size(), 0); GPUHistEvaluator<GradientPair> evaluator(tparam, feature_set.size(), 0);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); DeviceSplitCandidate result =
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = out_splits[0];
EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.findex, 0);
EXPECT_EQ(result.fvalue, 1.0); EXPECT_EQ(result.fvalue, 1.0);
EXPECT_EQ(result.dir, kRightDir); EXPECT_EQ(result.dir, kRightDir);
@ -109,27 +112,18 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
} }
TEST(GpuHist, EvaluateSingleSplitEmpty) { TEST(GpuHist, EvaluateSingleSplitEmpty) {
DeviceSplitCandidate nonzeroed;
nonzeroed.findex = 1;
nonzeroed.loss_chg = 1.0;
thrust::device_vector<DeviceSplitCandidate> out_split(1);
out_split[0] = nonzeroed;
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
TreeEvaluator tree_evaluator(tparam, 1, 0); GPUHistEvaluator<GradientPair> evaluator(tparam, 1, 0);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); DeviceSplitCandidate result = evaluator
EvaluateSingleSplit(dh::ToSpan(out_split), evaluator, .EvaluateSingleSplit(EvaluateSplitInputs<GradientPair>{}, 0,
EvaluateSplitInputs<GradientPair>{}); ObjInfo{ObjInfo::kRegression})
.split;
DeviceSplitCandidate result = out_split[0];
EXPECT_EQ(result.findex, -1); EXPECT_EQ(result.findex, -1);
EXPECT_LT(result.loss_chg, 0.0f); EXPECT_LT(result.loss_chg, 0.0f);
} }
// Feature 0 has a better split, but the algorithm must select feature 1 // Feature 0 has a better split, but the algorithm must select feature 1
TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPairPrecise parent_sum(0.0, 1.0); GradientPairPrecise parent_sum(0.0, 1.0);
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
tparam.UpdateAllowUnknown(Args{}); tparam.UpdateAllowUnknown(Args{});
@ -157,11 +151,10 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
dh::ToSpan(feature_min_values), dh::ToSpan(feature_min_values),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); DeviceSplitCandidate result =
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = out_splits[0];
EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.findex, 1);
EXPECT_EQ(result.fvalue, 11.0); EXPECT_EQ(result.fvalue, 11.0);
EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5)); EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5));
@ -170,7 +163,6 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
// Features 0 and 1 have identical gain, the algorithm must select 0 // Features 0 and 1 have identical gain, the algorithm must select 0
TEST(GpuHist, EvaluateSingleSplitBreakTies) { TEST(GpuHist, EvaluateSingleSplitBreakTies) {
thrust::device_vector<DeviceSplitCandidate> out_splits(1);
GradientPairPrecise parent_sum(0.0, 1.0); GradientPairPrecise parent_sum(0.0, 1.0);
TrainParam tparam = ZeroParam(); TrainParam tparam = ZeroParam();
tparam.UpdateAllowUnknown(Args{}); tparam.UpdateAllowUnknown(Args{});
@ -198,11 +190,10 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
dh::ToSpan(feature_min_values), dh::ToSpan(feature_min_values),
dh::ToSpan(feature_histogram)}; dh::ToSpan(feature_histogram)};
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); GPUHistEvaluator<GradientPair> evaluator(tparam, feature_min_values.size(), 0);
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); DeviceSplitCandidate result =
EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
DeviceSplitCandidate result = out_splits[0];
EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.findex, 0);
EXPECT_EQ(result.fvalue, 1.0); EXPECT_EQ(result.fvalue, 1.0);
} }
@ -250,9 +241,10 @@ TEST(GpuHist, EvaluateSplits) {
dh::ToSpan(feature_min_values), dh::ToSpan(feature_min_values),
dh::ToSpan(feature_histogram_right)}; dh::ToSpan(feature_histogram_right)};
TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); GPUHistEvaluator<GradientPair> evaluator{
auto evaluator = tree_evaluator.GetEvaluator<GPUTrainingParam>(); tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
EvaluateSplits(dh::ToSpan(out_splits), evaluator, input_left, input_right); evaluator.EvaluateSplits(input_left, input_right, ObjInfo{ObjInfo::kRegression},
evaluator.GetEvaluator(), dh::ToSpan(out_splits));
DeviceSplitCandidate result_left = out_splits[0]; DeviceSplitCandidate result_left = out_splits[0];
EXPECT_EQ(result_left.findex, 1); EXPECT_EQ(result_left.findex, 1);
@ -262,5 +254,36 @@ TEST(GpuHist, EvaluateSplits) {
EXPECT_EQ(result_right.findex, 0); EXPECT_EQ(result_right.findex, 0);
EXPECT_EQ(result_right.fvalue, 1.0); EXPECT_EQ(result_right.fvalue, 1.0);
} }
TEST_F(TestPartitionBasedSplit, GpuHist) {
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
GPUHistEvaluator<GradientPairPrecise> evaluator{param_,
static_cast<bst_feature_t>(info_.num_col_), 0};
cuts_.cut_ptrs_.SetDevice(0);
cuts_.cut_values_.SetDevice(0);
cuts_.min_vals_.SetDevice(0);
ObjInfo task{ObjInfo::kRegression};
evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0);
dh::device_vector<GradientPairPrecise> d_hist(hist_[0].size());
auto node_hist = hist_[0];
dh::safe_cuda(cudaMemcpy(d_hist.data().get(), node_hist.data(), node_hist.size_bytes(),
cudaMemcpyHostToDevice));
dh::device_vector<bst_feature_t> feature_set{std::vector<bst_feature_t>{0}};
EvaluateSplitInputs<GradientPairPrecise> input{0,
total_gpair_,
GPUTrainingParam{param_},
dh::ToSpan(feature_set),
dh::ToSpan(ft),
cuts_.cut_ptrs_.ConstDeviceSpan(),
cuts_.cut_values_.ConstDeviceSpan(),
cuts_.min_vals_.ConstDeviceSpan(),
dh::ToSpan(d_hist)};
auto split = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split;
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
}
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -3,9 +3,11 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include "../../../../src/common/hist_util.h"
#include "../../../../src/tree/hist/evaluate_splits.h" #include "../../../../src/tree/hist/evaluate_splits.h"
#include "../../../../src/tree/updater_quantile_hist.h" #include "../../../../src/tree/updater_quantile_hist.h"
#include "../../../../src/common/hist_util.h" #include "../test_evaluate_splits.h"
#include "../../helpers.h" #include "../../helpers.h"
namespace xgboost { namespace xgboost {
@ -108,80 +110,17 @@ TEST(HistEvaluator, Apply) {
ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f);
} }
TEST(HistEvaluator, CategoricalPartition) { TEST_F(TestPartitionBasedSplit, CPUHist) {
int static constexpr kRows = 128, kCols = 1; // check the evaluator is returning the optimal split
using GradientSumT = double; std::vector<FeatureType> ft{FeatureType::kCategorical};
std::vector<FeatureType> ft(kCols, FeatureType::kCategorical);
TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
size_t n_cats{8};
auto dmat =
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();
int32_t n_threads = 16;
auto sampler = std::make_shared<common::ColumnSampler>(); auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator = HistEvaluator<GradientSumT, CPUExpandEntry>{ HistEvaluator<double, CPUExpandEntry> evaluator{param_, info_, common::OmpGetNumThreads(0),
param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; sampler, ObjInfo{ObjInfo::kRegression}};
evaluator.InitRoot(GradStats{total_gpair_});
for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) { RegTree tree;
common::HistCollection<GradientSumT> hist; std::vector<CPUExpandEntry> entries(1);
evaluator.EvaluateSplits(hist_, cuts_, {ft}, tree, &entries);
std::vector<CPUExpandEntry> entries(1); ASSERT_NEAR(entries[0].split.loss_chg, best_score_, 1e-16);
entries.front().nid = 0;
entries.front().depth = 0;
hist.Init(gmat.cut.TotalBins());
hist.AddHistRow(0);
hist.AllocateAllData();
auto node_hist = hist[0];
ASSERT_EQ(node_hist.size(), n_cats);
ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back());
GradientPairPrecise total_gpair;
for (size_t i = 0; i < node_hist.size(); ++i) {
node_hist[i] = {static_cast<double>(node_hist.size() - i), 1.0};
total_gpair += node_hist[i];
}
SimpleLCG lcg;
std::shuffle(node_hist.begin(), node_hist.end(), lcg);
RegTree tree;
evaluator.InitRoot(GradStats{total_gpair});
evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries);
ASSERT_TRUE(entries.front().split.is_cat);
auto run_eval = [&](auto fn) {
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
GradStats left, right;
for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) {
auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
evaluator.Stats().front().root_gain;
fn(loss_chg);
left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess());
right.SetSubstract(GradStats{total_gpair}, left);
}
}
};
// Assert that's the best split
auto best_loss_chg = entries.front().split.loss_chg;
run_eval([&](auto loss_chg) {
// Approximated test that gain returned by optimal partition is greater than
// numerical split.
ASSERT_GT(best_loss_chg, loss_chg);
});
// node_hist is captured in lambda.
std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) {
return evaluator.Evaluator().CalcWeightCat(param, l) <
evaluator.Evaluator().CalcWeightCat(param, r);
});
double reimpl = 0;
run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); });
CHECK_EQ(reimpl, best_loss_chg);
}
} }
namespace { namespace {

View File

@ -0,0 +1,96 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <algorithm> // next_permutation
#include <numeric> // iota
#include "../../../src/tree/hist/evaluate_splits.h"
#include "../helpers.h"
namespace xgboost {
namespace tree {
/**
* \brief Enumerate all possible partitions for categorical split.
*/
class TestPartitionBasedSplit : public ::testing::Test {
protected:
size_t n_bins_ = 6;
std::vector<size_t> sorted_idx_;
TrainParam param_;
MetaInfo info_;
float best_score_{-std::numeric_limits<float>::infinity()};
common::HistogramCuts cuts_;
common::HistCollection<double> hist_;
GradientPairPrecise total_gpair_;
void SetUp() override {
param_.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});
sorted_idx_.resize(n_bins_);
std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0);
info_.num_col_ = 1;
cuts_.cut_ptrs_.Resize(2);
cuts_.SetCategorical(true, n_bins_);
auto &h_cuts = cuts_.cut_ptrs_.HostVector();
h_cuts[0] = 0;
h_cuts[1] = n_bins_;
auto &h_vals = cuts_.cut_values_.HostVector();
h_vals.resize(n_bins_);
std::iota(h_vals.begin(), h_vals.end(), 0.0);
hist_.Init(cuts_.TotalBins());
hist_.AddHistRow(0);
hist_.AllocateAllData();
auto node_hist = hist_[0];
SimpleLCG lcg;
SimpleRealUniformDistribution<double> grad_dist{-4.0, 4.0};
SimpleRealUniformDistribution<double> hess_dist{0.0, 4.0};
for (auto &e : node_hist) {
e = GradientPairPrecise{grad_dist(&lcg), hess_dist(&lcg)};
total_gpair_ += e;
}
auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow<double> hist,
GradientPairPrecise parent_sum) {
int32_t best_thresh = -1;
float best_score{-std::numeric_limits<float>::infinity()};
TreeEvaluator evaluator{param_, static_cast<bst_feature_t>(n_feat), -1};
auto tree_evaluator = evaluator.GetEvaluator<TrainParam>();
GradientPairPrecise left_sum;
auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_});
for (size_t i = 0; i < hist.size() - 1; ++i) {
left_sum += hist[i];
auto right_sum = parent_sum - left_sum;
auto gain =
tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) -
parent_gain;
if (gain > best_score) {
best_score = gain;
best_thresh = i;
}
}
return std::make_tuple(best_thresh, best_score);
};
// enumerate all possible partitions to find the optimal split
do {
int32_t thresh;
float score;
std::vector<GradientPairPrecise> sorted_hist(node_hist.size());
for (size_t i = 0; i < sorted_hist.size(); ++i) {
sorted_hist[i] = node_hist[sorted_idx_[i]];
}
std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_);
if (score > best_score_) {
best_score_ = score;
}
} while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end()));
}
};
} // namespace tree
} // namespace xgboost

View File

@ -262,7 +262,8 @@ TEST(GpuHist, EvaluateRootSplit) {
info.num_row_ = kNRows; info.num_row_ = kNRows;
info.num_col_ = kNCols; info.num_col_ = kNCols;
DeviceSplitCandidate res = maker.EvaluateRootSplit({6.4f, 12.8f}); DeviceSplitCandidate res =
maker.EvaluateRootSplit({6.4f, 12.8f}, 0, ObjInfo{ObjInfo::kRegression}).split;
ASSERT_EQ(res.findex, 7); ASSERT_EQ(res.findex, 7);
ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps); ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps);
@ -300,11 +301,11 @@ void TestHistogramIndexImpl() {
const auto &maker = hist_maker.maker; const auto &maker = hist_maker.maker;
auto grad = GenerateRandomGradients(kNRows); auto grad = GenerateRandomGradients(kNRows);
grad.SetDevice(0); grad.SetDevice(0);
maker->Reset(&grad, hist_maker_dmat.get(), kNCols); maker->Reset(&grad, hist_maker_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.HostVector()); std::vector<common::CompressedByteT> h_gidx_buffer(maker->page->gidx_buffer.HostVector());
const auto &maker_ext = hist_maker_ext.maker; const auto &maker_ext = hist_maker_ext.maker;
maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols); maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression});
std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector()); std::vector<common::CompressedByteT> h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector());
ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins());

View File

@ -211,6 +211,34 @@ class TestTreeMethod:
) )
assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
parameters["max_cat_to_onehot"] = 1
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
parameters,
m,
num_boost_round=rounds,
evals=[(m, "Train")],
evals_result=by_grouping,
)
rmse_oh = by_builtin_results["Train"]["rmse"]
rmse_group = by_grouping["Train"]["rmse"]
# always better or equal to onehot when there's no regularization.
for a, b in zip(rmse_oh, rmse_group):
assert a >= b
parameters["reg_lambda"] = 1.0
by_grouping = {}
xgb.train(
parameters,
m,
num_boost_round=32,
evals=[(m, "Train")],
evals_result=by_grouping,
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
@given(strategies.integers(10, 400), strategies.integers(3, 8), @given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7)) strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None) @settings(deadline=None)